I am training a GAN to perform style transfer from two different image domains (source S and target T). Since I have available class information I have an extra Q network (except G and D) that measures the classification results of the generated images for the target domain and their labels (a LeNet network) and propagate the error to generator as well with D. From the convergence of the system I have noticed that D is starting always from 8 (the loss function error of the D network) and slightly drops until 4.5 and the G loss function error is starting from 1 and quickly drops to 0.2. The loss function of D and G I am using can be found here while the loss function of Q network is categorical cross-entropy. The error plots over the iterations are:
The loss function of D and G are:
def discriminator_loss(y_true,y_pred):
BATCH_SIZE=10
return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.concatenate([K.ones_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])),K.zeros_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])) ]) ), axis=-1)
def discriminator_on_generator_loss(y_true,y_pred):
BATCH_SIZE=10
return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.ones_like(K.flatten(y_pred))), axis=-1)
def generator_l1_loss(y_true,y_pred):
BATCH_SIZE=10
return K.mean(K.abs(K.flatten(y_pred) - K.flatten(y_true)), axis=-1)
Does it make sense that the error function of D is always that high? What is the explanation of the errors of D and G? Is it that the loss of D should be small in the beginning and rise after the iterations? Is it a good idea to restrained the D over G with a loss threshold? Finally, during the training does it make sense to calculate the error from the loss function over a validation set and not from my train set that I am using? (instead of directly using train_on_batch use fit and then evaluate on a test set).
EDIT:
For the losses I think that the loss of discriminator and the discriminator_on_generator are the normal loss functions of GANs, right? Then, there is extra loss function generator_l1_loss which I do not get why? It seems like the Wasserstein GANs loss function without the weight clipping (or the L1 distance between the generated results).
from Conditional GAN for domain translation

No comments:
Post a Comment