Tuesday, 19 February 2019

Assign indexed entry of Keras tensor

I'm something of a Keras beginner so my apologies in advance for any generally poor understanding.

I want to manually set some values of my Keras tensor according to, say, indices stored in another tensor. I believe I understand how to access entries of a tensor using tf.gather_nd (my untested attempt below), and I think I understand that I can only set values of a variable and not a tensor.

For clarity, this is taking place between the generation and discrimination stages of a GAN.

gen_out = generator(inputs)

indices_to_reset = Input(shape=(1,),dtype='int32')

batch_size = K.shape(x)[0]

idx_0 = K.reshape(K.arange(batch_size),(1,))
indices_to_reset = K.reshape(indices_to_reset, (1,))

idx = K.stack((idx_0, indices_to_reset), axis=0)

grabbed_entries = Lambda(lambda x: tf.gather_nd(gen_out,x))(idx)

updated_gen_out = ???



from Assign indexed entry of Keras tensor

No comments:

Post a Comment