I have been training a conditional GAN architecture similar to Pix2Pix with the following training-loop:
for epoch in range(start_epoch, end_epoch):
for batch_i, (input_batch, target_batch) in enumerate(dataLoader.load_batch(batch_size)):
fake_batch= self.generator.predict(input_batch)
d_loss_real = self.discriminator.train_on_batch(target_batch, valid)
d_loss_fake = self.discriminator.train_on_batch(fake_batch, invalid)
d_loss = np.add(d_loss_fake, d_loss_real) * 0.5
g_loss = self.combined.train_on_batch([target_batch, input_batch], [valid, target_batch])
Now this works well, but it is not very efficient as the dataloader quickly becomes a bottleneck time-wise. I have looked into the .fit_generator() function that keras provides, which allows the generator to run in a worker thread and runs much faster.
self.combined.fit_generator(generator=trainLoader,
validation_data=evalLoader
callbacks=[checkpointCallback, historyCallback],
workers=1,
use_multiprocessing=True)
It took me some time to see that this was incorrect, I wasn't training my generator and discriminator separately anymore and the discriminator wasn't being trained at all since it it set to trainable = False in the combined model, essentially ruining any kind of adversarial loss, and I might as well train my generator by itself with MSE.
Now my question is if there is some work around, such as training my discriminator inside a custom callback, which is triggered each batch of the .fit_generator() method? It is possible to implement to create custom callbacks, like this for example:
class MyCustomCallback(tf.keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
discriminator.train_on_batch()
Another possibility would be to parallelise the original training loop, but I am afraid that I have no time to do that right now.
from Training GAN in keras with .fit_generator()
No comments:
Post a Comment