Friday 2 October 2020

Binary classification model using BERT encoder stuck at 50% accuracy

I'm trying to train a simple model for the Yelp binary classification task.

Load BERT encoder:

gs_folder_bert = "gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12"
bert_config_file = os.path.join(gs_folder_bert, "bert_config.json")
config_dict = json.loads(
bert_config = bert.configs.BertConfig.from_dict(config_dict)
_, bert_encoder = bert.bert_models.classifier_model(
    bert_config, num_labels=2)
checkpoint = tf.train.Checkpoint(model=bert_encoder)
    os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()

Load data:

data, info = tfds.load('yelp_polarity_reviews', with_info=True, batch_size=-1, as_supervised=True)
train_x_orig, train_y_orig = tfds.as_numpy(data['train'])
train_x = encode_examples(train_x_orig)
train_y = train_y_orig 

Use BERT to embed the data:

encoder_output = bert_encoder.predict(train_x)

Setup the model:

inputs = keras.Input(shape=(768,))
x = keras.layers.Dense(64, activation='relu')(inputs)
x = keras.layers.Dense(8, activation='relu')(x)
outputs = keras.layers.Dense(1, activation='sigmoid')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
sgd = SGD(lr=0.0001)
model.compile(loss='binary_crossentropy', optimizer=sgd, metrics=['accuracy'])

Train:[0], train_y, batch_size=64, epochs=3)
# encoder_output[0].shape === (10000, 1, 768)
# y_train.shape === (100000,)

Training results:

Epoch 1/5
157/157 [==============================] - 1s 5ms/step - loss: 0.6921 - accuracy: 0.5455
Epoch 2/5
157/157 [==============================] - 1s 5ms/step - loss: 0.6918 - accuracy: 0.5455
Epoch 3/5
157/157 [==============================] - 1s 5ms/step - loss: 0.6915 - accuracy: 0.5412
Epoch 4/5
157/157 [==============================] - 1s 5ms/step - loss: 0.6913 - accuracy: 0.5407
Epoch 5/5
157/157 [==============================] - 1s 5ms/step - loss: 0.6911 - accuracy: 0.5358

I tried different learning rates, but the main issue seems that training takes 1 second and the accuracy stays at ~0.5. Am I not setting the inputs/model correctly?

