Tuesday 26 October 2021

BatchNorm makes accuracy at prediction time around 10% of what's reported during training in tensorflow 2.6

I know this has been previously discussed, but I did not find a concrete answer, and some answers did not work after trying them, the case is simple, I have a model, if I use batch norm, the training accuracy reported by model.fit(training_data) is above 0.9 (it consistently increases, and the loss decreases), but then after training if I run model.evaluate(training_data) (notice is the same data) it returns 0.09, also predictions are really bad (the accuracy is low too if manually calculated using the results from model.predict(training_data). I know the difference between training and testing time in batch norm, and I know differences should be expected, but a drop from 0.9 to 0.09 seems just wrong(and the model is completely unusable). I tried some solutions from other threads:

  • use batch_size in .evaluate to be the same as .fit: did not make a difference
  • set tf.keras.backend.set_learning_phase(0): got a message saying it is now deprecated and made no difference.
  • set all batch norm layers to have layer.trainable=False before .predict and .evaluate: it did not a difference.

If I remove batch norm layers, the report from model.fit(training_data) coincides with model.evaluate(training_data) but the training is not doing any progress (results are consistent but bad) so I need to add it.

Is this a major bug in TF 2.6?

Update: also tested TF 2.5, result is the same.

Sample code(omitting irrelevant code, like data reading and pre-processing):

    ### model definition
 
class CLS_BERT_Embedding(tf.keras.Model):
    """Will only use the CLS token"""
    def __init__(self, bert_trainable=False,  number_filters=50,FNN_units=512,
               number_clases=2,dropout_rate=0.1,name="dcnn"):
        super(CLS_BERT_Embedding,self).__init__(name)
        self.checkpoint_id ="CLS_BERT_Embedding_bn_3fc_{}filters_{}fc_units_berttrainable{}".format(number_filters,
                                                                                FNN_units,bert_trainable)
 
        # trainable= False so we don't fine-tune bert, just use as embedding layer
        self.bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1", 
                                          trainable=bert_trainable,
                                         input_shape=(3,376))
 
 
        self.dense_1 = layers.Dense(units = FNN_units,activation="relu")
        self.bn1 = layers.BatchNormalization()
        self.dense_2 = layers.Dense(units = FNN_units, activation="relu")
        self.bn2 = layers.BatchNormalization()
        self.dense_3 = layers.Dense(units = FNN_units, activation="relu")
        self.bn3 = layers.BatchNormalization()
        self.dropout = layers.Dropout(rate=dropout_rate)
 
        if number_clases == 2:
            self.last_dense = layers.Dense(units=1,activation="sigmoid")
        else:
            self.last_dense = layers.Dense(units=number_clases,activation="softmax")
 
    def get_bert_embeddings(self,all_tokens):
        CLS_embedding ,embeddings = self.bert_layer([all_tokens[:,0,:],
                                        all_tokens[:,1,:],
                                        all_tokens[:,2,:]])
 
        return CLS_embedding,embeddings
 
    def call(self,inputs,training):
        CLS_embedding, x_seq = self.get_bert_embeddings(inputs)
        
        
        x = self.dense_1(CLS_embedding)
        x = self.bn1(x,training)
        x = self.dense_2(x)
        x = self.bn2(x,training)
        x = self.dense_3(x)
        x = self.bn3(x,training)
        
        output = self.last_dense(x)
 
 
        return output
 
#### config and hyper-params
NUMBER_FILTERS = 1024
FNN_UNITS = 2048
BERT_TRAINABLE = False
 
NUMBER_CLASSES = len(tokenizer.vocab)
 
DROPOUT_RATE = 0.2
 
NUMBER_EPOCHS = 3
LR = 0.001
 
DEVICE = '/GPU:0'
 
#### optimization definition
with tf.device(DEVICE):
    model = CLS_BERT_Embedding(
                bert_trainable = BERT_TRAINABLE,
                number_filters=NUMBER_FILTERS,
                FNN_units=FNN_UNITS,
                number_clases=NUMBER_CLASSES,
                dropout_rate = DROPOUT_RATE)
 
if NUMBER_CLASSES == 2:
    loss = "binary_crossentropy"
    metrics = ["accuracy"]
else:
    loss="sparse_categorical_crossentropy"
    metrics = ["sparse_categorical_accuracy"]
 
    
optimizer = tf.keras.optimizers.Adam(learning_rate = LR)
loss="sparse_categorical_crossentropy"
model.compile(loss=loss,optimizer=optimizer,metrics=metrics)
 
 
### training
with tf.device(DEVICE):
    model.fit(train_dataset,
             batch_size = BATCH_SIZE ,
             epochs=NUMBER_EPOCHS,
             shuffle=True,
             callbacks=[MyCustomCallback(), 
                        
                        tf.keras.callbacks.ReduceLROnPlateau(monitor="loss",patience=5),
                        tensorboard,lr_tensorboard])
 
 
### testing
train_results = model.evaluate(train_dataset,batch_size = BATCH_SIZE)
print(train_results)


from BatchNorm makes accuracy at prediction time around 10% of what's reported during training in tensorflow 2.6

No comments:

Post a Comment