Thursday, 9 June 2022

Strange memory usage of a custom LSTM layer in tensorflow and training gets killed

I'm trying to create a custom LSTMCell in TensorFlow. I have a CPU with 24GB of RAM (No GPU). Firstly I have created an LSTMCell as the default LSTMCell. The code is given below:

class LSTMCell(tf.keras.layers.AbstractRNNCell):
    def __init__(self, units, **kwargs):
        self.units = units
        super(LSTMCell, self).__init__(**kwargs)

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.kernel = self.add_weight(shape=(input_dim, self.units * 4),name='kernel',initializer='uniform')
        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units * 4),name='recurrent_kernel',initializer='uniform')
        self.bias = self.add_weight(shape=(self.units * 4,),name='bias',initializer='uniform')

    def _compute_carry_and_output_fused(self, z, c_tm1):
        z0, z1, z2, z3 = z
        i = K.sigmoid(z0)
        f = K.sigmoid(z1)
        c = f * c_tm1 + i * K.tanh(z2)
        o = K.sigmoid(z3)
        return c, o

    def call(self, inputs, states, training=None):
        h_tm1 = states[0] 
        c_tm1 = states[1]
        z = K.dot(inputs, self.kernel)
        z += K.dot(h_tm1, self.recurrent_kernel)
        z = K.bias_add(z, self.bias)
        z = tf.split(z, num_or_size_splits=4, axis=1)
        c, o = self._compute_carry_and_output_fused(z, c_tm1)
        h = o * K.sigmoid(c)
        self.h = h
        self.c = c
        return h, [h,c]

This cell is working fine. It's only consuming 8GB of RAM. Then I modified the cell to my need, where I have doubled the parameter. The code is below:

class LSTMCell(tf.keras.layers.AbstractRNNCell):
    def __init__(self, units, **kwargs):
        self.units = units
        super(LSTMCell, self).__init__(**kwargs)

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.kernel = self.add_weight(shape=(input_dim, self.units * 4),name='kernel',initializer='uniform')
        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units * 4),name='recurrent_kernel',initializer='uniform')
        self.bias = self.add_weight(shape=(self.units * 4,),name='bias',initializer='uniform')
        self.kernel_bits = self.add_weight(shape=(input_dim, self.units * 4),name='_diffq_k',initializer='uniform',trainable=True)
        self.recurrent_kernel_bits = self.add_weight(shape=(self.units, self.units * 4),name='_diffq_rk',initializer='uniform',trainable=True)

    def _compute_carry_and_output_fused(self, z, c_tm1):
        z0, z1, z2, z3 = z
        i = K.sigmoid(z0)
        f = K.sigmoid(z1)
        c = f * c_tm1 + i * K.tanh(z2)
        o = K.sigmoid(z3)
        return c, o

    def call(self, inputs, states, training=None):
        h_tm1 = states[0] 
        c_tm1 = states[1]
        z = K.dot(inputs, self.kernel + self.kernel_bits)
        z += K.dot(h_tm1, self.recurrent_kernel + self.recurrent_kernel_bits)
        z = K.bias_add(z, self.bias)
        z = tf.split(z, num_or_size_splits=4, axis=1)
        c, o = self._compute_carry_and_output_fused(z, c_tm1)
        h = o * K.sigmoid(c)
        self.h = h
        self.c = c
        return h, [h,c]

Now when I try to train with this cell, it consumes all of my RAM within a few seconds and gets killed. The model that I'm using is given below:

input_shape = (1874, 1024)
input = tf.keras.layers.Input(shape=input_shape, name = "input_layer")
x = input
lstm = tf.keras.layers.RNN(LSTMCell(units=input_shape[1]), return_sequences = True)
x = lstm(x)
model = tf.keras.models.Model(input, x, name='my_model')

For the same dataset, the RAM consumption for both cells is a lot different. I have tried reducing input dimensions, and I can only train an lstm of 128 units within my capacity. If I go above that, the RAM gets full, and training gets killed. I have done the same thing in PyTorch, and there was no issue. Can anyone point out the cause of the problem that I'm having?



from Strange memory usage of a custom LSTM layer in tensorflow and training gets killed

No comments:

Post a Comment