Home

Diffusion Models

1. Introduction

The goal of generative models is to learn the underlying distribution of data and use it to generate new, realistic samples. Traditional generative models, such as Generative Adversarial Networks (GANs) and Variational Autoencoders (VAEs), face specific challenges:

Diffusion models are a class of generative models that offer a new solution, combining stability and high-quality generation. These models generate data by reversing a gradual noise addition process. The model learns to undo the noise step-by-step, transforming noise into structured data, allowing for more robust generation.

The intuition is that instead of trying to directly generate data from a complex distribution in one step (as in GANs), the model does so progressively.

2. Key Concepts

2.1 Forward Process (Diffusion)

At each timestep \( t \), the noisy data \( x_t \) is a linear combination of the clean data \( x_0 \) and noise \( \epsilon \): $$ x_t = \sqrt{\alpha_{\text{cumprod}, t}}x_0 + \sqrt{1 - \alpha_{\text{cumprod}, t}}\epsilon$$ where:
    
    def forward_diffusion(x_0, t, alphas_cumprod, device):
        noise = torch.randn_like(x_0).to(device)
        sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod[t]).view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod[t]).view(-1, 1, 1, 1)
        x_t = sqrt_alphas_cumprod * x_0 + sqrt_one_minus_alphas_cumprod * noise
        return x_t, noise
    
    

2.2 Reverse Process (Denoising)

The reverse process is modeled as a conditional distribution: $$p_{\theta}(x_{t-1}|x_{t}) = \mathcal{N}(x_{t-1}; \mu_{\theta}(x_t,t), \sum_{\theta}(x_t,t))$$ where, Each step in the reverse process predicts the less noisy sample at time \( \textit{t} \) - 1 from the current sample \( \textit{x}_t \), based on the following decomposition: $$x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \alpha_{\text{cumprod}, t}}} \epsilon_\theta(x_t, t) \right) + \sigma_t z$$ where,
    
    def reverse_process(model, x_T, betas, alphas_cumprod, device, timesteps):
    
        x = x_T
        for t in reversed(range(timesteps)):
            t_tensor = torch.full((x.size(0),), t, dtype=torch.long).to(device)
            predicted_noise = model(x, t_tensor)
            beta = betas[t]
            alpha = 1 - beta
            alpha_cumprod = alphas_cumprod[t]
            sqrt_recip_alpha = 1 / torch.sqrt(alpha)
        
            x = sqrt_recip_alpha.view(-1, 1, 1, 1) * (x - (beta / torch.sqrt(1 - alpha_cumprod)).view(-1, 1, 1, 1) * predicted_noise)
        
            if t > 0:
                noise = torch.randn_like(x).to(device)
                sigma = torch.sqrt(beta).view(-1, 1, 1, 1)
                x = x + sigma * noise
        
        return x
    
    

2.3 Markov Process

Both the forward and reverse processes are modeled as Markov chains, meaning the state of the system at each timestep \( t \) depends only on the previous step \( t \) - 1. This makes it easier to define both processes using simple transition distributions.

2.4 Noise Schedule

The noise schedule \( \beta_t \) controls how much noise is added at each time step. It is important for stabilizing training and ensuring the model can reverse the diffusion process.

Popular choice for \( \beta_t \) include:

2.5 Loss Function

The model is trained to predict the noise added at each timestep using a mean squared error (MSE) loss.

The training objective is: $$\mathcal{L(\theta)} = \mathbb{E}_{x_0,\epsilon,t}\left[ \left\| \epsilon - \epsilon_{\theta}(x_t, t) \right\|^2\right]$$
This objective encourages the network to predict the noise added at each step, allowing it to reverse the noise addition process.

2.6 Training:

To train a diffusion model, we sample a timestep \( t \), a data point \( x_0 \), and a noise vector \( \epsilon \), then optimize the loss:

  1. Sample: \( t \) ~ \( Uniform(1, T) \)
  2. Generate noisy data: \( x_t = \sqrt{\overset{-}\alpha_t}x_0 + \sqrt{1 - \overset{-}\alpha_t}\epsilon \)
  3. Train the model to predict the noise: \( \mathcal{L(\theta)} = \mathbb{E}_{x_0,\epsilon,t}\left[ \left\| \epsilon - \epsilon_{\theta}(x_t, t) \right\|^2\right] \)

2.7 Sampling:

After training, new samples can be generated from the diffusion model by reversing the forward process:

  1. Start with random noise: \(x_T \) ~ \( \mathcal{N}(0, I) \)
  2. For each timestep \( t = T, T - 1, ..., 1 \):
    • Sample \( x_{t-1} \) ~ \( p_{\theta}(x_{t-1}|x_{t}) \)
  3. The final sample \( x_0 \) will be a realistic sample from the data distribution.
The iterative denoising process allows for high-quality and diverse data generation.
        
    
    def generate_and_save_samples(model, betas, alphas_cumprod, device, timesteps, num_samples=16):
        model.eval()
        with torch.no_grad():
          # Start from pure noise
          x_T = torch.randn(num_samples, 1, 128, 128).to(device)
      
          # Generate samples by reverse diffusion
          x_0 = reverse_diffusion(model, x_T, betas, alphas_cumprod, device, timesteps)
      
          # Clamp the generated image to [-1, 1]
          x_0 = torch.clamp(x_0, -1, 1)
      
          # Rescale to [0, 1]
          x_0 = (x_0 + 1) / 2
    
    

4. Conclusion

Diffusion models offer a powerful approach to generative modeling, focusing on progressively denoising noisy data to recover structured outputs. Their stability and quality make them a better alternative to GANs and VAEs, especially in generating high-quality images and other data modalities.

The full code is available to run from this notebook Google Colab.

5. References

  1. Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in neural information processing systems 33 (2020): 6840-6851.
  2. Nichol, Alexander Quinn, and Prafulla Dhariwal. "Improved denoising diffusion probabilistic models." International conference on machine learning. PMLR, 2021.
  3. Dhariwal, Prafulla, and Alexander Nichol. "Diffusion models beat gans on image synthesis." Advances in neural information processing systems 34 (2021): 8780-8794.