Diffusion Models#

In this lecutre, we will be exploring the architecture of Diffusion models! To understand the details, you will need to know about Self-Attention and Transformers. Please refer the Attention, Transformers, and Vision Transformer notebooks for some help if you need it!

import os
import subprocess
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm.auto import tqdm
from torchvision.datasets import ImageFolder
from torchvision import transforms 
from PIL import Image
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from transformers import get_cosine_schedule_with_warmup
import itertools
def wget_data(url: str, local_path='./tmp_data'):
  os.makedirs(local_path, exist_ok=True)

  p = subprocess.Popen(["wget", "-nc", "-P", local_path, url], stderr=subprocess.PIPE, encoding='UTF-8')
  rc = None

  while rc is None:
    line = p.stderr.readline().strip('\n')
    if len(line) > 0:
      print(line)
    rc = p.poll()

Diffusion Intuition#

Diffusion models are a powerful form for generative AI that intuitively works pretty easily! We first add noise to an image (forward diffusion process) and then learn how to remove noise from the image (backward diffusion process). Although we add random noise to the image, the noise is not added randomly! More specifically, the noise we add to an image follows a specific noise schedule. In modern diffusion models, there are a large variety of noise schedulers available, but they all do something similar, scheduling of the variance. Therefore, adding noise is easy and follows some rules but removing noise is hard and has to be learned!

The noise we add to the image is always gaussian noise, and a gaussian distribution is characterized by only two parameters, the mean \(\mu\) and variance \(\sigma^2\). When we sample noise from this distribution, we assume the mean is 0, so we will be scheduling just the variance.

I keep saying schedule but why is this? We don’t actually add noise once, but multiple times! In most diffusion models today, we will add noise 1000 times to the image. By doing so, the original image will become essentially gaussian noise. The scheduler controls the intensity of noise at every step. If you have small variance, then the noise is very small, and if you have a large variance then the noise will be large. You could add the same amount of noise repeatedly to an image 1000 times but that may not be preferable. The scheduler can start by adding small amounts of noise that the beginning, and gradually increase the intensity of noise at every step. The reverse diffusion process then takes the noisy image, the number of times noise was added, and then attempts to remove the noise.

A couple of really important clarifications before we go into the math though!

  • If we add 500 steps of noise to an image, the model will attempt to reconstruct it back to the case where there was NO noise (the original image at step 0)

  • During inference, we pass in random noise with the maximum timesteps of our model, lets pretend its 1000. Then the model will then remove as much noise as it can. We can then pass this denoised image with timestep of 999, and the model will attempt to denoise again. We take the denoised image again, pass it in with timestep 998, rinse and repeat!

Therefore inferencing is slightly different than training. When training, we are passing in actual images that have been destroyed with noise. But, even if we add noise to an image, the original image is still somewhere in there, just hidden away. When inferencing, we will only provide true noise, and the model attempts to create something from nothing. This is why we have repeated applications of the denoising model for all timesteps from 1000 down to 1. Even though the model was trained to undo all the timesteps of noise (if we pass in an image with 500 timesteps of noise go back to original image with 0 steps of noise), when the image is actually noise, we will have to reapply the denoising repeatedly to hopefully generate something useful!

A Mathematical Overview to Understand Diffusion#

There is a lot of interesting math that goes into diffusion, making statisticians super excited (physicists like me less so). There is quite a bit of it and some of it is nice to know, but there are more heavy details that are necessary to get a better grasp. There are things in the derivation that still don’t make total sense to me so I will try to highlights some of the most important parts here. Lets break this up into bite size pieces.

Forward Diffusion#

Forward diffusion is the process of iteratively adding noise to data according to some variance schedule for \(T\) timesteps. We will identify the forward process with the variable \(q\) which will identify the conditional distribution of some noise image \(X_T\) given the original image \(X_0\). We can then write all the steps \(1,...T\) of noise as the following:

\[ \Large q(x_{1:T}|x_0) = \prod_{t=1}^Tq(x_t|x_{t-1}) \]

As you can see, we are just multiplying every consecutive pairs of timesteps together to get our full probability distribution of all the timesteps together. The original DDPM paper then says that we can write a single one of these consecutive timestep probabilities as follows:

\[ \Large \boxed{ q(x_t|x_{t-1}) = N(x_t;\sqrt{1-B_t}x_{t-1}, B_t) } \]

where \(B_1, B_2, ..., B_T\) is a fixed noise schedule for the variance.

Basically, \(x_t\) should be sampled from a normal distribution with a mean centered at \(\sqrt{1-B_t}x_{t-1}\), based on the previous timestep, and variance of \(B_t\). But where did this even come from? At every step we add some noise, but if we dont scale the x down as we add noise we will have a variance explosion, therefore we can do this proof:


Proof

Let \(x_0 \sim N(\mu, 1)\), so it has some arbritary mean and a variance of 1. We want to add noise to the image \(x_0\) but ALWAYS KEEP THE VARIANCE AT 1 to avoid any variance explosions. Lets pretend our nosie is just randomly sampled from a standard normal distribution:

\[ \Large \epsilon \sim N(0,1) \]

To compute \(x_1\) we will scale \(x_0\) by some unknown factor \(a\) and scale the added noise by another unknown factor \(b\). Therefore we can write it as:

\[ \Large x_1 = ax_0 + b\epsilon \]

Now a quick review of normal distributions! If we multiply a normal distribution by some constant what happens?

\[ \Large Var(aX) = a^2Var(x) \]

Remember again, our variance scheduler determines the amount of noise added at every step, but currently the noise is just standard normal, but we want \(\epsilon_1\) to have variance \(B_1\). That should be easy enough to do because of the variance product rule! If \(\epsilon_1\) has variance 1, then:

\[ \Large Var(\sqrt{B_1}\epsilon_1) = (\sqrt{B_1})^2Var(\epsilon_1) = B_1 * 1 = B_1 \]

So now we know in our equation \(x_1 = ax_0 + b\epsilon\) that b should be \(\sqrt{B_1}\), so we can write our first step out as:

\[ \Large x_1 = ax_0 + \sqrt{B_1}\epsilon \]

Now the second condition! We want to make sure that \(x_1\) has an overall variance of 1 at the end to avoid variance explosions, so we can compute a to make that happen.

\[ \Large Var(x_1) = Var(ax_0 + \sqrt{B_1}\epsilon) = Var(ax_0) + Var(\sqrt{B_1}\epsilon) \]

Again remember, \(x_0\) has a variance of 1 and \(\epsilon\) also has a variance of 1, so by the variance rule we can compute:

\[ \Large Var(x_1) = a^2Var(x_0) + B_1Var(\epsilon) = a^2 + B_1 \]

This variance then has to have a sum of 1, so we can solve for a,

\[ \Large a^2 + B_1 = 1 ~~~\text{ therefore }~~~ a = \sqrt{1 - B_1} \]

Therefore we have:

\[ \Large x_1 = \sqrt{1-b_1}x_0 + \sqrt{B_1}\epsilon \]

and we can write the more general form for an arbritrary \(t\):

\[ \Large x_t = \sqrt{1-b_t}x_{t-1} + \sqrt{B_t}\epsilon \]

Which can also be written as \(q(x_t|x_{t-1}) = N(x_t;\sqrt{1-B_t}x_t, B_t)\)


So first things first, we need a list of numbers to be our schedule for noise! There are a ton of schedulers out there (linear, cosine, etc…) so lets just go with a simple linear scheduler:

Linear Scheduler#

beta_start = 0.0001
beta_end = 0.2
num_training_steps = 1000
beta_schedule = torch.linspace(beta_start, beta_end, num_training_steps)

Increasing Computational Efficiency

So we have made it back to our original equation! But we still have a problem, which has more to do with practical efficiency rather than theory. If we want an image with 500 steps of noise added to it, we actually need to do a for loop to repeatedly add noise, but would’t it be nice to be able to add all 500 steps of noise in just a single calculation? We are just adding gaussian noise over and over to an image, and we know the sum of gaussians are also a gaussian, so we should be able to do this.

In the paper they do this by introducing a new variable that is just a manipulation of our old variable \(B_t\)

\[ \Large a_t = 1 - B_t \text{ and } \bar{a}_t = \prod_{s=1}^t\alpha_s \]

for example:

\[ \Large \bar{a}_2 = a_1 * a_2 = (1-B_1)*(1-B_2) \]

We can go ahead and compute these pieces as well then!

Alpha Computation#


### Compute Alpha ###
alpha = 1 - beta_schedule

### Compute Cumulative Alpha ###
alpha_cumulative_prod = torch.cumprod(alpha)

So time for some fun algebraic manipipulations! Rember our previous eqation for \(x_t\)?

\[ \Large x_t = \sqrt{1-b_t}x_{t-1} + \sqrt{B_t}\epsilon \]

Lets rewrite this in terms of our new variable \(a_t\)

\[ \Large x_t = \sqrt{a_t}x_{t-1} + \sqrt{1 - a_t}\epsilon \]

Well \(x_{t-1}\) can be written very similarly:

\[ \Large x_{t-1} = \sqrt{a_{t-1}}x_{t-2} + \sqrt{1-a_{t-1}}\epsilon \]

Lets substitute \(x_{t-1}\) into our formula for \(x_t\)

\[ \Large x_t = \sqrt{a_t} [ \sqrt{a_{t-1}}x_{t-2} + \sqrt{1-a_{t-1}}\epsilon] + \sqrt{1 - a_t}\epsilon \]
\[ \Large = \sqrt{a_t}\sqrt{a_{t-1}}x_{t-2} + \sqrt{a_t}\sqrt{1 - a_{t-1}}\epsilon + \sqrt{1 - a_t}\epsilon \]

Now how can we simplify this? This was probably one of the tricker parts of the derivation that I didn’t catch at first glance. Remember again, \(\epsilon\) is just a random variable coming from a standard normal distribution, and by multiplying it by some constant, we are adjusting its variance. We have two parts of our equation above using epsilon, so lets try to combine them! Therefore lets create some temporary random variables using our variance rule from earlier:

\[ \Large \sqrt{a_t}\sqrt{1 - a_{t-1}}\epsilon = X \sim N(0, a_t(1 - a_{t-1})) \]
\[ \Large \sqrt{1 - a_t}\epsilon = Y \sim N(0, 1 - a_t) \]

Therefore \(X\) and \(Y\) are both normal distributions with different variances. So a quick review, what happens when we add two normal distributions together with different means and variances?

\[ \Large N(\mu_1, \sigma_1^2) + N(\mu_2, \sigma_2^2) = N(\mu_1+\mu_2, \sigma_1^2+\sigma_2^2) \]

So similarly, lets add together our two normal distributions \(X\) and \(Y\) above.

\[ \Large X + Y = N(0, a_t(1 - a_{t-1})) + N(0, 1 - a_t) = N(0, a_t(1 - a_{t-1}) + (1 - a_t)) \]
\[ \Large = N(0, 1 - a_ta_{t-1}) \]

Again by the variance rule we can write this normal distribution as the standard normal epsilon, multiplied by the sqrt of the variance we want:

\[ \Large N(0, 1 - a_ta_{t-1}) = \sqrt{1 - a_ta_{t-1}}\epsilon \]

Therefore our final equation for \(x_t\) can be written as:

\[ \Large x_t = \sqrt{a_t}\sqrt{a_{t-1}}x_{t-2} + (\sqrt{1 - a_ta_{t-1}})\epsilon \]

Now what if we substituted in the equation for \(x_{t-2}\) like we did earlier for \(x_{t-1}\)? Well the same thing! Actually we can continue substituting repeatedly and you will see a pattern emerge that,

\[ \Large q(x_t|x_0) = \prod_{i=1}^t\sqrt{a_i}x_0 + (\sqrt{1 - \prod_{i=1}^ta_i})\epsilon \]

Which is then written with the shorthand \(\bar{a}_t\) notation.

\[ \Large q(x_t|x_0) = \sqrt{\bar{a}_t}x_0 + (\sqrt{1 - \bar{a}_t})\epsilon \]

So we did it! If you made it this far, you should now understand all the details about the forward diffusion process and ways we can make it more efficient. Lets do a quick implementation of the formulas above:


### What is the image we want to add noise to ? ###
x_0 = IMAGE

### What is the timestep we want? Lets add 500 steps of noise ###
timestep = 500

### Grab the corresponding cumulative alpha for 500 steps of noise ###
alpha_bar_t = alpha[timestep]

### Compute Mean Coefficient ###
mean_coeff = alpha_bar_t ** 0.5

### Compute Variance Coefficient ###
var_coeff = (1 - alpha_bar_t) ** 0.5

### Generate some Random Noise for the Episilon ###
epsilon = torch.randn(x_0.shape)

### Use the reparamaterization trick to sample from distribution ###
noisy_image = mean_coef*inputs + var_coef*epsilon

Backward Diffusion#

Forward diffusion is not too bad! Basically we have a scheduler, so we know the sequence of gaussian distributions we want to sample noise from, and we iteratively add steps of noise to our image.

On the other hand, how do we undo the steps of noise? Unfortunately, this is much harder and we will have to use a Neural Network to approximate this Backward Diffusion step.

Lets write out this with some math again where we want to represent, what is the distribution of the removal of \(T\) steps of noise?

\[ \Large P_\theta(x_{0:T}) = P(X_T)\prod_{t=1}^TP_\theta(X_{t-1}|x_t) \]

Notice the conditional \(P_\theta(X_{t-1}|x_t)\), we want the previous timestep given a future one, which is backwards from the forward diffusion process. We can represent this probability as the following:

\[ \Large P(X_T)\prod_{t=1}^TP_\theta(X_{t-1}|x_t) \sim N(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t,t)) \]

where \(\mu_\theta(x_t, t), \Sigma_\theta(x_t,t)\) have to be learned by the neural network.

Again we have no idea what these reverse probability distributions are. But what was the original problem we want to solve? We want to compute \(P(x_0)\), the distribution of the original image data. But right now we will have the joint distribution of \(P_\theta(x_{0:T})\) which means we need to marginalize our distribution. If you don’t know what marginalizaing a distribution means just take a quick look here. At a high level, if you have a joint distribution in two variables, and you only want one of them, you can integrate out the ones that you dont. In our case it will look something like this:

\[ \Large P(x_0) = \int_{x_1}\int_{x_2}\int_{x_3}...\int_{x_T} P(x_{0:T})dx_1dx_2...dx_T \]

This would be an incredibly expensive process though so instead we will take a hint from Variational AutoEncoders and use the Evidence Lower Bound.

The mathematical derivation is out of the scope of gathering some intuition for diffusion models as I try to push a more hands-on approach than a theory driven one. I was writing the derivation out myself when making this lecture and then I found an incredible resource! Author Lilian Weng has this great blog post What are Diffusion Models? where it covers all the ground level math for you to really know the reverse process. Definitely give it a read, you will need to know a bit about ELBO and Bayes rule to do it but it should be pretty intuitive. In our case, lets just go over the results of the reverse diffusion process rather than the derivation. This was the given recipie by the original DDPM paper.

First we will have a small change in the forward diffusion process. Up to now, we have been doing forward diffusion as \(q(x_t|x_{t-1})\). Instead, we will also condition \(x_t\) on \(x_0\), so it will look like \(q(x_t|x_{t-1}, x_0)\). We can then use Bayes rule to get an expression for \(q(x_{t-1}|x_t, x_0)\). The reason for this is, if you do the derivation, you will see that the reverse conditional probability actually isn’t possible to solve without conditioning on \(x_0\). This is the first main derivation showed in Lilian Weng’s blog post that you can look at. But the result is:

\[ \Large q(x_{t-1}|x_t, x_0) \sim N(x_{t-1}; \tilde{\mu}(x_t, x_0), \Sigma_\theta(x_t,t)) \]

We can then write \(\tilde{\mu}_t\) as:

\[ \Large \tilde{\mu}_t = \frac{1}{\sqrt{a_t}}(x_t - \frac{1-a_t}{\sqrt{1-\bar{a_t}}}e_t) \]

and our full expression for \(q(x_{t-1}|x_t, x_0)\) becomes:

\[ \Large q(x_{t-1}|x_t, x_0) \sim N(x_{t-1}; \frac{1}{\sqrt{a_t}}(x_t - \frac{1-a_t}{\sqrt{1-\bar{a_t}}}e_t), \Sigma_\theta(x_t,t)) \]

In the original DDPM paper ONLY the mean of the reverse process is learned, \(\mu_\theta\) and the variance \(\Sigma_t\) is fixed according to the scheduler. So we can then set \(\Sigma_\theta(x_t,t) \to \sigma_t^2I\).

In later papers like the Improved Diffusion Models paper, the model will learn both the mean and variance of the reverse process. Because \(\sigma_t^2I\) is set ahead of time, we can directly compute it. The paper sets this as:

\[ \Large \sigma_t^2I = \frac{1-\bar{a}_{t-1}}{1-\bar{a}_{t}}B_t \]

We could also have done \(\sigma_t^2I = B_t\) and directly used our beta schedule only but the DDPM paper found no difference experimentally between using one or the other, so we will just go with the previous expression as its typically used in implementations I’ve seen!

So our final form of the reverse process is:

\[ \Large \boxed{ q(x_{t-1}|x_t, x_0) \sim N(x_{t-1}; \frac{1}{\sqrt{a_t}}(x_t - \frac{1-a_t}{\sqrt{1-\bar{a_t}}}e_t), \sigma_t^2I) } \]

I know I skipped over a ton of mathematical detail again, maybe later I’ll make a deeper dive into the reverse process! But for now this should be enough to get started and have a basic understanding.

Lets quickly look at the simplified code that will compute all of this! There are three things we need: (1) the noisy image, (2) what timestep we are on, and (3) the predicted noise.


### Pass in the things we need ###
timestep = 500
noisy_image = NOISY_IMAGE
predicted_noise = PREDICTED_NOISE

### Compute Sigma (b_t * (1 - cumulative_a_(t-1)) / (1 - cumulative_a)) * noise ###
alpha_bar_t = alpha_cumulative_prod[timestep]
alpha_bar_t_prev = alpha_cumulative_prod[timestep - 1]
beta_t = beta_schedule[timestep]
noise = torch.randn_like(input)
variance = beta_t * (1 - alpha_cumulative_prod_t_prev) / (1 - alpha_cumulative_t)
sigma = noise * variance**0.5

### Compute Noise Coefficient (1 - a_t / sqrt(1 - cumulative_a)) where 1 - a_t = b_t ###
beta_t = beta_schedule[timestep]
alpha_bar_t = alpha_cumulative_prod[timestep]
root_one_minus_alpha_bar_t = (1 - alpha_bar_t) ** 0.5
noise_coefficient = beta_t / root_one_minus_alpha_bar_t

### Compute 1 / sqrt(a_t) ###
reciprocal_root_a_t = (alpha[timestep]**-0.5)

### Compute Denoised Image ###
denoised = reciprocal_root_a_t * (noisy_image - (noise_coefficient * predicted_noise)) + sigma_z

How do we learn this?#

As you can see, the entire problem now boils down to: given an image that is noisy, predict the noise so we can remove it!

The DDPM paper states that we can then take an image, predict the noise, and use Mean Squared Error as our loss function between the predicted and true noise. In practice though you could also use Mean Absolute Error or the Huber Loss (L1 + L2 Error). There are some mathematical justifications for this too. If you think back to Variational Autoencoders, what are trying to do is to maximize the Evidence Lower Bound (ELBO) because learning the likelihood of the data generating distribution is impossible, but we can learn the minmum bound of that function and continue to maximize it. The reason we are not taking a similar approach here is, the authors of DDPM claim that as long as the steps of noise added to the image is sufficiently small, we can approximate it with Mean Squared Error. You can take a closer look here if you want to see more!

Define Sampler#

We will now take everything we say previously and put it all together to create our scheduler! There are again a lot of different schedulers that exist, but for now we will stick to the basis linear beta scheduler. Again, at a high level, the sampler will have two methods:

  • add_noise: will take an image and add some timesteps amount of noise to it

  • remove_noise: will take a noisy image, which timestep of noise it was at, models predicted noise, and removes the noise accoring the scheduling parameters

The other change that will be made from the pseudocode before is we need to make sure out adding/removing noise can be applied to batches of images, and can occur on the correct device if it isn’t running on the CPU!

class Sampler:
    def __init__(self, num_training_steps=1000, beta_start=0.0001, beta_end=0.02):
        self.num_training_steps = num_training_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        ### Define Basic Beta Scheduler ###
        self.beta_schedule = self.linear_beta_schedule()

        ### Compute Alphas for Direction 0 > t Noise Calculation ###
        self.alpha = 1 - self.beta_schedule
        self.alpha_cumulative_prod = torch.cumprod(self.alpha, dim=-1)
    
    def linear_beta_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.num_training_steps)

    def _repeated_unsqueeze(self, target_shape, input):
        while target_shape.dim() > input.dim():
            input = input.unsqueeze(-1)
        return input
    
    def add_noise(self, inputs, timesteps):

        batch_size, c, h, w = inputs.shape

        ### Grab the Device we want to place tensors on ###
        device = inputs.device
        
        alpha_cumulative_prod_timesteps = self.alpha_cumulative_prod[timesteps].to(device)
        
        ### Compute Mean Coefficient ###
        mean_coeff = alpha_cumulative_prod_timesteps ** 0.5

        ### Compute Variance Coefficient ###
        var_coeff = (1 - alpha_cumulative_prod_timesteps) ** 0.5

        ### Reshape mean_coeff and var_coeff to have shape (batch x 1 x 1 x 1) so we can broadcast with input (batch x c x height x width) ###
        mean_coeff = self._repeated_unsqueeze(inputs, mean_coeff)
        var_coeff = self._repeated_unsqueeze(inputs, var_coeff)

        ### Generate some Noise X ~ N(0,1) (rand_like will automatically place on same device as the inputs) ###
        noise = torch.randn_like(inputs)
        
        ### Compute Mean (mean_coef * x_0) ###
        mean = mean_coeff * inputs

        ### Compute Variance ###
        var = var_coeff * noise

        ### Compute Noisy Data ###
        noisy_image = mean + var

        return noisy_image, noise
        
    def remove_noise(self, input, timestep, predicted_noise):

        assert (input.shape == predicted_noise.shape), "Shapes of noise pattern and input image must be identical!!"
        
        b, c, h, w = input.shape

        ### Grab Device to Place Tensors On ###
        device = input.device

        ### Create a mask (if timestep == 0 sigma_z will also be 0 so we need to save this for later ###
        greater_than_0_mask = (timestep >= 1).int()

        
        ### Compute Sigma (b_t * (1 - cumulative_a_(t-1)) / (1 - cumulative_a)) * noise ###
        alpha_cumulative_t = self.alpha_cumulative_prod[timestep].to(device)
        alpha_cumulative_prod_t_prev = self.alpha_cumulative_prod[timestep - 1].to(device) # (timestep - 1) if timestep is 0 is WRONG! we will multiply by 0 later
        beta_t = self.beta_schedule[timestep].to(device)
        noise = torch.randn_like(input)
        variance = beta_t * (1 - alpha_cumulative_prod_t_prev) / (1 - alpha_cumulative_t)

        ### 0 out the variance for if the timestep == 0 ###
        variance = variance * greater_than_0_mask
        variance = self._repeated_unsqueeze(input, variance)
        sigma_z = noise * variance**0.5

        ### Compute Noise Coefficient (1 - a_t / sqrt(1 - cumulative_a)) where 1 - a_t = b_t ###
        beta_t = self.beta_schedule[timestep].to(device)
        alpha_cumulative_t = self.alpha_cumulative_prod[timestep].to(device)
        root_one_minus_cumulative_alpha_t = (1 - alpha_cumulative_t) ** 0.5
        noise_coefficient = beta_t / root_one_minus_cumulative_alpha_t
        noise_coefficient = self._repeated_unsqueeze(input, noise_coefficient)
        

        ### Compute 1 / sqrt(a_t) ###
        reciprocal_root_a_t = (self.alpha[timestep]**-0.5).to(device)
        reciprocal_root_a_t = self._repeated_unsqueeze(input, reciprocal_root_a_t)
        
        ### Compute Denoised Image ###
        denoised = reciprocal_root_a_t * (input - (noise_coefficient * predicted_noise)) + sigma_z
 
        return denoised

Lets Test our Scheduler#

For this, we need a sample image, transformations to go from images to tensors, and to go back to images from tensors. For the transformations, we will make use of the ones give in the Huggingface Annotated Diffusion!

wget_data('https://raw.githubusercontent.com/illinois-mlp/MachineLearningForPhysics/main/data/kitten.png')
File ‘./tmp_data/kitten.png’ already there; not retrieving.
### Define an Image Size ###
image_size = 256

### Load Image ###
image = Image.open("./tmp_data/kitten.png").convert("RGB")

### Init Sampler ###
sampler = Sampler()

### Transformations ###
image2tensor_transform = transforms.Compose([
                    transforms.Resize((image_size, image_size)), # Resize Image
                    transforms.ToTensor(), # Convert to tensor (will scale from 0 to 1)
                    transforms.Lambda(lambda t: (t*2) - 1), # Change scale to be -1 to 1
                    transforms.Lambda(lambda t: t.unsqueeze(0))
])

tensor2image_transform = transforms.Compose([
        transforms.Lambda(lambda t: t.squeeze(0)), # Remove batch dimension on
        transforms.Lambda(lambda t: (t + 1) / 2), # Scale back to 0 to 1
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # Make channels last 
        transforms.Lambda(lambda t: t * 255.), # Scale back to 0 to 255
        transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)), # Convert to numpy
        transforms.ToPILImage(), # Conver to PIL
    ])


### Check Transforms ###
tensor2image_transform(image2tensor_transform(image))
../../_images/c6b956bbcf05335d4b00af86164adea2fe75af6f05abd8cd46121c07738181ca.png

Define Model#

We can now define the structure of the model. The parts are:

SelfAttention, MLP, and TransformerBlock#

Please refer the Attention, Transformers, and Vision Transformer notebooks for some help if you need it to understand the code below.

class SelfAttention(nn.Module):

  def __init__(self,
               in_channels,
               num_heads=12, 
               attn_p=0,
               proj_p=0,
               fused_attn=True):

    super().__init__()
    assert in_channels % num_heads == 0
    self.num_heads = num_heads
    self.head_dim = int(in_channels / num_heads)
    self.scale = self.head_dim ** -0.5
    self.fused_attn = fused_attn  

    self.qkv = nn.Linear(in_channels, in_channels*3)
    self.attn_p = attn_p
    self.attn_drop = nn.Dropout(attn_p)
    self.proj = nn.Linear(in_channels, in_channels)
    self.proj_drop = nn.Dropout(proj_p)

  def forward(self, x):
    batch_size, seq_len, embed_dim = x.shape
      
    qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
    qkv = qkv.permute(2,0,3,1,4)
    q,k,v = qkv.unbind(0)

    if self.fused_attn:
      x = F.scaled_dot_product_attention(q,k,v, dropout_p=self.attn_p)
    else:
      attn = (q @ k.transpose(-2,-1)) * self.scale
      attn = attn.softmax(dim=-1)
      attn = self.attn_drop(attn)
      x = attn @ v
    
    x = x.transpose(1,2).reshape(batch_size, seq_len, embed_dim)
    x = self.proj(x)
    x = self.proj_drop(x)
    
    return x

class MLP(nn.Module):
    def __init__(self, 
                 in_channels,
                 mlp_ratio=4, 
                 act_layer=nn.GELU,
                 mlp_p=0):

        super().__init__()
        hidden_features = int(in_channels * mlp_ratio)
        self.fc1 = nn.Linear(in_channels, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(mlp_p)
        self.fc2 = nn.Linear(hidden_features, in_channels)
        self.drop2 = nn.Dropout(mlp_p)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x
        
class TransformerBlock(nn.Module):
    def __init__(self,
                 in_channels,
                 fused_attention=True,
                 num_heads=4, 
                 mlp_ratio=2,
                 proj_p=0,
                 attn_p=0,
                 mlp_p=0,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        
        super().__init__()
        self.norm1 = norm_layer(in_channels, eps=1e-6)

        self.attn = SelfAttention(in_channels=in_channels,
                                  num_heads=num_heads, 
                                  attn_p=attn_p,
                                  proj_p=proj_p,
                                  fused_attn=fused_attention)
        
        self.norm2 = norm_layer(in_channels, eps=1e-6)
        self.mlp = MLP(in_channels=in_channels,
                       mlp_ratio=mlp_ratio,
                       act_layer=act_layer,
                       mlp_p=mlp_p)
        
    def forward(self, x):
        batch_size, channels, height, width = x.shape
      
        ### Reshape to batch_size x (height*width) x channels
        x = x.reshape(batch_size, channels, height*width).permute(0,2,1)
        
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))

        x = x.permute(0,2,1).reshape(batch_size, channels, height, width)
        return x

rand_img = torch.randn(4,32,64,64)
t = TransformerBlock(in_channels=32)
t(rand_img).shape
torch.Size([4, 32, 64, 64])

Sinusoidal Time Embeddings#

The model needs to have some idea of “time”. In this case, it will be, what timestep of noise am I at? To do so, we can directly use sinusoidal time embeddings proposed in the Attention is all you need paper. If you haven’t seen this before, here is a super high level overview. The formula for the embeddings is as follows:

\[ \Large PE_{position, 2i} = sin(\frac{position}{10000^{2i/d_{model}}}) \]
\[ \Large PE_{position, 2i+1} = cos(\frac{position}{10000^{2i+1/d_{model}}}) \]

The position in our case is the timestep (for the 1000 timesteps of noise we want to have) and \(d_{model}\) is actually a hyperparameter in this case. If you remember back to Language models, we actually know \(d_{model}\) ahead of time, the length of the vectors we use to represent tokens. We do know it here, but the embedding depth will change, because if you again think about how typical Convolutional models work, as we go deeper into the model we will have more channels. So at the beginning, we may only have 64 channels, so the \(d_{model}\) is 64, but later we may have 512 channels, so then the \(d_{model}\) is 512. Therefore we will create time embeddings for now for some given length \(d_{model}\), and then expand as needed in the later architecture.

The index \(i\) refers to, which index are we are accessing in the \(d_{model}\). Technically we should be interleaving \(sin\) and \(cos\) for every even and odd index (\(2i\) and \(2i+1\)), but experimentally I dont think it matters if we take all our our even sin values and then just concatenate all the odd cos values, so we will just do that to keep it easy.

We can then use linear layers to pick the length of these vectors, so we can pass in two hyperparameters time_embed_dim, which will be the initial dimension of our time embeddings and scaled_time_embed_dim, which will be the size of the projected time embed dim after some linear layers. We will also use the SiLU activation function for no reason other than that’s what I’ve seen people use online!

class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, time_embed_dim, scaled_time_embed_dim):
        super().__init__()
        self.inv_freqs = nn.Parameter(1.0 / (10000 ** (torch.arange(0, time_embed_dim, 2).float() / (time_embed_dim/2))), requires_grad=False)
        
        self.time_mlp = nn.Sequential(nn.Linear(time_embed_dim, scaled_time_embed_dim), 
                                      nn.SiLU(), 
                                      nn.Linear(scaled_time_embed_dim, scaled_time_embed_dim), 
                                      nn.SiLU())
    def forward(self, timesteps):
        timestep_freqs = timesteps.unsqueeze(1) * self.inv_freqs.unsqueeze(0)
        embeddings = torch.cat([torch.sin(timestep_freqs), torch.cos(timestep_freqs)], axis=-1)
        embeddings = self.time_mlp(embeddings)
        return embeddings

s = SinusoidalTimeEmbedding(time_embed_dim=128,scaled_time_embed_dim=256)
timesteps = torch.tensor([1,2,3])
s(timesteps).shape
torch.Size([3, 256])

UNet, Residual Blocks, Upsample Blocks#

UNet is a convolutional neural network that was developed for image segmentation. The network is based on a fully convolutional neural network whose architecture was modified and extended to work with fewer training images and to yield more precise segmentation. The UNet architecture has also been employed in diffusion models for iterative image denoising, which is the technology that underlies many modern image generation models, such as DALL-E, Midjourney, and Stable Diffusion.

The main idea of UNets is to supplement a usual contracting network by successive layers, where pooling operations are replaced by upsampling operators. Hence these layers increase the resolution of the output. A successive convolutional layer can then learn to assemble a precise output based on this information.

Our implementation of UNets will be pretty similar to the UNet proposed except for one key difference:

In the original UNet, the upsampling (decoder) was done through a series of Transpose Convolutions. Transpose Convolutions are fine for image segmentation, but when it comes to image generation it can lead to a problem known as a checkerboard effect. You can see some cool examples of this here!. At a high level, transpose conovolutions sometimes can overlap more in some areas of an image than another, do during generation, it will put more emphasis in specific locations of an image than other parts. This leads to darker or more saturated chunks in the image leading to the checkerboard pattern. In practice, using a simple Upsampling method like nearest neighbors or linear interpolation followed by a regular convolution works best so thats what we will do!

This implementation is also more for teaching, and we will be dynamically defining the structure of the model before building the layers so you can see how it comes together.

Sidenote: We will be utilizing a method known as GroupNorm in this archietcture.

Layernorm that we have been using for transformers normalize across all channels (in this case normalize across all channels of our image). Groupnorm will chunk across channels so each block of channels are normalized. This in practice gives better location specific normalization and can be used instead of layernorm or batchnorm.

Lets Start with the first piece: ResidualBlocks#

The ResidualBlock will be a block of convolutions, normalizations and where we incorporate our time embeddings intot he model. The crucial part of our residual block is NO UPSAMPLING OR DOWNSAMPLING SHOULD OCCUR. The shape of image (image shape and number of channels) that go into the ResidualBlock must come out identically! And you may be wondering, how do we incorporate time embeddigs? Time embeddings are a vector, but our image is a 3d tensor? Easy! We just reshape, broadcast and add it together! Lets just look at that piece first.

batch_size = 4

### Create some random images with 64 channels, and size of 32 by 32 ###
random_image = torch.randn(batch_size, 64, 32, 32)
print("Image Tensor", random_image.shape)

### Our random images have 64 channels, so our embeddings need to have a d_model of 64 ###
random_time_embeddings = torch.randn(batch_size, 64)
print("Original Time Embeddings", random_time_embeddings.shape)

### To add together a (b, 64, 32, 32) with a (b, 64) we need to add some extra dimensions to our embeddings ###
random_time_embeddings = random_time_embeddings.unsqueeze(-1).unsqueeze(-1)
print("Reshaped Time Embeddings", random_time_embeddings.shape)

## Add them together! ##
timed_images = random_image + random_time_embeddings
print("Output", timed_images.shape)
Image Tensor torch.Size([4, 64, 32, 32])
Original Time Embeddings torch.Size([4, 64])
Reshaped Time Embeddings torch.Size([4, 64, 1, 1])
Output torch.Size([4, 64, 32, 32])

One thing to keep in mind though is, our random_time_embeddings had exactly the number of channels we needed. In our case, we need to make sure to resize the original output from our SinusoidalTimeEmbeddings to match our required channels. Lets put it together!

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, groupnorm_num_groups, time_embed_dim):
        super().__init__()
        
        ### Time Embedding Expansion to Out Channels ###
        self.time_expand = nn.Linear(time_embed_dim, out_channels)

        ### Input Convolutions + GroupNorm ###
        self.groupnorm_1 = nn.GroupNorm(groupnorm_num_groups, in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding="same")

        ### Input + Time Embedding Convolutions + GroupNorm ###
        self.groupnorm_2 = nn.GroupNorm(groupnorm_num_groups, out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding="same")

        ### Residual Layer ###
        self.residual_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, time_embeddings):

        residual_connection = x

        ### Time Expansion to Out Channels ###
        time_embed = self.time_expand(time_embeddings)
        
        ### Input GroupNorm and Convolutions ###
        x = self.groupnorm_1(x)
        x = F.silu(x)
        x = self.conv_1(x)

        ### Add Time Embeddings ###
        x = x + time_embed.reshape((*time_embed.shape, 1, 1))

        ### Group Norm and Conv Again! ###
        x = self.groupnorm_2(x)
        x = F.silu(x)
        x = self.conv_2(x)

        ### Add Residual and Return ###
        x = x + self.residual_connection(residual_connection)
        
        return x
        

Upsampling#

The next step is easy. Like we had seen before, we just want to do a regular upsample using any type of interpolation by a factor of 2, and then use a convolutional layer afterwards that will return the same size as the upsampled image.

class UpSampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same')
        )

    def forward(self, inputs):
        batch, channels, height, width = inputs.shape
        upsampled = self.upsample(inputs)
        assert (upsampled.shape == (batch, channels, height*2, width*2))
        return upsampled

Building the UNet#

This is probably the hardest part, especially for someone who hasn’t made a UNet before. Lets take it a piece at a time! The things we need to let our UNet know are:

  • in_channels: How many input channels does our input image have? Probably 3 if RGB

  • start_dim: What is the starting dimension of the first convolution projection? There will be 3 channels into the first convolution, but how many channels do you want out? This should be 128 for a good model, but we will keep it 64 so its not too big and you can train it!

  • dim_mults: A tuple of expansion factors. For example if our start_dim is 64 and the dim_mults are (1,2,4), then the number of channels in the following conovlution blocks will be (64*1, 64*2, 64*4). Therefore our final output of the encoder will be 256. The decoder will also use the dim mults but in reverse order. Meaning it will start at 64*4 and work its way back to 64*1.

  • residual_blocks_per_group: After every down or upsample, we can repeat Residual blocks (which dont change any shape) multiple times to have more parameters in our model

  • time_embed_dim: What is the expected time embedding size from our SinusoidalTimeEmbeddings

High-level overview of some of the specifics of the UNet architecture#

For this drawing, we will only use 1 residual blocks per group just to keep it simple!

class UNET(nn.Module):
    def __init__(self, in_channels=3, start_dim=64, dim_mults=(1,2,4), residual_blocks_per_group=1, groupnorm_num_groups=16, time_embed_dim=128):
        super().__init__()

        #######################################
        ### COMPUTE ALL OF THE CONVOLUTIONS ###
        #######################################
        
        ### Store Number of Input channels from Original Image ###
        self.input_image_channels = in_channels
        
        ### Get Number of Channels at Each Block ###
        channel_sizes = [start_dim*i for i in dim_mults]
        starting_channel_size, ending_channel_size = channel_sizes[0], channel_sizes[-1]

        ### Compute the Input/Output Channel Sizes for Every Convolution of Encoder ###
        self.encoder_config = []
        
        for idx, d in enumerate(channel_sizes):
            ### For Every Channel Size add "residual_blocks_per_group" number of Residual Blocks that DONT Change the number of channels ###
            for _ in range(residual_blocks_per_group):
                self.encoder_config.append(((d, d), "residual")) # Shape: (Batch x Channels x Height x Width) -> (Batch x Channels x Height x Width)

            ### After Residual Blocks include Downsampling (by factor of 2) but dont change number of channels ###
            self.encoder_config.append(((d,d), "downsample")) # Shape: (Batch x Channels x Height x Width) -> (Batch x Channels x Height/2 x Width/2)

            ### Compute Attention ###
            self.encoder_config.append((d, "attention"))
            
            ### If we are not at the last channel size include a channel upsample (typically by factor of 2) ###
            if idx < len(channel_sizes) - 1:
                self.encoder_config.append(((d,channel_sizes[idx+1]), "residual")) # Shape: (Batch x Channels x Height x Width) -> (Batch x Channels*2 x Height x Width)
            
        ### The Bottleneck will have "residual_blocks_per_group" number of ResidualBlocks each with the input/output of our final channel size###
        self.bottleneck_config = []
        for _ in range(residual_blocks_per_group):
            self.bottleneck_config.append(((ending_channel_size, ending_channel_size), "residual"))

        ### Store a variable of the final Output Shape of our Encoder + Bottleneck so we can compute Decoder Shapes ###
        out_dim = ending_channel_size

        ### Reverse our Encoder config to compute the Decoder ###
        reversed_encoder_config = self.encoder_config[::-1]

        ### The output of our reversed encoder will be the number of channels added for residual connections ###
        self.decoder_config = []
        for idx, (metadata, type) in enumerate(reversed_encoder_config):
            ### Flip in_channels, out_channels with the previous out_dim added on ###
            if type != "attention":
                enc_in_channels, enc_out_channels = metadata
            
                self.decoder_config.append(((out_dim+enc_out_channels, enc_in_channels), "residual"))
                        
                if type == "downsample":
                    ### If we did a downsample in our encoder, we need to upsample in our decoder ###
                    self.decoder_config.append(((enc_in_channels, enc_in_channels), "upsample"))
    
                ### The new out_dim will be the number of output channels from our block (or the cooresponding encoder input channels) ###
                out_dim = enc_in_channels
            else:
                in_channels = metadata
                self.decoder_config.append((in_channels, "attention"))

        ### Add Extra Residual Block for residual from input convolution ###
        # hint: We know that the initial convolution will have starting_channel_size
        # and the output of our decoder will also have starting_channel_size, so the
        # final ResidualBlock we need will need to go from starting_channel_size*2 to starting_channel_size

        self.decoder_config.append(((starting_channel_size*2, starting_channel_size), "residual"))
        
        #######################################
        ### ACTUALLY BUILD THE CONVOLUTIONS ###
        #######################################

        ### Intial Convolution Block ###
        self.conv_in_proj = nn.Conv2d(self.input_image_channels, 
                                      starting_channel_size, 
                                      kernel_size=3, 
                                      padding="same")
        
        self.encoder = nn.ModuleList()
        for metadata, type in self.encoder_config:
            if type == "residual":
                in_channels, out_channels = metadata
                self.encoder.append(ResidualBlock(in_channels=in_channels,
                                                  out_channels=out_channels,
                                                  groupnorm_num_groups=groupnorm_num_groups,
                                                  time_embed_dim=time_embed_dim))
            elif type == "downsample":
                in_channels, out_channels = metadata
                self.encoder.append(
                    nn.Conv2d(in_channels, 
                              out_channels, 
                              kernel_size=3, 
                              stride=2, 
                              padding=1)
                    )
            elif type == "attention":
                in_channels = metadata
                self.encoder.append(TransformerBlock(in_channels))

        
        ### Build Encoder Blocks ###
        self.bottleneck = nn.ModuleList()
        
        for (in_channels, out_channels), _ in self.bottleneck_config:
            self.bottleneck.append(ResidualBlock(in_channels=in_channels,
                                                 out_channels=out_channels,
                                                 groupnorm_num_groups=groupnorm_num_groups,
                                                 time_embed_dim=time_embed_dim))

        ### Build Decoder Blocks ###
        self.decoder = nn.ModuleList()
        for metadata, type in self.decoder_config:
            if type == "residual":
                in_channels, out_channels = metadata
                self.decoder.append(ResidualBlock(in_channels=in_channels,
                                                  out_channels=out_channels,
                                                  groupnorm_num_groups=groupnorm_num_groups,
                                                  time_embed_dim=time_embed_dim))
            elif type == "upsample":
                in_channels, out_channels = metadata
                self.decoder.append(UpSampleBlock(in_channels=in_channels, 
                                                  out_channels=out_channels))

            elif type == "attention":
                in_channels = metadata
                self.decoder.append(TransformerBlock(in_channels))

        ### Output Convolution ###
        self.conv_out_proj = nn.Conv2d(in_channels=starting_channel_size, 
                                       out_channels=self.input_image_channels,
                                       kernel_size=3, 
                                       padding="same")

        
    def forward(self, x, time_embeddings):
        residuals = []

        ### Pass Through Projection and Store Residual ###
        x = self.conv_in_proj(x)
        residuals.append(x)

        ### Pass through encoder and store residuals ##
        for module in self.encoder:
            if isinstance(module, (ResidualBlock)):
                x = module(x, time_embeddings)
                residuals.append(x)
            elif isinstance(module, nn.Conv2d):
                x = module(x)
                residuals.append(x)
            else:
                x = module(x)

        ### Pass Through BottleNeck ###
        for module in self.bottleneck:
            x = module(x, time_embeddings)

        ### Pass through Decoder while Concatenating Residuals ###
        for module in self.decoder:
            if isinstance(module, ResidualBlock):
                residual_tensor = residuals.pop()
                x  = torch.cat([x, residual_tensor], axis=1)
                x = module(x, time_embeddings)
            else:
                x = module(x)

        ### Map back to num_channels for final output ###
        x = self.conv_out_proj(x)
        
        return x

m = UNET(start_dim=128, dim_mults=(1,2))
rand_image = torch.randn(4, 3, 256, 256)
rand_time_embeddings = torch.randn(4, 128)
out = m(rand_image, rand_time_embeddings)
out.shape
torch.Size([4, 3, 256, 256])

Putting it all together: The Diffusion Model#

Lets stick together all the modules we have constructed to define our final Diffusion model!

class Diffusion(nn.Module):
    def __init__(self, 
                 in_channels=3, 
                 start_dim=128, 
                 dim_mults=(1,2,4), 
                 residual_blocks_per_group=1, 
                 groupnorm_num_groups=16, 
                 time_embed_dim=128, 
                 time_embed_dim_ratio=2):

        super().__init__()
        self.in_channels = in_channels
        self.start_dim = start_dim
        self.dim_mults = dim_mults
        self.residual_blocks_per_group = residual_blocks_per_group
        self.groupnorm_num_groups = groupnorm_num_groups

        self.time_embed_dim = time_embed_dim
        self.scaled_time_embed_dim = int(time_embed_dim * time_embed_dim_ratio)

        self.sinusoid_time_embeddings = SinusoidalTimeEmbedding(time_embed_dim=self.time_embed_dim,
                                                                scaled_time_embed_dim=self.scaled_time_embed_dim)

        self.unet = UNET(in_channels=in_channels, 
                         start_dim=start_dim, 
                         dim_mults=dim_mults, 
                         residual_blocks_per_group=residual_blocks_per_group, 
                         groupnorm_num_groups=groupnorm_num_groups,  
                         time_embed_dim=self.scaled_time_embed_dim)

    def forward(self, noisy_inputs, timesteps):

        ### Embed the Timesteps ###
        timestep_embeddings = self.sinusoid_time_embeddings(timesteps)
        
        ### Pass Images + Time Embeddings through UNET ###
        noise_pred = self.unet(noisy_inputs, timestep_embeddings)

        return noise_pred

Some Helper Functions#

@torch.no_grad()
def sample_plot_image(step_idx, 
                      total_timesteps, 
                      sampler, 
                      image_size,
                      num_channels,
                      plot_freq, 
                      model,
                      num_gens,
                      path_to_generated_dir,
                      device):

    ### Conver Tensor back to Image (From Huggingface Annotated Diffusion) ###
    tensor2image_transform = transforms.Compose([
        transforms.Lambda(lambda t: t.squeeze(0)),
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)),
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    images = torch.randn((num_gens, num_channels, image_size, image_size))
    num_images_per_gen = (total_timesteps // plot_freq)

    images_to_vis = [[] for _ in range(num_gens)]
    for t in np.arange(total_timesteps)[::-1]:
        ts = torch.full((num_gens, ), t)
        noise_pred = model(images.to(device), ts.to(device)).detach().cpu()
        images = sampler.remove_noise(images, ts, noise_pred)
        if t % plot_freq == 0:
            for idx, image in enumerate(images):
                images_to_vis[idx].append(tensor2image_transform(image))


    images_to_vis = list(itertools.chain(*images_to_vis))

    fig, axes = plt.subplots(nrows=num_gens, ncols=num_images_per_gen, figsize=(num_images_per_gen, num_gens))
    plt.tight_layout()
    for ax, image in zip(axes.ravel(), images_to_vis):
        ax.imshow(image)
        ax.axis("off")
    fig.subplots_adjust(wspace=0.05, hspace=0.05)
    plt.savefig(os.path.join(path_to_generated_dir, f"step_{step_idx}.png"))
    plt.show()
    plt.close()

Get CelebA dataset#

wget_data('https://courses.physics.illinois.edu/phys503/fa2023/data/celeba.tgz')
File ‘./tmp_data/celeba.tgz’ already there; not retrieving.

Untar CelebA dataset#

subprocess.call( ['/usr/bin/pwd'])
subprocess.call( ['tar','zxvpf', './tmp_data/celeba.tgz', '--exclude=._*'] )
subprocess.call( ['mv', './celeba', './tmp_data/'] )

Train Function#

def train(image_size=64, 
        evaluation_interval=3750,
          total_timesteps=500, 
          plot_freq_interval=50, 
          num_generations=5, 
          num_training_steps=75000, 
          num_input_channels=3, 
          batch_size=64,
          path_to_generated="./tmp_data"): 
    
    torch.backends.cudnn.benchmark = True

    device = "cuda" if torch.cuda.is_available() else "cpu"
        
    ### Define Basic Image Transformations (From Huggingface Annotated Diffusion) ###
    image2tensor = transforms.Compose([
                    transforms.Resize((image_size, image_size)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(), 
                    transforms.Lambda(lambda t: (t*2) - 1)
                ])
    dataset = ImageFolder('./tmp_data/celeba', transform=image2tensor)
    trainloader = DataLoader(dataset, 
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=8,
#                             num_workers=0,
                            pin_memory=True)

    model = Diffusion(in_channels=num_input_channels).to(device)

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("Number of Parameters:", params)

    ### MODEL TRAINING INPUTS ###
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=0.0005)
    scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer, 
                                                num_warmup_steps=2500, 
                                                num_training_steps=num_training_steps)

    ddpm_sampler = Sampler(num_training_steps=total_timesteps)

    loss_fn = nn.MSELoss()

    progress_bar = tqdm(range(num_training_steps))
    completed_steps = 0

    train = True
    while train:
        training_losses = []
        for images, _ in trainloader:
            batch_size = images.shape[0]
        
            ### Random Sample T ###
            timesteps = torch.randint(0,total_timesteps,(batch_size,))
        
            ### Get Noisy Images ###
            noisy_images, noise = ddpm_sampler.add_noise(images, timesteps)
        
            ### Get Noise Prediction ###
            noise_pred = model(noisy_images.to(device), timesteps.to(device))

            ### Compute Error ###
            loss = loss_fn(noise_pred, noise.to(device))

            training_losses.append(loss.cpu().item())
            
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)

            progress_bar.update(1)
            completed_steps += 1

            if (completed_steps % evaluation_interval == 0):
                loss_mean = np.mean(training_losses)
                print("Training Loss:", loss_mean)
                print("Learning Rate:", optimizer.param_groups[-1]["lr"])

                training_losses = []
                print("Saving Image Generation")
                sample_plot_image(step_idx=completed_steps, 
                                  total_timesteps=total_timesteps, 
                                  sampler=ddpm_sampler, 
                                  image_size=image_size,
                                  num_channels=num_input_channels,
                                  plot_freq=plot_freq_interval, 
                                  model=model,
                                  num_gens=num_generations,
                                  path_to_generated_dir=path_to_generated,
                                  device=device)
                
            if completed_steps >= num_training_steps:
                print("Training Completed!!!")
                train = False
                break
train()
Number of Parameters: 52956419
Training Loss: 0.027614978325476693
Learning Rate: 0.0004996333534627809
Saving Image Generation
../../_images/ff41ffac63aa22577dd837f196d252226fa5d6f5bf8213e54daf55f6372c6c17.png
Training Loss: 0.025593497773651545
Learning Rate: 0.0004941551389275216
Saving Image Generation
../../_images/e0b69b0d6a70fca2a8b619fdc6db6de9bc9ffc508fbcebb5b61a43b787de13a2.png
Training Loss: 0.025227434365461543
Learning Rate: 0.0004822441799541979
Saving Image Generation
../../_images/1e8a13fcc3751b52a433bd5fc5d136b0b31618fd0d3c7584e2830e384123e91e.png
Training Loss: 0.024866870749177258
Learning Rate: 0.0004642142940418973
Saving Image Generation
../../_images/5e5edb216f6220d6d1e016a89ede0cc5f5109c6324edfddde4c2338ef904dcb4.png
Training Loss: 0.024508819502956962
Learning Rate: 0.0004405405137819091
Saving Image Generation
../../_images/c8c903ef6d598e83bb42f164b6ffdd7caae6ad1851230c24230e68fb3211e83a.png
Training Loss: 0.02422389824263002
Learning Rate: 0.00041184657119545693
Saving Image Generation
../../_images/50aedc8104b12c41be5f591efe347333fab005e7d2dba51d41aea46efa18cf5b.png
Training Loss: 0.023754292109136727
Learning Rate: 0.00037888846429425546
Saving Image Generation
../../_images/7a1fd77f3df1b81462b68773e8e6d30ac0368b5598fedba375b9d10e6a9ba811.png
Training Loss: 0.023777675831005235
Learning Rate: 0.00034253453883497867
Saving Image Generation
../../_images/f70691fa0a01c34fe57f1194b5d5bd3c85811ce7da25eba9fa9ef6cef0e5366c.png
Training Loss: 0.023423723777551114
Learning Rate: 0.00030374261005275605
Saving Image Generation
../../_images/44d76c4737cd3e49433ea58633ce0e0ee6540d0563a0c9c5627e7f6bf007f192.png
Training Loss: 0.023408989196680847
Learning Rate: 0.0002635347271463544
Saving Image Generation
../../_images/a580cb768b8a775a0c40c9ed72641141e6f500ad515a155cbdf4453148ad060a.png
Training Loss: 0.02353073466245247
Learning Rate: 0.0002229702453940146
Saving Image Generation
../../_images/2c36905e6dd32660501a9349017cf30b0ee6625aeab4b0c5275d73185822450c.png
Training Loss: 0.022872181113302355
Learning Rate: 0.00018311791536769484
Saving Image Generation
../../_images/710fb0613f5b55003fa36b2966860db0bc281e5785e7aa2fe03d15cbfbd782f9.png
Training Loss: 0.022874412042576644
Learning Rate: 0.00014502772460993384
Saving Image Generation
../../_images/ef5ff19642d531281aa2fba471c6a9ebe0f7076b857abaf36f9322f203a87526.png
Training Loss: 0.022812383052935408
Learning Rate: 0.00010970323365940444
Saving Image Generation
../../_images/865f65895acb03b4959e706b41655073df77789667dd7b0a3290ca4084ba1045.png
Training Loss: 0.022587639776394507
Learning Rate: 7.807513528664414e-05
Saving Image Generation
../../_images/e303cfc405670e99c0eaf52529288d3e3d9141fc3f8da5712b52a6d58e447ff4.png
Training Loss: 0.022739307398649346
Learning Rate: 5.097673357358906e-05
Saving Image Generation
../../_images/080e1bb6d74a1881ad2a10a528bf3762e339e792fa7c14ca8f19e1fe14f9a6fa.png
Training Loss: 0.0229581450047188
Learning Rate: 2.9121988888494292e-05
Saving Image Generation
../../_images/911b1360eb5ac6b02eb247d22ab3a8e19c61a759237aa2ebffe59579dace8ac9.png
Training Loss: 0.02254007502479196
Learning Rate: 1.3086707204299414e-05
Saving Image Generation
../../_images/81a4c492c0e4811ee6346896f13e05aa76563d3e63562739440054316a9a4270.png
Training Loss: 0.02248887126526337
Learning Rate: 3.293369364618465e-06
Saving Image Generation
../../_images/e5799c3845ac8c1e72bbdab4ef4eb52912a501c232c7eeb9e12a0c51e38a50c8.png
Training Loss: 0.022647713878687973
Learning Rate: 0.0
Saving Image Generation
../../_images/63d444deae24ad5eed1e6ed3c59f63159c0da3c9e21e7e4274356993e71ef6c3.png
Training Completed!!!

Acknowledgments#

  • Initial version: Mark Neubauer

© Copyright 2025