Monday, 3 July 2023

Minimal diffusion model (DDIM) for MNIST

For the purpose of learning I created a minimal DDIM for the MNIST dataset. Everything besides the math of diffusions I consider "extras."

Here is my list:

  • U-Net (removed - replaced with something simpler)
  • Positional embeddings (removed - part of unet)
  • Diffusion Schedule (added it back in case it helps)
  • Normalization of the dataset (left it in there for now)

The reason for a minimal example is because I do not understand the contribution of these other tricks. Therefore, If I start with something simpler - I can see the contribution of additional optimizations.

I expected to see some pictures that resemble a number but I do not. Loss goes down very slowly but not good enough.

I may have a bug or "what I have" it's just not enough. What would it take to make this minimal example to barely work? Any help would be greatly appreciated

The code is borrowed from this great Keras example: https://keras.io/examples/generative/ddim/

Here is my working code:

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os

print("tf version: ", tf.__version__)

# data
diffusion_steps = 20
image_size = 28

# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95

# optimization
batch_size = 64
num_epochs = 1000
learning_rate = 1e-3

x0 = tf.keras.Input(shape=(28, 28, 1))
t0 = tf.keras.Input(shape=(1, 1,  1))

combined = tf.keras.layers.Add()([x0, t0])

x = tf.keras.layers.Flatten()(combined)
x = tf.keras.layers.Dense(7 * 7 * 64, activation="relu")(x)
x = tf.keras.layers.Reshape((7, 7, 64))(x)
x = tf.keras.layers.Conv2DTranspose(
    64, 3, activation="relu", strides=2, padding="same"
)(x)
x = tf.keras.layers.Conv2DTranspose(
    32, 3, activation="relu", strides=2, padding="same"
)(x)
output = tf.keras.layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
network = tf.keras.Model(inputs=[x0, t0], outputs=output)
# print(network.summary())


class DiffusionModel(tf.keras.Model):
    def __init__(self, network):
        super().__init__()

        self.normalizer = tf.keras.layers.Normalization()
        self.network = network

    def compile(self, **kwargs):
        super().compile(**kwargs)
        self.noise_loss_tracker = tf.keras.metrics.Mean(name="n_loss")
        self.image_loss_tracker = tf.keras.metrics.Mean(name="i_loss")

    @property
    def metrics(self):
        return [self.noise_loss_tracker, self.image_loss_tracker]

    def denormalize(self, images):
        return tf.clip_by_value(images, 0.0, 1.0)

    def diffusion_schedule(self, diffusion_times):
        # diffusion times -> angles
        start_angle = tf.acos(max_signal_rate)
        end_angle = tf.acos(min_signal_rate)
        diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
        # angles -> signal and noise rates
        signal_rates = tf.cos(diffusion_angles)
        noise_rates = tf.sin(diffusion_angles)
        # note that their squared sum is always: sin^2(x) + cos^2(x) = 1
        return noise_rates, signal_rates

    # predictive stage
    def denoise(self, noisy_images, noise_rates, signal_rates, training):
        # predict noise component and calculate the image component using it
        pred_noises = self.network([noisy_images, noise_rates**2], training=training)
        pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
        return pred_noises, pred_images

    def reverse_diffusion(self, initial_noise, steps):
        # reverse diffusion = sampling
        batch = initial_noise.shape[0]
        step_size = 1.0 / steps

        # important line:
        # at the first sampling step, the "noisy image" is pure noise
        # but its signal rate is assumed to be nonzero (min_signal_rate)
        next_noisy_images = initial_noise
        for step in range(diffusion_steps):
            noisy_images = next_noisy_images
            diffusion_times = tf.ones((batch, 1, 1, 1)) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=False
            )

            # this new noisy image will be used in the next step
            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(
            next_diffusion_times
            )
            next_noisy_images = (
            next_signal_rates * pred_images + next_noise_rates * pred_noises
           )
        return pred_images

    def generate(self, num_images, steps):
        # noise -> images -> denormalized images
        initial_noise = tf.random.normal(shape=(num_images, image_size, image_size, 1))
        generated_images = self.reverse_diffusion(initial_noise, steps)
        generated_images = self.denormalize(generated_images)
        return generated_images

    def train_step(self, images):
        # normalize images to have standard deviation of 1, like the noises
        images = self.normalizer(images, training=True)

        noises = tf.random.normal(shape=(batch_size, image_size, image_size, 1))
        diffusion_times = tf.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)

        # mix the images with noises accordingly
        noisy_images = signal_rates * images + noise_rates * noises

        with tf.GradientTape() as tape:
            # train the network to separate noisy images to their components
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=True
            )

            noise_loss = self.loss(noises, pred_noises)  # used for training
            image_loss = self.loss(images, pred_images)  # only used as metric

        gradients = tape.gradient(noise_loss, self.network.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

        self.noise_loss_tracker.update_state(noise_loss)
        self.image_loss_tracker.update_state(image_loss)
        return {m.name: m.result() for m in self.metrics}

    def plot_images(
        self,
        epoch=None,
        logs=None,
        num_rows=3,
        num_cols=6,
        write_to_file=True,
        output_dir="output",
    ):
        # plot random generated images for visual evaluation of generation quality
        generated_images = self.generate(
            num_images=num_rows * num_cols,
            steps=diffusion_steps,
        )

        plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
        for row in range(num_rows):
            for col in range(num_cols):
                index = row * num_cols + col
                plt.subplot(num_rows, num_cols, index + 1)
                plt.imshow(generated_images[index])
                plt.axis("off")

        plt.tight_layout()

        if write_to_file:
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            if epoch is not None:
                filename = os.path.join(
                    output_dir, "image_epoch_{:04d}.png".format(epoch)
                )
            else:
                import time

                timestr = time.strftime("%Y%m%d-%H%M%S")
                filename = os.path.join(output_dir, "image_{}.png".format(timestr))
            plt.savefig(filename)
        else:
            plt.show()

        plt.close()


# create and compile the model
model = DiffusionModel(network)
# below tensorflow 2.9:
# pip install tensorflow_addons
# import tensorflow_addons as tfa
# optimizer=tfa.optimizers.AdamW
model.compile(
    optimizer=tf.keras.optimizers.experimental.AdamW(learning_rate=learning_rate),
    loss=tf.keras.losses.mean_absolute_error,
)
# pixelwise mean absolute error is used as loss

# save the best model based on the noise loss
checkpoint_path = "checkpoints/diffusion_model"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor="i_loss",
    mode="min",
    save_best_only=True,
)

(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255

dataset = tf.data.Dataset.from_tensor_slices(mnist_digits)
dataset = dataset.batch(batch_size, drop_remainder=True)

# calculate mean and variance of training dataset for normalization
model.normalizer.adapt(mnist_digits)

# run training and plot generated images periodically
model.fit(
    dataset,
    epochs=num_epochs,
    batch_size=batch_size,
    callbacks=[
        tf.keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
        checkpoint_callback,
    ],
)

# load the best model and generate images
model.load_weights(checkpoint_path)
model.plot_images(write_to_file=False)

Edit

  • Removed commented out block as pointed out by @xdurch0

Unfortunately, still no luck. Just to clarify, to answer this question one can either provide a network that is simpler than u-net that we can recognize some digits, or explain why we need u-net.



from Minimal diffusion model (DDIM) for MNIST

No comments:

Post a Comment