Thursday, 16 December 2021

Pytorch mixed precision causing discriminator loss to go to NaN in WGAN-GP

I'm going to preface this by saying I don't have a lot of experience with neural networks - less so with Pytorch in general. I'm trying to implement mixed precision in a WGAN-GP implementation (not mine) so that I can save GPU memory and train a bit faster.

I got the code from here, but I made my own Generator/Discriminator models, which I will put at the bottom.

The training loop looks as follows:


scaler1 = torch.cuda.amp.GradScaler()
scaler2 = torch.cuda.amp.GradScaler()

for epoch in range(EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        cur_batch_size = real.shape[0]

        with torch.cuda.amp.autocast():
            noise = torch.randn(cur_batch_size, LATENT_SIZE, 1, 1).to(device)
            fake = gen(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            gp = gradient_penalty(critic, real, fake, device=device, scaler = scaler1)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
        critic.zero_grad()
        #loss_critic.backward(retain_graph=True)
        #opt_critic.step()
        scaler1.scale(loss_critic).backward(retain_graph = True)
        scaler1.unscale_(opt_critic)
        scaler1.step(opt_critic)

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]

        with torch.cuda.amp.autocast():
            gen_fake = critic(fake).reshape(-1)
            loss_gen = -torch.mean(gen_fake)

        
        gen.zero_grad()
        #loss_gen.backward()
        #opt_gen.step()

        scaler2.scale(loss_gen).backward(retain_graph = True)
        scaler2.unscale_(opt_gen)
        scaler2.step(opt_gen)

        scaler1.update()
        scaler2.update()


And the gradient penalty function:


def gradient_penalty(critic, real, fake, device="cpu", scaler = None):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores if scaler is None else scaler.scale(mixed_scores),
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    
    gradient = gradient / scaler.get_scale() if scaler is not None else gradient
    
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

I've tested this without mixed precision, and it seems to do well enough, but after I tried to implement mixed precision, the discriminator loss becomes NaN after a few batches. The generator loss appears to be normal (however it starts out negative, which I'm not sure is OK but it becomes positive later when not using mixed precision).

The following are my Generator and Discriminator models:


class Generator(nn.Module):
    def __init__(self, targetSize, channels, features, latentSize):
        super(Generator, self).__init__()
        
        mult = int(np.log(targetSize)/np.log(2) - 3)
        startFactor = 2**mult
        
        self.network = nn.Sequential(
            nn.ConvTranspose2d(latentSize, features * startFactor, 4, 1, 0, bias = False),
            nn.BatchNorm2d(features * startFactor),
            nn.LeakyReLU(0.2),
            *sum([self.__block(int(features * startFactor / (2**i)), int(features * startFactor / (2**(i+1)))) for i in range(mult)], []),
            nn.ConvTranspose2d(features, channels, 4, 2, 1, bias = False),
            nn.Tanh(),
            )

    def __block(self, in_features, out_features):
        layers = [nn.ConvTranspose2d(in_features, out_features, 4, 2, 1, bias = False)]
        layers.append(nn.BatchNorm2d(out_features))
        layers.append(nn.LeakyReLU(0.2))
        return layers

    def forward(self, inp):
        return self.network(inp)
    
class Discriminator(nn.Module):
    def __init__(self, targetSize, channels, features):
        super(Discriminator, self).__init__()
        
        mult = int(np.log(targetSize)/np.log(2) - 3)
        startFactor = 2**mult

        self.network = nn.Sequential(

            nn.Conv2d(channels, features, 4, 2, 1, bias = False),
            nn.LeakyReLU(0.2),
            *sum([self.__block(int(features * (2**i)), int(features * (2**(i+1)))) for i in range(mult)],[]),
            nn.Conv2d(features * startFactor, 1, 4, 2, 0, bias = False),
            )

    def __block(self, in_features, out_features):
        layers = [nn.Conv2d(in_features, out_features, 4, 2, 1, bias = False)]
        layers.append(nn.InstanceNorm2d(out_features, affine=True))
        layers.append(nn.LeakyReLU(0.2))
        return layers

    def forward(self, inp):
        return self.network(inp)

FP16

FP32

NOTE: The FP16 chart ends on step ~140 because it becomes NaN from that point onwards.

I've also created pastebins for the raw data: FP32 and FP16



from Pytorch mixed precision causing discriminator loss to go to NaN in WGAN-GP

No comments:

Post a Comment