I'm something of a Keras beginner so my apologies in advance for any generally poor understanding.
I want to manually set some values of my Keras tensor according to, say, indices stored in another tensor. I believe I understand how to access entries of a tensor using tf.gather_nd
(my untested attempt below), and I think I understand that I can only set values of a variable and not a tensor.
For clarity, this is taking place between the generation and discrimination stages of a GAN.
gen_out = generator(inputs)
indices_to_reset = Input(shape=(1,),dtype='int32')
batch_size = K.shape(x)[0]
idx_0 = K.reshape(K.arange(batch_size),(1,))
indices_to_reset = K.reshape(indices_to_reset, (1,))
idx = K.stack((idx_0, indices_to_reset), axis=0)
grabbed_entries = Lambda(lambda x: tf.gather_nd(gen_out,x))(idx)
updated_gen_out = ???
from Assign indexed entry of Keras tensor
No comments:
Post a Comment