The Art of Spread: Crafting Your Own Diffusion Model

Matthew MacFarquhar
12 min readFeb 21, 2024

--

Introduction

Diffusion models in machine learning represent a cutting-edge approach to image generation, offering a sophisticated framework for synthesizing high-quality images with fine-grained control. Unlike traditional generative models, diffusion models leverage iterative refinement of data over time, enabling the generation of realistic images through controlled diffusion processes. These models have garnered significant attention for their ability to produce sharp, diverse, and coherent images, revolutionizing applications in computer vision, art generation, and beyond. In this article, we delve into the principles and techniques behind diffusion models, exploring their capabilities and implications for the future of image synthesis in machine learning.

How do Diffusion models work?

Diffusion

Diffusion models get their name from the chemical process of diffusion. In the chemical process, an area of highly concentrated molecules diffuses or spreads out in space. If we think of our set of meaningful, non-noisy images in say 128x128 pixels as a tight cluster in a 16,384 dimension space, we can diffuse the image across the space by adding some random noise to it until it is just a noisy entry in our 16,384 dimension space.

Why is that useful for image generation? Well diffusion models actually learn how to reverse this diffusion process and give us a sensical 128x128 image even though we may start out in some random location in our 16,384 hypercube.

Our Diffusion model has a forward pass to slowly add noise, and then a backward pass to slowly predict the noise and remove it.

Forward Process

In our forward process, we take in the image and the time step and generate an even noisier image by adding noise from a standard normal gaussian distribution. We do this for many timestamps, until the image is completely noise.

Luckily, we can precompute the noise amount for every time-step, so we do not have to iteratively sample these and can generate a noisy image from any random timestamp we wish for training. So in practice, we don’t even do the forward piece, since we can just generate our noisy image during our train loop.

Backward Pass

This is where we do our learning, we feed in an image and a time t. We then generate some random noise in the shape of our desired image— Zt — and run it through a U-Net to predict the noise we can remove to take Zt to Zt-1. We will train for lots and lots of different t so that our U-Net will learn how to take away a little bit of noise no matter where our input falls on the image to noise spectrum. Then, we will be able to be call our model over and over again to de-noise the image bit by bit from Zt (completely noisy) to Z0 (a sensible image).

Model Architecture

Time Embedding

To input our time steps into training, we use a positional encoding. Positional encodings are really great when we need to add some sort of attention module to our network (which we will).

Positional encodings have periodicity for attention. Since the cos and sine functions are repetitive, it makes it very easy for position n to pay attention to position n+k where k is the phase full phase shift of sin or cosine. These functions are also nicely constrained between -1 and 1 which allows for consistently sized values. If we didn’t use sin and cos, then some of the later embedding values would be much smaller than the earlier ones.

UNet

Our UNet will be a classic UNet structure with some multi-head attention blocks. We have three types of blocks: Down, Mid and Up each of these blocks can be repeated multiple times. Down blocks down sample the image resolution while increasing the number of channels, mid blocks maintain consistent size and allow our model to learn more at the lower dimension latent space, and up blocks reduce our number of channels but upscale the image back to our original resolution. The only caveat is that the up and down blocks must be repeated the same amount of times because they have residual connections (the red lines) between corresponding Up and Down blocks.

Down Blocks

In our Down Blocks, we preform two simple convolution blocks and save the outputs to use as residual connections for our corresponding UpBlocks later. In between the convolutions, we inject our time embeddings in. We then pass our data through a MultiHeadAttention block — which will create more contextually aware data — and finally downsample using another convolution layer.

Mid Blocks

In our UNet’s bottleneck, we have a convolution block followed by an attention block and finally another convolution block. Like the DownBlock, the convolution block is two convolutions with the time embeddings injected in-between. None of these convolution blocks do any downsampling so our feature maps remain the same size as they were on input.

Up Blocks

Our Up Blocks will concat their corresponding Down blocks residual feature maps to the data. We then go through one convolution layer (with the time embedding injection) and one attention layer and then up sample using a Conv2DTranspose Layer.

Code

Noise Scheduler

We need to add different amounts of noise depending on our sampled time step. We will encapsulate this logic in its own class.

import torch


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class LinearNoiseScheduler:
def __init__(self, num_timesteps, beta_start, beta_end):
self.num_timesteps = num_timesteps
self.beta_start = beta_start
self.beta_end = beta_end

self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
self.alphas = 1. - self.betas
self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
self.sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(device)
self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1. - self.alpha_cum_prod)
self.sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(device)

def add_noise(self, original, noise, t):
original_shape = original.shape
batch_size = original_shape[0]

sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod[t].reshape(batch_size)
sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod[t].reshape(batch_size)

for _ in range(len(original_shape)-1):
sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)

return sqrt_alpha_cum_prod*original + sqrt_one_minus_alpha_cum_prod * noise

def sample_prev_timestep(self, xt, noise_pred, t):
x0 = (xt - self.sqrt_one_minus_alpha_cum_prod[t] * noise_pred) / self.sqrt_alpha_cum_prod[t]
x0 = torch.clamp(x0, -1., 1.)

mean = xt - ((self.betas[t] * noise_pred) / self.sqrt_one_minus_alpha_cum_prod[t])
mean = mean / torch.sqrt(self.alphas[t])

if t == 0:
return mean, x0
else:
variance = (1 - self.alpha_cum_prod[t-1]) / (1 - self.alpha_cum_prod[t])
variance = variance * self.betas[t]
sigma = variance ** 0.5
z = torch.randn(xt.shape).to(xt.device)
return mean + sigma*z, x0

In our constructor, we pre-compute the noise amount to add at each time step (stored in the sqrt_one_minus_alpha_cum_prod array) and the amount of the original image to use (stored in the sqrt_alpha_cum_prod array). The noise formula is included below.

add_noise take in the time step t, our original image and a same dimension image of just random noise. Then, we get the the weight of how much noise should be included using the pre-computed arrays we created to generate a weighted noise image to return.

sample_prev_timestep takes in the noisy image Xt, the time step t and the predicted noise. It then subtracts the predicted noise and divides by the stored sqrt_alpha_cum_prod[t] value to get a predicted original image X0. At time step 0, we return the mean at Xt. At any other time, we also get the variance and return the mean + some random value multiplied by the standard deviation to simulate a gaussian.

UNet

import torch
import torch.nn as nn


def get_time_embedding(time_steps, t_emb_dim):
factor = 10000 ** (torch.arange(start=0, end=t_emb_dim // 2, device=time_steps.device) / (t_emb_dim // 2))

t_emb = time_steps[:, None].repeat(1, t_emb_dim // 2) / factor
return torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)


class DownBlock(nn.Module):
def __init__(self, in_channels, out_channels, t_emb_dim, down_sample, num_heads):
super().__init__()
self.down_sample = down_sample

self.resnet_conv_first = nn.Sequential(
nn.GroupNorm(8, in_channels), # Why 8?
nn.SiLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
)

self.t_emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)

self.resnet_conv_second = nn.Sequential(
nn.GroupNorm(8, out_channels), # Why 8?
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
)

self.attention_norm = nn.GroupNorm(8, out_channels)
self.attention = nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
self.residual_input_conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
self.down_sample_conv = nn.Conv2d(out_channels, out_channels, kernel_size=(4, 4), stride=(2, 2), padding=1) if self.down_sample else nn.Identity()

def forward(self, x, t_emb):
out = x

# Resnet block
resnet_input = out
out = self.resnet_conv_first(out)
out = out + self.t_emb_layer(t_emb)[:, :, None, None]
out = self.resnet_conv_second(out)
out = out + self.residual_input_conv(resnet_input)

# Attention block
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h*w)
in_attn = self.attention_norm(in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attention(in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1,2).reshape(batch_size, channels, h, w)
out = out + out_attn

out = self.down_sample_conv(out)
return out


class MidBlock(nn.Module):
def __init__(self, in_channels, out_channels, t_emb_dim, num_heads):
super().__init__()

self.resnet_conv_first = nn.ModuleList([
nn.Sequential(
nn.GroupNorm(8, in_channels), # Why 8?
nn.SiLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
),
nn.Sequential(
nn.GroupNorm(8, out_channels), # Why 8?
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
)
])


self.t_emb_layer = nn.ModuleList([
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
),
nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)
])

self.resnet_conv_second = nn.ModuleList([
nn.Sequential(
nn.GroupNorm(8, out_channels), # Why 8?
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
),
nn.Sequential(
nn.GroupNorm(8, out_channels), # Why 8?
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
)
])


self.attention_norm = nn.GroupNorm(8, out_channels)
self.attention = nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
self.residual_input_conv = nn.ModuleList([
nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)),
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 1))
])

def forward(self, x, t_emb):
out = x

# first Resnet block
resnet_input = out
out = self.resnet_conv_first[0](out)
out = out + self.t_emb_layer[0](t_emb)[:, :, None, None]
out = self.resnet_conv_second[0](out)
out = out + self.residual_input_conv[0](resnet_input)

# Attention block
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h*w)
in_attn = self.attention_norm(in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attention(in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1,2).reshape(batch_size, channels, h, w)
out = out + out_attn

# second ResnetBlock
resnet_input = out
out = self.resnet_conv_first[1](out)
out = out + self.t_emb_layer[1](t_emb)[:, :, None, None]
out = self.resnet_conv_second[1](out)
out = out + self.residual_input_conv[1](resnet_input)

return out


class UpBlock(nn.Module):
def __init__(self, in_channels, out_channels, t_emb_dim, up_sample, num_heads):
super().__init__()
self.up_sample = up_sample

self.resnet_conv_first = nn.Sequential(
nn.GroupNorm(8, in_channels), # Why 8?
nn.SiLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
)

self.t_emb_layer = nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, out_channels)
)

self.resnet_conv_second = nn.Sequential(
nn.GroupNorm(8, out_channels), # Why 8?
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1)
)

self.attention_norm = nn.GroupNorm(8, out_channels)
self.attention = nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
self.residual_input_conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
self.up_sample_conv = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) if self.up_sample else nn.Identity()

def forward(self, x, out_down, t_emb):
x = self.up_sample_conv(x)
x = torch.cat([x, out_down], dim=1)
out = x

# Resnet block
resnet_input = out
out = self.resnet_conv_first(out)
out = out + self.t_emb_layer(t_emb)[:, :, None, None]
out = self.resnet_conv_second(out)
out = out + self.residual_input_conv(resnet_input)

# Attention block
batch_size, channels, h, w = out.shape
in_attn = out.reshape(batch_size, channels, h*w)
in_attn = self.attention_norm(in_attn)
in_attn = in_attn.transpose(1, 2)
out_attn, _ = self.attention(in_attn, in_attn, in_attn)
out_attn = out_attn.transpose(1,2).reshape(batch_size, channels, h, w)
out = out + out_attn

return out


class Unet(nn.Module):
def __init__(self, im_channels):
super().__init__()
self.down_channels = [32, 64, 128, 256]
self.mid_channels = [256, 256, 128]
self.up_channels = [128, 64, 32, 16]
self.t_emb_dim = 128
self.down_sample = [True, True, False]
self.up_sample = [False, True, True]

self.t_proj = nn.Sequential(
nn.Linear(self.t_emb_dim, self.t_emb_dim),
nn.SiLU(),
nn.Linear(self.t_emb_dim, self.t_emb_dim),
)
self.conv_in = nn.Conv2d(im_channels, self.down_channels[0], kernel_size=(3, 3), padding=1)

self.downs = nn.ModuleList([])
for i in range(len(self.down_channels) - 1):
self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i+1], self.t_emb_dim, down_sample=self.down_sample[i], num_heads=4))

self.mids = nn.ModuleList([])
for i in range(len(self.mid_channels) - 1):
self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i+1], self.t_emb_dim, num_heads=4))

self.ups = nn.ModuleList([])
for i in range(len(self.up_channels) - 1):
self.ups.append(UpBlock(self.up_channels[i] * 2, self.up_channels[i+1], self.t_emb_dim, up_sample=self.up_sample[i], num_heads=4))

self.norm_out = nn.GroupNorm(8, 16)
self.conv_out = nn.Conv2d(16, im_channels, kernel_size=(3, 3), padding=1)

def forward(self, x, t):
out = self.conv_in(x)
t_emb = get_time_embedding(t, self.t_emb_dim)
t_emb = self.t_proj(t_emb)

down_outs = []
for down in self.downs:
down_outs.append(out)
out = down(out, t_emb)

for mid in self.mids:
out = mid(out, t_emb)

for up in self.ups:
down_out = down_outs.pop()
out = up(out, down_out, t_emb)

out = self.norm_out(out)
out = nn.SiLU()(out)
out = self.conv_out(out)
return out

Our UNet uses the down, mid and up blocks we talked about, we do four down blocks, three mid blocks and four up blocks. The numbers in the arrays are the numbers of channels (i.e. the feature maps for our convolution). We have two down sampling layers meaning at the smallest dimension our feature maps are 7x7 since we halve the size of the feature maps each downsample. This model will take in an image X and a time step t and predict the amount of noise that must be removed to get to the original image.

Training

We specify a configuration yaml file which we inject into training and sampling to make it very easy to test out different hyper-parameters. Below is the setup I used.

diffusion_params:
num_timesteps: 1000
beta_start: 0.0001
beta_end: 0.02

model_params:
im_channels: 1
im_size: 28

train_params:
task_name: 'default'
batch_size: 64
num_epochs: 40
num_samples: 100
num_grid_rows: 10
lr: 0.0001
ckpt_name: 'ddpm_ckpt.pth'

Now we will jump into the actual training code.

import argparse
import os
import numpy as np

import torch
import yaml
from torch.optim import Adam
from torchvision.transforms import transforms

from LinearNoiseScheduler import LinearNoiseScheduler
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from tqdm import tqdm

from Unet import Unet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def train(args):
with open(args.config_path, 'r') as file:
config = yaml.safe_load(file)
print(config)

diffusion_config = config['diffusion_params']
model_config = config['model_params']
train_config = config['train_params']

scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], beta_start=diffusion_config['beta_start'], beta_end=diffusion_config['beta_end'])

trfms = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]
)
mnist = datasets.MNIST(root="dataset/", transform=trfms, download=True, train=True)
mnist_loader = DataLoader(mnist, batch_size=train_config['batch_size'], shuffle=True)

model = Unet(model_config['im_channels']).to(device)
model.train()

if not os.path.exists(train_config['task_name']):
os.mkdir(train_config['task_name'])

if os.path.exists(os.path.join(train_config['task_name'], train_config['ckpt_name'])):
print('Loading checkpoint')
model.load_state_dict(torch.load(os.path.join(train_config['task_name'], train_config['ckpt_name']), map_location=device))

num_epochs = train_config['num_epochs']
optimizer = Adam(model.parameters(), lr=train_config['lr'])
criterion = torch.nn.MSELoss()

for epoch_idx in range(num_epochs):
losses = []
for im, _ in tqdm(mnist_loader):
optimizer.zero_grad()
im = im.float().to(device)

noise = torch.randn_like(im).to(device)
t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device)

noisy_im = scheduler.add_noise(im, noise, t)
noise_pred = model(noisy_im, t)

loss = criterion(noise_pred, noise)
losses.append(loss.item())
loss.backward()
optimizer.step()

print('Finished epoch:{} | Loss: {:.4f}'.format(epoch_idx + 1, np.mean(losses)))
torch.save(model.state_dict(), os.path.join(train_config['task_name'], train_config['ckpt_name']))


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Arguments for ddpm training')
parser.add_argument('--config', dest='config_path',
default='default.yaml', type=str)
args = parser.parse_args()
train(args)

We will be building an MNIST generating diffusion model, so we will use the MNIST data set. We load up our config file and instantiate our noise scheduler, model and our optimizer and loss functions.

Then we will start the training loop. We get an image, generate some random noise in the shape of the image and then pick a random t value. We use the noise, image and t to generate a noisy image corresponding to time step t. Then, we predict the noise amount and train using MSE between the predicted and actual noises.

Sampling

import argparse

import torchvision
import yaml
import torch
import os
from tqdm import tqdm
from torchvision.utils import make_grid

from LinearNoiseScheduler import LinearNoiseScheduler
from Unet import Unet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def sample(model, scheduler, train_config, model_config, diffusion_config):
r"""
Sample stepwise by going backward one timestep at a time.
We save the x0 predictions
"""
xt = torch.randn((train_config['num_samples'],
model_config['im_channels'],
model_config['im_size'],
model_config['im_size'])).to(device)
for i in tqdm(reversed(range(diffusion_config['num_timesteps']))):
# Get prediction of noise
noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))

# Use scheduler to get x0 and xt-1
xt, _ = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))

# Save x0
ims = torch.clamp(xt, -1., 1.).detach().cpu()
ims = (ims + 1) / 2
grid = make_grid(ims, nrow=train_config['num_grid_rows'])
img = torchvision.transforms.ToPILImage()(grid)
if not os.path.exists(os.path.join(train_config['task_name'], 'samples')):
os.mkdir(os.path.join(train_config['task_name'], 'samples'))
img.save(os.path.join(train_config['task_name'], 'samples', 'x0_{}.png'.format(i)))
img.close()

def infer(args):
with open(args.config_path, 'r') as file:
config = yaml.safe_load(file)

print(config)
diffusion_config = config['diffusion_params']
model_config = config['model_params']
train_config = config['train_params']

model = Unet(model_config['im_channels']).to(device)
model.load_state_dict(torch.load(os.path.join(train_config['task_name'], train_config['ckpt_name']), map_location=device))

model.eval()

scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], beta_start=diffusion_config['beta_start'], beta_end=diffusion_config['beta_end'])

with torch.no_grad():
sample(model, scheduler, train_config, model_config, diffusion_config)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Arguments for ddpm image generation')
parser.add_argument('--config', dest='config_path',
default='default.yaml', type=str)
args = parser.parse_args()
infer(args)

For sampling, We first load up our saved model checkpoint and set up our noise scheduler. Then we will generate n — dictated by the config yaml — random noise images of size 28x28 with one color channel. We will start at time step T and use our model to predict the noise that was added to go from T-1 to T. Then we update our image to by subtracting out a small amount of the predicted noise, giving us the image at T-1. We continue to do this until we get the images for time step 0 which are our generated MNIST images.

Conclusion

Below are the samples at timesteps 1000, 800, 600, 400, 200, and 0 in that order.

Congratulations! I hope this toy diffusion model has given you some intuition on how these models work. By predicting small amounts of noise to remove at each time step, we can reverse the diffusion process and generate high quality, reasonable images from noise in a reliable way.

--

--

Matthew MacFarquhar
Matthew MacFarquhar

Written by Matthew MacFarquhar

I am a software engineer working for Amazon living in SF/NYC.

No responses yet