I know this has been previously discussed, but I did not find a concrete answer, and some answers did not work after trying them, the case is simple, I have a model, if I use batch norm, the training accuracy reported by model.fit(training_data)
is above 0.9 (it consistently increases, and the loss decreases), but then after training if I run model.evaluate(training_data)
(notice is the same data) it returns 0.09, also predictions are really bad (the accuracy is low too if manually calculated using the results from model.predict(training_data)
. I know the difference between training and testing time in batch norm, and I know differences should be expected, but a drop from 0.9 to 0.09 seems just wrong(and the model is completely unusable). I tried some solutions from other threads:
- use batch_size in
.evaluate
to be the same as.fit
: did not make a difference - set
tf.keras.backend.set_learning_phase(0)
: got a message saying it is now deprecated and made no difference. - set all batch norm layers to have
layer.trainable=False
before.predict
and.evaluate
: it did not a difference.
If I remove batch norm layers, the report from model.fit(training_data)
coincides with model.evaluate(training_data)
but the training is not doing any progress (results are consistent but bad) so I need to add it.
Is this a major bug in TF 2.6?
Update: also tested TF 2.5, result is the same.
Sample code(omitting irrelevant code, like data reading and pre-processing):
### model definition
class CLS_BERT_Embedding(tf.keras.Model):
"""Will only use the CLS token"""
def __init__(self, bert_trainable=False, number_filters=50,FNN_units=512,
number_clases=2,dropout_rate=0.1,name="dcnn"):
super(CLS_BERT_Embedding,self).__init__(name)
self.checkpoint_id ="CLS_BERT_Embedding_bn_3fc_{}filters_{}fc_units_berttrainable{}".format(number_filters,
FNN_units,bert_trainable)
# trainable= False so we don't fine-tune bert, just use as embedding layer
self.bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1",
trainable=bert_trainable,
input_shape=(3,376))
self.dense_1 = layers.Dense(units = FNN_units,activation="relu")
self.bn1 = layers.BatchNormalization()
self.dense_2 = layers.Dense(units = FNN_units, activation="relu")
self.bn2 = layers.BatchNormalization()
self.dense_3 = layers.Dense(units = FNN_units, activation="relu")
self.bn3 = layers.BatchNormalization()
self.dropout = layers.Dropout(rate=dropout_rate)
if number_clases == 2:
self.last_dense = layers.Dense(units=1,activation="sigmoid")
else:
self.last_dense = layers.Dense(units=number_clases,activation="softmax")
def get_bert_embeddings(self,all_tokens):
CLS_embedding ,embeddings = self.bert_layer([all_tokens[:,0,:],
all_tokens[:,1,:],
all_tokens[:,2,:]])
return CLS_embedding,embeddings
def call(self,inputs,training):
CLS_embedding, x_seq = self.get_bert_embeddings(inputs)
x = self.dense_1(CLS_embedding)
x = self.bn1(x,training)
x = self.dense_2(x)
x = self.bn2(x,training)
x = self.dense_3(x)
x = self.bn3(x,training)
output = self.last_dense(x)
return output
#### config and hyper-params
NUMBER_FILTERS = 1024
FNN_UNITS = 2048
BERT_TRAINABLE = False
NUMBER_CLASSES = len(tokenizer.vocab)
DROPOUT_RATE = 0.2
NUMBER_EPOCHS = 3
LR = 0.001
DEVICE = '/GPU:0'
#### optimization definition
with tf.device(DEVICE):
model = CLS_BERT_Embedding(
bert_trainable = BERT_TRAINABLE,
number_filters=NUMBER_FILTERS,
FNN_units=FNN_UNITS,
number_clases=NUMBER_CLASSES,
dropout_rate = DROPOUT_RATE)
if NUMBER_CLASSES == 2:
loss = "binary_crossentropy"
metrics = ["accuracy"]
else:
loss="sparse_categorical_crossentropy"
metrics = ["sparse_categorical_accuracy"]
optimizer = tf.keras.optimizers.Adam(learning_rate = LR)
loss="sparse_categorical_crossentropy"
model.compile(loss=loss,optimizer=optimizer,metrics=metrics)
### training
with tf.device(DEVICE):
model.fit(train_dataset,
batch_size = BATCH_SIZE ,
epochs=NUMBER_EPOCHS,
shuffle=True,
callbacks=[MyCustomCallback(),
tf.keras.callbacks.ReduceLROnPlateau(monitor="loss",patience=5),
tensorboard,lr_tensorboard])
### testing
train_results = model.evaluate(train_dataset,batch_size = BATCH_SIZE)
print(train_results)
from BatchNorm makes accuracy at prediction time around 10% of what's reported during training in tensorflow 2.6
No comments:
Post a Comment