Wednesday, 25 July 2018

Correct use of tf.while_loop when variable are created inside body

I'm using a while_loop in Tensorflow in order to iterate over a tensor and extracting specific slices over a given dimension. For each step, I need to use a decoder RNN to generate a sequence of output symbols. I'm using the code provided in tf.contrib.seq2seq, in particular, tf.contrib.seq2seq.dynamic_decode. The code looks similar to the following:

def decoder_condition(i, data, source_seq_len, ta_outputs):
    return tf.less(i, max_loop_len)

def decode_body(i, data, source_seq_len, ta_outputs):
    curr_data = data[:, i, :]
    curr_source_seq_len = source_seq_len[:, i, :]
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
        2 * self.opt["encoder_rnn_h_size"],
        curr_data,
        memory_sequence_length=curr_source_seq_len
    )
    cell = GRUCell(num_units)
    cell = AttentionWrapper(cell, attention_mechanism)
    # ... other code that initialises all the variables required
    # for the RNN decoder
    outputs = tf.contrib.seq2seq.dynamic_decode(
        decoder,
        maximum_iterations=self.opt["max_sys_seq_len"],
        swap_memory=True
    )
    with tf.control_dependencies([outputs)]:
        ta_outputs = ta_outputs.write(i, outputs)

    return i+1, data, ta_outputs

 loop_index = tf.constant(0)
 gen_outputs = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
 outputs = tf.while_loop(
      decoder_condition,
      decoder_body,
      loop_vars=[
          loop_index,
          data,
          data_source_len,
          ta_outputs
      ],
      swap_memory=True,
      back_prop=True, 
      parallel_iterations=1
)

So as you can see, I create different objects which depend specifically on the input at the current step i. I'm using tf.AUTO_REUSE in my current variable scope in such a way that the variables are reused even if I'm creating different objects. Unfortunately, my decoder seems that it's not properly training because it keeps generating incorrect values. I've already checked the input data to the decoder RNN and everything is correct. I suspect that there is something that I'm not doing properly in terms of how TensorFlow manages the TensorArray and while_loop.

So my main questions are:

  1. Is TensorFlow correctly propagating the gradients for each variable that it's created inside the while loop?
  2. Is it possible to create object inside the while loop that are dependent on specific slices of a Tensor obtained using the loop index?
  3. Does the backprop parameter guarantee that the gradients are propagated during training? Should it be set to False during inference?
  4. In general, are there any sanity check that I can use to spot possible errors in my implementation?

Thanks!



from Correct use of tf.while_loop when variable are created inside body

No comments:

Post a Comment