Consider the following code. I'd like to know how I can gather non-Tensor types from a list.
import tensorflow as tf
class Point(tf.experimental.ExtensionType):
xx: tf.Tensor
def __init__(self,xx):
self.xx = xx
super().__init__()
list1 = [ 1, 2, 3, 4]
list2 = [ Point(1), Point(2), Point(3), Point(4) ]
# this works
out1 = tf.gather(list1,[0,2])
print('First gather ',out1)
# this throws: ValueError: Attempt to convert a value (Point(xx=<tf.Tensor:
# shape=(), dtype=int32, numpy=1>)) with an unsupported type
# (<class '__main__.Point'>) to a Tensor.
out2 = tf.gather(list2,[0,2])
print('Second gather ',out2)
from usage of tf.gather to index lists containing non-Tensor types
No comments:
Post a Comment