Mathematics Behind Deep Generative Models: VAEs, GANs, and Diffusion Models
Raj Shaikh 9 min read 1864 words1. Variational Autoencoders (VAEs): The Thoughtful Creator 🧠🎭
What Is a VAE?
A Variational Autoencoder (VAE) is a neural network that learns to compress data into a structured representation (latent space) and then recreate (generate) it. Think of it as AI’s version of daydreaming—it sees an image of a cat 🐱, imagines a different but similar cat, and then draws it.
How Does a VAE Work?
VAEs have two main parts:
- Encoder: Compresses input into a latent space (a compact numerical representation).
- Decoder: Reconstructs data from the latent space.
The twist? VAEs don’t just compress—they learn a distribution over the latent space, allowing them to generate new data by sampling from this distribution.
The VAE Workflow
- Input Data (\( x \)): Feed it into the encoder.
- Latent Space Distribution:
- The encoder doesn’t produce a single point; it estimates a mean (\( \mu \)) and variance (\( \sigma^2 \)) for a distribution.
- Sampling:
- Sample \( z \) (latent vector) from the distribution: \[ z = \mu + \sigma \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, 1) \] This clever trick (called the reparameterization trick) makes VAEs trainable!
- Reconstruction:
- Feed \( z \) into the decoder to reconstruct the input \( x' \).
Objective Function: ELBO
VAEs optimize the Evidence Lower Bound (ELBO), which balances:
-
Reconstruction Loss:
- How well the decoder reconstructs \( x \): \[ \text{Reconstruction Loss} = \|x - x'\|^2 \]
-
KL Divergence:
- Ensures the latent space distribution is close to a standard normal: \[ D_{\text{KL}}(q(z|x) \parallel p(z)) = \int q(z|x) \log \frac{q(z|x)}{p(z)} dz \]
The full objective:
\[ \mathcal{L}_{\text{VAE}} = \text{Reconstruction Loss} + \text{KL Divergence} \]Numerical Example
Suppose we’re encoding an image of a cat:
- Encoder estimates \( \mu = [0.5, -1.0] \), \( \sigma = [0.2, 0.1] \).
- Sample \( z \): \[ z = \mu + \sigma \cdot \epsilon = [0.5, -1.0] + [0.2, 0.1] \cdot [\epsilon_1, \epsilon_2] \] If \( \epsilon = [1.0, -0.5] \), then: \[ z = [0.7, -1.05] \]
- Decoder reconstructs a similar-but-new cat image. 🎉
Why VAEs Matter
- Data Generation:
- Generate new samples, like unique faces or handwritten digits.
- Representation Learning:
- Learn compact, meaningful representations for tasks like clustering.
- Smooth Latent Space:
- Unlike traditional autoencoders, VAEs ensure neighboring points in the latent space produce similar outputs.
Code Example: VAE in Python
Here’s a simplified example using PyTorch:
import torch
import torch.nn as nn
import torch.optim as optim
# Encoder
class Encoder(nn.Module):
def __init__(self, input_dim, latent_dim):
super().__init__()
self.fc_mu = nn.Linear(input_dim, latent_dim)
self.fc_var = nn.Linear(input_dim, latent_dim)
def forward(self, x):
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
# Decoder
class Decoder(nn.Module):
def __init__(self, latent_dim, output_dim):
super().__init__()
self.fc = nn.Linear(latent_dim, output_dim)
def forward(self, z):
return self.fc(z)
# VAE Model
class VAE(nn.Module):
def __init__(self, input_dim, latent_dim):
super().__init__()
self.encoder = Encoder(input_dim, latent_dim)
self.decoder = Decoder(latent_dim, input_dim)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
epsilon = torch.randn_like(std)
return mu + std * epsilon
def forward(self, x):
mu, log_var = self.encoder(x)
z = self.reparameterize(mu, log_var)
return self.decoder(z), mu, log_var
# Loss Function
def vae_loss(x, x_recon, mu, log_var):
recon_loss = nn.MSELoss()(x_recon, x)
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return recon_loss + kl_loss
# Example usage
input_dim = 10
latent_dim = 2
vae = VAE(input_dim, latent_dim)
optimizer = optim.Adam(vae.parameters(), lr=0.001)
x = torch.randn((5, input_dim)) # Example input
x_recon, mu, log_var = vae(x)
loss = vae_loss(x, x_recon, mu, log_var)
print("Loss:", loss.item())
Fun Analogy
A VAE is like a recipe generator 🍰:
- Encoder: Reads a cake recipe and boils it down to “sweet, chocolatey, layered.”
- Latent Space: Randomly tweaks “sweet” to “sweeter” or “chocolatey” to “nutty.”
- Decoder: Writes a new recipe based on these tweaks. The result? A delicious variation of the original cake! 🎂
Mermaid.js Diagram: VAE Workflow
graph TD Input[Input Data x] --> Encoder[Encoder: Estimate mu and sigma] Encoder --> LatentSpace[Sample z from Latent Space] LatentSpace --> Decoder[Decoder: Reconstruct x'] Decoder --> Output[Output Data x'] Input --> Loss[Compute Loss] Output --> Loss LatentSpace --> Loss
2. Generative Adversarial Networks (GANs): The Creative Frenemies 🎭🤖
What Are GANs?
A Generative Adversarial Network (GAN) consists of two neural networks:
- Generator: The Picasso 🎨 that creates fake data (images, text, etc.).
- Discriminator: The Sherlock Holmes 🕵️♂️ that decides if the data is real or fake.
They’re trained together in an adversarial setup:
- The generator tries to fool the discriminator with convincing fakes.
- The discriminator tries to catch the generator’s lies.
How GANs Work
Step-by-Step Workflow
- Input: Start with random noise (\( z \)).
- Generator: Creates fake data (\( G(z) \)).
- Discriminator: Classifies data as real or fake.
- Real: Data from the true dataset (\( x \)).
- Fake: Data generated by the generator (\( G(z) \)).
- Feedback Loop:
- If the discriminator catches a fake, it improves.
- If the generator fools the discriminator, it improves.
- Adversarial Training:
- Repeat until the generator becomes so good, the discriminator can’t tell real from fake. 🎉
The Math Behind GANs
GANs optimize a minimax game:
\[ \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] \]Where:
- \( D(x) \): Probability the discriminator assigns to real data.
- \( G(z) \): Data generated by the generator from random noise \( z \).
Numerical Example
Imagine training a GAN to generate cat images:
- Generator:
- Takes noise (\( z \)) like \([0.1, -0.7]\) and outputs a fake image.
- Discriminator:
- Sees real cat images and fakes, assigning probabilities like:
- Real cat: \( D(x) = 0.9 \)
- Fake cat: \( D(G(z)) = 0.2 \)
- Sees real cat images and fakes, assigning probabilities like:
- Feedback:
- Generator updates to make \( D(G(z)) \) closer to 1.
- Discriminator updates to push \( D(G(z)) \) closer to 0.
Why GANs Are Awesome
- High-Quality Generation:
- GANs can create stunningly realistic images, videos, and audio.
- Creative Applications:
- From art generation to deepfakes.
- No Need for Explicit Data Distributions:
- GANs learn directly from the data.
Challenges of GANs
- Training Instability:
- Generator and discriminator must be perfectly balanced.
- Mode Collapse:
- The generator produces limited variations, losing diversity.
- Hard to Evaluate:
- No clear metric for “quality” of generated data.
Code Example: A Simple GAN in PyTorch
Here’s how to create a basic GAN:
import torch
import torch.nn as nn
import torch.optim as optim
# Generator
class Generator(nn.Module):
def __init__(self, noise_dim, data_dim):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(noise_dim, 128),
nn.ReLU(),
nn.Linear(128, data_dim),
nn.Tanh()
)
def forward(self, z):
return self.fc(z)
# Discriminator
class Discriminator(nn.Module):
def __init__(self, data_dim):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(data_dim, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.fc(x)
# Hyperparameters
noise_dim = 100
data_dim = 28 * 28 # Example for flattened 28x28 image
generator = Generator(noise_dim, data_dim)
discriminator = Discriminator(data_dim)
# Loss and Optimizers
loss_fn = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)
# Training Loop (Simplified)
for epoch in range(10): # Example epochs
# Train Discriminator
real_data = torch.randn(64, data_dim) # Simulate real data
fake_data = generator(torch.randn(64, noise_dim)) # Generate fake data
real_labels = torch.ones(64, 1)
fake_labels = torch.zeros(64, 1)
d_loss = loss_fn(discriminator(real_data), real_labels) + \
loss_fn(discriminator(fake_data.detach()), fake_labels)
optimizer_d.zero_grad()
d_loss.backward()
optimizer_d.step()
# Train Generator
g_loss = loss_fn(discriminator(fake_data), real_labels) # Fool discriminator
optimizer_g.zero_grad()
g_loss.backward()
optimizer_g.step()
print(f"Epoch {epoch + 1}: D Loss = {d_loss.item()}, G Loss = {g_loss.item()}")
Fun Analogy
GANs are like two siblings competing for a prize:
- The Generator is the crafty younger sibling trying to sneak cookies from the jar (fake data). 🍪
- The Discriminator is the overprotective older sibling watching like a hawk (real vs. fake). 👀 The more they compete, the better they both get at their roles. Eventually, the generator becomes a master cookie thief! 🎉
Mermaid.js Diagram: GAN Workflow
graph TD RandomNoise[Random Noise z] --> Generator[Generator: Create Fake Data] Generator --> Discriminator[Discriminator: Classify Real vs. Fake] Discriminator --> RealFake[Real or Fake?] RealData[Real Data x] --> Discriminator RealFake --> Feedback[Feedback to Update G and D] Feedback --> Generator Feedback --> Discriminator
3. Diffusion Models: Turning Noise Into Masterpieces 🎨✨
What Are Diffusion Models?
A Diffusion Model learns to generate data by reversing a noise process. It starts with random noise and iteratively denoises it to generate realistic data like images, text, or even molecules.
Think of it as a reverse time machine:
- Start with data (e.g., a picture of a cat 🐱).
- Gradually add noise until all structure is lost (pure static).
- Learn how to reverse this process, step by step, to recreate (or generate new) data.
How Diffusion Models Work
1. The Forward Process (Noising Data)
Gradually add Gaussian noise to data over \( T \) timesteps:
\[ q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} \cdot x_{t-1}, \beta_t \cdot I) \]Where:
- \( x_t \): Data at timestep \( t \).
- \( \beta_t \): Noise schedule, controlling how much noise is added.
After \( T \) steps, the data is pure noise:
\[ q(x_T) \approx \mathcal{N}(0, I) \]2. The Reverse Process (Denoising)
Learn to reverse the noising process:
\[ p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)) \]Here, the model predicts the mean \( \mu_\theta \) and variance \( \Sigma_\theta \) to undo the noise at each step.
3. Training Objective
Instead of directly learning \( p_\theta(x_{t-1} | x_t) \), the model learns to predict the original data \( x_0 \) or the added noise \( \epsilon \):
\[ \mathcal{L}_{\text{simple}} = \mathbb{E}_{x_0, \epsilon, t} \left[ \|\epsilon - \epsilon_\theta(x_t, t)\|^2 \right] \]Where:
- \( \epsilon \): Actual noise added during the forward process.
- \( \epsilon_\theta \): Predicted noise by the model.
Numerical Example
Imagine we’re generating a simple 2D image:
- Start with Data: \[ x_0 = \text{a 2D grid of pixel values} \]
- Forward Process:
- Add small noise at each step: \[ x_1 = x_0 + \epsilon_1, \, x_2 = x_1 + \epsilon_2, \, \ldots, \, x_T = \text{pure noise} \]
- Reverse Process:
- Use the model to predict and remove noise step by step: \[ x_{T-1} = x_T - \epsilon_\theta(x_T), \, x_{T-2} = x_{T-1} - \epsilon_\theta(x_{T-1}), \, \ldots, \, x_0 = \text{reconstructed data} \]
Why Diffusion Models Are Awesome
- High-Quality Generation:
- Diffusion models produce state-of-the-art images, often rivaling GANs.
- Stable Training:
- Unlike GANs, they don’t suffer from adversarial instability.
- Flexibility:
- Can generate diverse types of data, from images to molecules.
Challenges of Diffusion Models
- Slow Generation:
- Reversing the noise process involves many steps (up to thousands).
- Computational Cost:
- Training and inference are resource-intensive.
Code Example: Diffusion Model in PyTorch
Here’s a simple implementation of the forward noising process:
import torch
import numpy as np
# Forward Process: Add Gaussian Noise
def forward_diffusion(x_0, t, beta_schedule):
beta_t = beta_schedule[t]
noise = torch.randn_like(x_0)
x_t = torch.sqrt(1 - beta_t) * x_0 + torch.sqrt(beta_t) * noise
return x_t, noise
# Example Usage
x_0 = torch.randn((1, 28, 28)) # Example image (28x28 pixels)
beta_schedule = np.linspace(0.0001, 0.02, 1000) # Noise schedule
t = 10 # Timestep
x_t, noise = forward_diffusion(x_0, t, beta_schedule)
print("Noised Data at Timestep t:\n", x_t)
Fun Analogy
Diffusion Models are like unmixing a smoothie 🥤:
- Blend a perfectly good strawberry-banana smoothie (data).
- Gradually add more random ingredients (noise) until it’s unrecognizable sludge. 🤢
- Train a chef (model) to reverse the process and extract the original fruits step by step. 🍓🍌
- Voilà! You have a delicious smoothie again (or a totally new one).
Mermaid.js Diagram: Diffusion Model Workflow
graph TD Data[Original Data x0] --> AddNoise[Forward Process: Add Noise] AddNoise --> Noise[Pure Noise xT] Noise --> RemoveNoise[Reverse Process: Denoise Step by Step] RemoveNoise --> Output[Generated Data x0']