How can I efficiently implement the fill forward logic (inspired for pandas ffill
) for a vector shaped NxLxC (batch, sequence dimension, channel). Because each channel sequence is independent this can be equivalent to working with a tensor shaped (N*C)xL.
The computation should keep the torch variable so that the actual output is differentiable.
I managed to make something with advanced indexing, but it is L**2 in the memory and number of operations, so not very great and gpu friendly.
Example:
Assuming you have the sequence [0,1,2,0,0,3,0,4,0,0,0,5,6,0]
in a tensor shaped 1x14
the fill forward will give you the sequence [0,1,2,2,2,3,3,4,4,4,4,5,6,6]
.
An other example shaped 2x4
is [[0, 1, 0, 3], [1, 2, 0, 3]]
which should be forward filled into [[0, 1, 1, 3], [1, 2, 2, 3]]
.
Method used today:
We use the following code that is highly unoptimized but still faster than non vectorized loops:
def last_zero_sequence_start_indices(t: torch.Tensor) -> torch.Tensor:
"""
Given a 3D tensor `t`, this function returns a two-dimensional tensor where each entry represents
the starting index of the last contiguous sequence of zeros up to and including the current index.
If there's no zero at the current position, the value is the tensor's length.
In essence, for each position in `t`, the function pinpoints the beginning of the last contiguous
sequence of zeros up to that position.
Args:
- t (torch.Tensor): Input tensor with shape [Batch, Channel, Time].
Returns:
- torch.Tensor: Three-dimensional tensor with shape [Batch, Channel, Time] indicating the starting position of
the last sequence of zeros up to each index in `t`.
"""
# Create a mask indicating the start of each zero sequence
start_of_zero_sequence = (t == 0) & torch.cat([
torch.full(t.shape[:-1] + (1,), True, device=t.device),
t[..., :-1] != 0,
], dim=2)
# Duplicate this mask into a TxT matrix
duplicated_mask = start_of_zero_sequence.unsqueeze(2).repeat(1, 1, t.size(-1), 1)
# Extract the lower triangular part of this matrix (including the diagonal)
lower_triangular = torch.tril(duplicated_mask)
# For each row, identify the index of the rightmost '1' (start of the last zero sequence up to that row)
indices = t.size(-1) - 1 - lower_triangular.int().flip(dims=[3]).argmax(dim=3)
return indices
from How to efficiently implement forward fill in pytorch
No comments:
Post a Comment