Saturday 2 January 2021

Reshaping before as_strided for optimisation

def forward(x, f, s):
    B, H, W, C = x.shape # e.g. 64, 16, 16, 3
    Fh, Fw, C, _ = f.shape # e.g. 4, 4, 3, 3 
    # C is redeclared to emphasise that the dimension is the same
    
    Sh, Sw = s # e.g. 2, 2

    strided_shape = B, 1 + (H - Fh) // Sh, 1 + (W - Fw) // Sw, Fh, Fw, C

    x = as_strided(x, strided_shape, strides=(
        x.strides[0], 
        Sh * x.strides[1], 
        Sw * x.strides[2], 
        x.strides[1], 
        x.strides[2], 
        x.strides[3]), 
    )

    # print(x.flags, f.flags)

    # The reshaping changes the einsum from 'wxyijk,ijkd' to 'wxyz,zd->wxyd'
    f = f.reshape(-1, f.shape[-1])
    x = x.reshape(*x.shape[:3], -1) # Bottleneck!
    
    return np.einsum('wxyz,zd->wxyd', x, f, optimize='optimal')

For reference, here are the flags for x and f before reshaping:

x.flags:

C_CONTIGUOUS : False
F_CONTIGUOUS : False
OWNDATA : False
WRITEABLE : True
ALIGNED : True
WRITEBACKIFCOPY : False
UPDATEIFCOPY : False


f.flags:

C_CONTIGUOUS : True
F_CONTIGUOUS : False
OWNDATA : True
WRITEABLE : True
ALIGNED : True
WRITEBACKIFCOPY : False
UPDATEIFCOPY : False

Interestingly the major bottleneck in the routine is not the einsum, but rather the reshaping (flattening) of x. I understand that f does not suffer from such problems since its memory is C-contiguous, so the reshape amounts to a quick internal modification without changing the data - but since x is not C-contiguous (and does not own its data, for that matter), the reshape is far more expensive since it involves changing the data/fetching non-cache-aligned data often. This, in turn, results from the as_strided function performed on x - the modification of the strides must be in such a manner as to disturb the natural ordering. (FYI, the as_strided is incredibly fast, and should be fast no matter what strides are passed to it)

Is there a way to achieve the same result without incurring the bottleneck? Perhaps by reshaping x before using as_strided?



from Reshaping before as_strided for optimisation

No comments:

Post a Comment