Saturday 3 October 2020

Slicing a tensor with a tensor of indices and tf.gather

I am trying to slice a tensor with a indices tensor. For this purpose I am trying to use tf.gather. However, I am having a hard time understanding the documentation and don't get it to work as I would expect it to:

I have two tensors. An activations tensor with a shape of [1,240,4] and an ids tensor with the shape [1,1,120]. I want to slice the second dimension of the activations tensor with the indices provided in the third dimension of the ids tensor:

downsampled_activations = tf.gather(activations, ids, axis=1)

I have given it the axis=1 option since that is the axis in the activations tensor I want to slice.

However, this does not render the expected result and only gives me the following error:

tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0,0,1] = 1 is not in [0, 1)

I have tried various combinations of the axis and batch_dims options, but to no avail so far and the documentation doesn't really help me on my path. Anybody care to explain the parameters in more detail or on the example above would be very helpful!

Edit: The IDs are precomputed before runtime and come in through an input pipeline as such:

features = tf.io.parse_single_example(
            serialized_example,
            features={ 'featureIDs': tf.io.FixedLenFeature([], tf.string)}

They are then reshaped into the previous format:

feature_ids_raw = tf.decode_raw(features['featureIDs'], tf.int32)
feature_ids_shape = tf.stack([batch_size, (num_neighbours * 4)])
feature_ids = tf.reshape(feature_ids_raw, feature_ids_shape)
feature_ids = tf.expand_dims(feature_ids, 0)

Afterwards they have the previously mentioned shape (batch_size = 1 and num_neighbours = 30 -> [1,1,120]) and I want to use them to slice the activations tensor.

Edit2: I would like the output to be [1,120,4]. (So I would like to gather the entries along the second dimension of the activations tensor in accordance with the IDs stored in my ids tensor.)



from Slicing a tensor with a tensor of indices and tf.gather

No comments:

Post a Comment