def forward(x, f, s):
B, H, W, C = x.shape # e.g. 64, 16, 16, 3
Fh, Fw, C, _ = f.shape # e.g. 4, 4, 3, 3
# C is redeclared to emphasise that the dimension is the same
Sh, Sw = s # e.g. 2, 2
strided_shape = B, 1 + (H - Fh) // Sh, 1 + (W - Fw) // Sw, Fh, Fw, C
x = as_strided(x, strided_shape, strides=(
x.strides[0],
Sh * x.strides[1],
Sw * x.strides[2],
x.strides[1],
x.strides[2],
x.strides[3]),
)
# print(x.flags, f.flags)
# The reshaping changes the einsum from 'wxyijk,ijkd' to 'wxyz,zd->wxyd'
f = f.reshape(-1, f.shape[-1])
x = x.reshape(*x.shape[:3], -1) # Bottleneck!
return np.einsum('wxyz,zd->wxyd', x, f, optimize='optimal')
For reference, here are the flags for x and f before reshaping:
x.flags:
C_CONTIGUOUS : False
F_CONTIGUOUS : False
OWNDATA : False
WRITEABLE : True
ALIGNED : True
WRITEBACKIFCOPY : False
UPDATEIFCOPY : False
f.flags:
C_CONTIGUOUS : True
F_CONTIGUOUS : False
OWNDATA : True
WRITEABLE : True
ALIGNED : True
WRITEBACKIFCOPY : False
UPDATEIFCOPY : False
Interestingly the major bottleneck in the routine is not the einsum, but rather the reshaping (flattening) of x. I understand that f does not suffer from such problems since its memory is C-contiguous, so the reshape amounts to a quick internal modification without changing the data - but since x is not C-contiguous (and does not own its data, for that matter), the reshape is far more expensive since it involves changing the data/fetching non-cache-aligned data often. This, in turn, results from the as_strided function performed on x - the modification of the strides must be in such a manner as to disturb the natural ordering. (FYI, the as_strided is incredibly fast, and should be fast no matter what strides are passed to it)
Is there a way to achieve the same result without incurring the bottleneck? Perhaps by reshaping x before using as_strided?
from Reshaping before as_strided for optimisation
No comments:
Post a Comment