Wednesday 10 March 2021

How to do gradient accumulation with custom fit in tf.keras?

Please add a minimum comment on your thoughts so that I can improve my query. Thanks. :)


I'm trying to train a tf.keras model with Gradient Accumulation (GA). But I don't want to use it in the custom training loop (like) but customizing the .fit() method by overriding the train_step. Is it possible? How to accomplish this?

And also, I know the pros of using GA but what are the major cons of using it? Why it's not come as a default but an optional feature with the framework?

# overriding train step 
# my attempt 
# it's not appropriately implemented 
# and need to fix 
class CustomTrainStep(tf.keras.Model):
    def __init__(self, n_gradients, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.n_gradients = n_gradients
        self.gradient_accumulation = [tf.zeros_like(this_var) for this_var in \
                                           self.trainable_variables]

    def train_step(self, data):
        x, y = data
        batch_size = tf.cast(tf.shape(x)[0], tf.float32)  
        # Gradient Tape
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        # Calculate batch gradients
        gradients = tape.gradient(loss, self.trainable_variables)
        # Accumulate batch gradients
        accum_gradient = [(acum_grad+grad) for acum_grad, grad in \
               zip(self.gradient_accumulation, gradients)]
        accum_gradient = [this_grad/batch_size for this_grad in accum_gradient]
        # apply accumulated gradients
        self.optimizer.apply_gradients(zip(accum_gradient, self.trainable_variables))
        # TODO: reset self.gradient_accumulation 
        # update metrics
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

Please, run and check with the following toy setup.

# Model 
size = 32
input = tf.keras.Input(shape=(size,size,3))
efnet = tf.keras.applications.DenseNet121(weights=None,
                                          include_top = False, 
                                          input_tensor = input)
base_maps = tf.keras.layers.GlobalAveragePooling2D()(efnet.output) 
base_maps = tf.keras.layers.Dense(units=10, activation='softmax', 
                                             name='primary')(base_maps) 
custom_model = CustomTrainStep(n_gradients=10, inputs=[input], outputs=[base_maps])

# bind all
custom_model.compile(
    loss = tf.keras.losses.CategoricalCrossentropy(),
    metrics = ['accuracy'],
    optimizer = tf.keras.optimizers.Adam() )
# data 
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = tf.expand_dims(x_train, -1)
x_train = tf.repeat(x_train, 3, axis=-1)
x_train = tf.divide(x_train, 255)
x_train = tf.image.resize(x_train, [size,size]) # if we want to resize 
y_train = tf.one_hot(y_train , depth=10) 

# customized fit 
custom_model.fit(x_train, y_train, batch_size=64, epochs=3, verbose = 1)

Update

I've found that some others also tried to achieve this and ended up with the same issue. One has got some workaround, here, but it's too messy and I think there should be some better approach.



from How to do gradient accumulation with custom fit in tf.keras?

No comments:

Post a Comment