Note: All code for a self-contained example to reproduce my problem can be found below.
I have a tf.keras.models.Model
instance and need to train it with a training loop written in the low-level TensorFlow API.
The problem: Training the exact same tf.keras model once with a basic, standard low-level TensorFlow training loop and once with Keras' own model.fit()
method produces very different results. I would like to find out what I'm doing wrong in my low-level TF training loop.
The model is a simple image classification model that I train on Caltech256 (link to tfrecords below).
With the low-level TensorFlow training loop, the training loss first decreases as it should, but then after just 1000 training steps, the loss plateaus and then starts increasing again:
Training the same model on the same dataset using the normal Keras training loop, on the other hand, works as expected:
What am I missing in my low-level TensorFlow training loop?
Here is the code to reproduce the problem (download the TFRecords with the link at the bottom):
import tensorflow as tf
from tqdm import trange
import sys
import glob
import os
sess = tf.Session()
tf.keras.backend.set_session(sess)
num_classes = 257
image_size = (224, 224, 3)
# Build a simple model.
input_tensor = tf.keras.layers.Input(shape=image_size)
x = tf.keras.layers.Conv2D(64, (3,3), strides=(2,2), activation='relu', kernel_initializer='he_normal')(input_tensor)
x = tf.keras.layers.Conv2D(64, (3,3), strides=(2,2), activation='relu', kernel_initializer='he_normal')(x)
x = tf.keras.layers.Conv2D(128, (3,3), strides=(2,2), activation='relu', kernel_initializer='he_normal')(x)
x = tf.keras.layers.Conv2D(256, (3,3), strides=(2,2), activation='relu', kernel_initializer='he_normal')(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(num_classes, activation='softmax', kernel_initializer='he_normal')(x)
model = tf.keras.models.Model(input_tensor, x)
# Build a tf.data.Dataset from TFRecords.
tfrecord_directory = 'path/to/tfrecords/directory'
tfrecord_filennames = glob.glob(os.path.join(tfrecord_directory, '*.tfrecord'))
feature_schema = {'image': tf.FixedLenFeature([], tf.string),
'filename': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)}
dataset = tf.data.Dataset.from_tensor_slices(tfrecord_filennames)
dataset = dataset.shuffle(len(tfrecord_filennames)) # Shuffle the TFRecord file names.
dataset = dataset.flat_map(lambda filename: tf.data.TFRecordDataset(filename))
dataset = dataset.map(lambda single_example_proto: tf.parse_single_example(single_example_proto, feature_schema)) # Deserialize tf.Example objects.
dataset = dataset.map(lambda sample: (sample['image'], sample['label']))
dataset = dataset.map(lambda image, label: (tf.image.decode_jpeg(image, channels=3), label)) # Decode JPEG images.
dataset = dataset.map(lambda image, label: (tf.image.resize_image_with_pad(image, target_height=image_size[0], target_width=image_size[1]), label))
dataset = dataset.map(lambda image, label: (tf.image.per_image_standardization(image), label))
dataset = dataset.map(lambda image, label: (image, tf.one_hot(indices=label, depth=num_classes))) # Convert labels to one-hot format.
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.repeat()
dataset = dataset.batch(32)
This is the simple TensorFlow training loop:
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
# Build the training-relevant part of the graph.
model_output = model(features)
#loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.stop_gradient(labels), logits=model_output))
loss = tf.keras.backend.categorical_crossentropy(target=labels, output=model_output, from_logits=False)
train_op = tf.train.AdamOptimizer().minimize(loss)
# The next three lines are only relevant for the accuracy metric.
softmax_output = tf.nn.softmax(model_output)
predictions_argmax = tf.argmax(softmax_output, axis=-1, output_type=tf.int64)
labels_argmax = tf.argmax(labels, axis=-1, output_type=tf.int64)
mean_loss_value, mean_loss_update_op = tf.metrics.mean(loss)
acc_value, acc_update_op = tf.metrics.accuracy(labels=labels_argmax, predictions=predictions_argmax)
# Run the training
epochs = 3
steps_per_epoch = 1000
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
with sess.as_default():
for epoch in range(1, epochs+1):
tr = trange(steps_per_epoch, file=sys.stdout)
tr.set_description('Epoch {}/{}'.format(epoch, epochs))
fetch_list = [mean_loss_value,
acc_value,
train_op,
mean_loss_update_op,
acc_update_op]
for train_step in tr:
ret = sess.run(fetch_list, feed_dict={tf.keras.backend.learning_phase(): 1})
tr.set_postfix(ordered_dict={'loss': ret[0],
'accuracy': ret[1]})
Below is the standard Keras training loop, which works as expected. Note that the activation of the dense layer in the model above needs to be changed from None
to 'softmax' in order for the Keras loop to work.
epochs = 3
steps_per_epoch = 1000
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
history = model.fit(dataset,
epochs=epochs,
steps_per_epoch=steps_per_epoch)
You can download the TFRecords for the Caltech256 dataset here (about 850 MB).
UPDATE:
I've managed to solve at least part of the problem: Replacing the low-level TF loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.stop_gradient(labels), logits=model_output))
by its Keras equivalent
loss = tf.keras.backend.categorical_crossentropy(target=labels, output=model_output, from_logits=False)
improves the situation. Now the low-level TensorFlow training loop at least doesn't start diverging anymore after a thousand training steps, but it still converges a lot slower than using model.fit()
. This raises the questions:
- Something must still be missing. What else needs to be taken care of in my low-level TF training loop in order for it to match the performance of Keras' built-in training methods?
- What does
tf.keras.backend.categorical_crossentropy()
do thattf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2()
doesn't that leads the latter to perform much worse?
from Training a tf.keras model with a basic low-level TensorFlow training loop doesn't work
No comments:
Post a Comment