Wednesday, 26 August 2020

How to model.predict inside loss function? (Tensorflow, Keras)

I am trying to construct a custom loss for a regression problem with the following structure, following this answer: Keras Custom loss function to pass arguments other than y_true and y_pred

Now, my function is like the following:

def CustomLoss(model,X_valid,y_valid,batch_size):
    def Loss(y_true,y_pred):
        n_samples=5
        mc_predictions = np.zeros((n_samples,256,256))
        for i in range(n_samples):
           y_p = model.predict(X_valid, verbose=1,batch_size=batch_size)
    (Other operations...) 
        return LossValue
    return Loss

When trying to execute this line y_p = model.predict(X_valid, verbose=1,batch_size=batch_size) i get the following error:

Method requires being in cross-replica context, use get_replica_context().merge_call()

From what I gathered I cannot use model.predict inside loss function. Is there a workaround or solution for this? Please let me know if my question is clear or if you need any additional information. Thanks!



from How to model.predict inside loss function? (Tensorflow, Keras)

No comments:

Post a Comment