I'm going to preface this by saying I don't have a lot of experience with neural networks - less so with Pytorch in general. I'm trying to implement mixed precision in a WGAN-GP implementation (not mine) so that I can save GPU memory and train a bit faster.
I got the code from here, but I made my own Generator/Discriminator models, which I will put at the bottom.
The training loop looks as follows:
scaler1 = torch.cuda.amp.GradScaler()
scaler2 = torch.cuda.amp.GradScaler()
for epoch in range(EPOCHS):
# Target labels not needed! <3 unsupervised
for batch_idx, (real, _) in enumerate(loader):
real = real.to(device)
cur_batch_size = real.shape[0]
with torch.cuda.amp.autocast():
noise = torch.randn(cur_batch_size, LATENT_SIZE, 1, 1).to(device)
fake = gen(noise)
critic_real = critic(real).reshape(-1)
critic_fake = critic(fake).reshape(-1)
gp = gradient_penalty(critic, real, fake, device=device, scaler = scaler1)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
)
critic.zero_grad()
#loss_critic.backward(retain_graph=True)
#opt_critic.step()
scaler1.scale(loss_critic).backward(retain_graph = True)
scaler1.unscale_(opt_critic)
scaler1.step(opt_critic)
# Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
with torch.cuda.amp.autocast():
gen_fake = critic(fake).reshape(-1)
loss_gen = -torch.mean(gen_fake)
gen.zero_grad()
#loss_gen.backward()
#opt_gen.step()
scaler2.scale(loss_gen).backward(retain_graph = True)
scaler2.unscale_(opt_gen)
scaler2.step(opt_gen)
scaler1.update()
scaler2.update()
And the gradient penalty function:
def gradient_penalty(critic, real, fake, device="cpu", scaler = None):
BATCH_SIZE, C, H, W = real.shape
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * alpha + fake * (1 - alpha)
# Calculate critic scores
mixed_scores = critic(interpolated_images)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores if scaler is None else scaler.scale(mixed_scores),
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient / scaler.get_scale() if scaler is not None else gradient
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
I've tested this without mixed precision, and it seems to do well enough, but after I tried to implement mixed precision, the discriminator loss becomes NaN after a few batches. The generator loss appears to be normal (however it starts out negative, which I'm not sure is OK but it becomes positive later when not using mixed precision).
The following are my Generator and Discriminator models:
class Generator(nn.Module):
def __init__(self, targetSize, channels, features, latentSize):
super(Generator, self).__init__()
mult = int(np.log(targetSize)/np.log(2) - 3)
startFactor = 2**mult
self.network = nn.Sequential(
nn.ConvTranspose2d(latentSize, features * startFactor, 4, 1, 0, bias = False),
nn.BatchNorm2d(features * startFactor),
nn.LeakyReLU(0.2),
*sum([self.__block(int(features * startFactor / (2**i)), int(features * startFactor / (2**(i+1)))) for i in range(mult)], []),
nn.ConvTranspose2d(features, channels, 4, 2, 1, bias = False),
nn.Tanh(),
)
def __block(self, in_features, out_features):
layers = [nn.ConvTranspose2d(in_features, out_features, 4, 2, 1, bias = False)]
layers.append(nn.BatchNorm2d(out_features))
layers.append(nn.LeakyReLU(0.2))
return layers
def forward(self, inp):
return self.network(inp)
class Discriminator(nn.Module):
def __init__(self, targetSize, channels, features):
super(Discriminator, self).__init__()
mult = int(np.log(targetSize)/np.log(2) - 3)
startFactor = 2**mult
self.network = nn.Sequential(
nn.Conv2d(channels, features, 4, 2, 1, bias = False),
nn.LeakyReLU(0.2),
*sum([self.__block(int(features * (2**i)), int(features * (2**(i+1)))) for i in range(mult)],[]),
nn.Conv2d(features * startFactor, 1, 4, 2, 0, bias = False),
)
def __block(self, in_features, out_features):
layers = [nn.Conv2d(in_features, out_features, 4, 2, 1, bias = False)]
layers.append(nn.InstanceNorm2d(out_features, affine=True))
layers.append(nn.LeakyReLU(0.2))
return layers
def forward(self, inp):
return self.network(inp)
NOTE: The FP16 chart ends on step ~140 because it becomes NaN from that point onwards.
I've also created pastebins for the raw data: FP32 and FP16
from Pytorch mixed precision causing discriminator loss to go to NaN in WGAN-GP
No comments:
Post a Comment