Wednesday 25 October 2023

usage of tf.gather to index lists containing non-Tensor types

Consider the following code. I'd like to know how I can gather non-Tensor types from a list.

import tensorflow as tf

class Point(tf.experimental.ExtensionType):
    xx: tf.Tensor
    def __init__(self,xx):
        self.xx = xx
        super().__init__()

list1 = [ 1, 2, 3, 4] 
list2 = [ Point(1), Point(2), Point(3), Point(4) ]

# this works
out1 = tf.gather(list1,[0,2])
print('First gather ',out1)

# this throws: ValueError: Attempt to convert a value (Point(xx=<tf.Tensor:
# shape=(), dtype=int32, numpy=1>)) with an unsupported type
# (<class '__main__.Point'>) to a Tensor.

out2 = tf.gather(list2,[0,2])
print('Second gather ',out2)


from usage of tf.gather to index lists containing non-Tensor types

No comments:

Post a Comment