Wednesday, 13 November 2019

Training GAN in keras with .fit_generator()

I have been training a conditional GAN architecture similar to Pix2Pix with the following training-loop:

for epoch in range(start_epoch, end_epoch):
    for batch_i, (input_batch, target_batch) in enumerate(dataLoader.load_batch(batch_size)):
                fake_batch= self.generator.predict(input_batch)

                d_loss_real = self.discriminator.train_on_batch(target_batch, valid)
                d_loss_fake = self.discriminator.train_on_batch(fake_batch, invalid)
                d_loss = np.add(d_loss_fake, d_loss_real) * 0.5

                g_loss = self.combined.train_on_batch([target_batch, input_batch], [valid, target_batch])

Now this works well, but it is not very efficient as the dataloader quickly becomes a bottleneck time-wise. I have looked into the .fit_generator() function that keras provides, which allows the generator to run in a worker thread and runs much faster.

self.combined.fit_generator(generator=trainLoader,
                                    validation_data=evalLoader
                                    callbacks=[checkpointCallback, historyCallback],
                                    workers=1,
                                    use_multiprocessing=True)

It took me some time to see that this was incorrect, I wasn't training my generator and discriminator separately anymore and the discriminator wasn't being trained at all since it it set to trainable = False in the combined model, essentially ruining any kind of adversarial loss, and I might as well train my generator by itself with MSE.

Now my question is if there is some work around, such as training my discriminator inside a custom callback, which is triggered each batch of the .fit_generator() method? It is possible to implement to create custom callbacks, like this for example:

class MyCustomCallback(tf.keras.callbacks.Callback):
  def on_train_batch_end(self, batch, logs=None):
    discriminator.train_on_batch()

Another possibility would be to parallelise the original training loop, but I am afraid that I have no time to do that right now.



from Training GAN in keras with .fit_generator()

No comments:

Post a Comment