Tuesday, 26 January 2021

Passing non-tensor parameters to a Keras model during training / using tensors for indexing

I'm trying to train a Keras model that incorporates data augmentation in the model itself. The input to the model are images of different classes, and the model is supposed to generate an augmentation model for each class which should be used for the augmentation process. My code roughly looks like this:

from keras.models import Model
from keras.layers import Input
...further imports...

def get_main_model(input_shape, n_classes):
    encoder_model = get_encoder_model()
    input = Input(input_shape, name="input")
    label_input = Input((1,), name="label_input")
    aug_models = [get_augmentation_model() for i in range(n_classes)]
    augmentation = aug_models[label_input](input)
    x = encoder_model(input)
    y = encoder_model(augmentation)
    model = Model(inputs=[input, label_input], outputs=[x, y])
    model.add_loss(custom_loss_function(x, y))
    return model 

I would then like to pass batches of data through the model which consist of an array of images (passed to input) and a corresponding array of labels (passed to label_input). However, this doesn't work since whatever is input into label_input is converted to a tensor by Tensorflow and can't be used for indexing in the following. What I've tried is the following:

  • augmentation = aug_models[int(label_input)](input) --> doesn't work because label_input is a tensor
  • augmentation = aug_models[tf.make_ndarray(label_input)](input) --> casting doesn't work (I guess because label_input is a symbolic tensor)
  • tf.gather(aug_models, label_input) --> doesn't work because the result of the operation is a Keras model instance that Tensorflow tries to cast into a tensor (which obviously fails)

Is there any kind of trick in Tensorflow that would enable me to pass a parameter to the model during training that is not converted to a tensor or a different way in which I could tell the model which augmentation model to select? Thanks in advance!



from Passing non-tensor parameters to a Keras model during training / using tensors for indexing

No comments:

Post a Comment