摘要
扩散模型作为生成式AI领域中技术最为扎实、效果最为突出的方向之一,已在图像生成、音频合成及分子设计等多个任务上全面超越了GAN与VAE。本文将从数学基础出发,系统推导前向扩散与反向去噪的核心流程,并基于PyTorch从零构建一个完整可运行的扩散模型实现方案。代码实现将涵盖噪声调度机制、损失函数计算、采样策略选择,以及实际开发中常见的问题与解决思路。全文以逻辑推导结合工程实践的形式呈现,适合具备深度学习基础、希望深入理解DDPM底层原理的开发者细致研读。

应用场景
扩散模型的核心优势在于能够从随机噪声中还原出高保真的数据分布,这一特性使其在众多领域得到广泛应用:
- 文本到图像生成,典型代表包括Stable Diffusion与DALL-E 3
- 图像超分辨率重建、修复与编辑任务
- 音频内容生成,例如AudioLDM框架
- 分子构象生成与药物发现
- 时间序列预测与数据插值
- 三维点云数据生成
简而言之,任何需要从随机噪声中生成高质量、多样化样本的应用场景,扩散模型均能胜任。
核心原理
扩散模型最精巧之处在于其双过程设计理念。
第一个过程为前向扩散。对原始数据 x₀ 逐步添加高斯噪声,经过 T 步后转化为纯噪声 x_T ~ N(0, I)。该过程是一个马尔可夫链,每一步的转移核定义为:
q(x_t | x_{t-1}) = N(x_t; sqrt(1 - β_t) * x_{t-1}, β_t * I)
其中 β_t 是预先定义的噪声调度参数,通常从 1e-4 线性递增至 0.02。
借助重参数化技巧,可直接从 x₀ 推导出任意时刻 t 对应的 x_t:
x_t = sqrt(ᾱ_t) * x₀ + sqrt(1 - ᾱ_t) * ε
这里 α_t = 1 - β_t,ᾱ_t = ∏_{i=1}^t α_i,ε ~ N(0, I)。
第二个过程为反向去噪。需要训练一个神经网络 ε_θ(x_t, t) 来预测所添加的噪声 ε,并逐步将其去除,从而从 x_T 恢复出 x₀。反向过程同样是一个马尔可夫链,其转移核表示为:
p_θ(x_{t-1} | x_t) = N(x_{t-1}; μ_θ(x_t, t), σ_t² * I)
μ_θ 的推导基于变分下界,最终简化为:
μ_θ(x_t, t) = (1 / sqrt(α_t)) * (x_t - (β_t / sqrt(1 - ᾱ_t)) * ε_θ(x_t, t))
σ_t² = β_t(DDPM原始设定),也可采用更复杂的调度方案。
损失函数设计简洁直观:训练阶段最小化预测噪声与真实噪声之间的均方误差:
L = E_{t, x₀, ε} [ || ε - ε_θ( sqrt(ᾱ_t) * x₀ + sqrt(1 - ᾱ_t) * ε, t ) ||² ]
采样过程同样直观:从 x_T ~ N(0, I) 出发,对 t 从 T 到 1 进行迭代:
x_{t-1} = (1 / sqrt(α_t)) * (x_t - (β_t / sqrt(1 - ᾱ_t)) * ε_θ(x_t, t)) + σ_t * z
其中 z ~ N(0, I),当 t=1 时 z=0。
详细步骤
1. 定义噪声调度
采用线性调度策略:β_t 从 β₁ 线性递增至 β_T,随后计算 α_t 及 ᾱ_t。
2. 构建神经网络
基于UNet架构,包含下采样、上采样及跳跃连接模块。网络输入为带噪图像 x_t 与时间步 t,时间步通过正弦位置编码嵌入,并注入到每个残差块中。
3. 训练循环
- 从数据集中采样 x₀
- 随机采样 t ~ Uniform(1, T)
- 采样噪声 ε ~ N(0, I)
- 计算 x_t = sqrt(ᾱ_t) * x₀ + sqrt(1 - ᾱ_t) * ε
- 网络预测 ε_pred = ε_θ(x_t, t)
- 计算损失 MSE(ε, ε_pred) 并反向传播更新参数
4. 采样(推理)
- 从标准正态分布采样 x_T
- 对 t 从 T 到 1 执行以下步骤:
- 预测噪声 ε_pred = ε_θ(x_t, t)
- 计算 x_{t-1} 的均值 μ
- 若 t>1,添加噪声 σ_t * z
- 返回 x₀
完整可运行代码
以下代码基于PyTorch实现了一个简化版扩散模型,在MNIST数据集上完成训练并生成手写数字图像。代码可直接运行,每行均附有详细注释。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import math
# ---------- 1. 噪声调度 ----------
def linear_beta_schedule(T, beta_start=1e-4, beta_end=0.02):
"""线性噪声调度,返回beta_t, alpha_t, alpha_bar_t"""
betas = torch.linspace(beta_start, beta_end, T, dtype=torch.float32)
alphas = 1.0 - betas
alphas_bar = torch.cumprod(alphas, dim=0) # 累积乘积
return betas, alphas, alphas_bar
# ---------- 2. 时间嵌入 ----------
class SinusoidalPosEmb(nn.Module):
"""正弦位置编码,将时间步t映射为embedding向量"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
# t: [batch_size],取值范围[0, T-1]
device = t.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = t[:, None].float() * emb[None, :] # [batch, half_dim]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # [batch, dim]
return emb
# ---------- 3. 简易UNet ----------
class SimpleUNet(nn.Module):
"""适用于MNIST的轻量UNet,输入通道1,输出通道1"""
def __init__(self, T, time_emb_dim=128):
super().__init__()
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(time_emb_dim),
nn.Linear(time_emb_dim, time_emb_dim),
nn.ReLU(),
nn.Linear(time_emb_dim, time_emb_dim),
nn.ReLU()
)
# 下采样
self.enc1 = nn.Conv2d(1, 64, 3, padding=1)
self.enc2 = nn.Conv2d(64, 128, 3, padding=1, stride=2)
self.enc3 = nn.Conv2d(128, 256, 3, padding=1, stride=2)
# 中间
self.mid = nn.Conv2d(256, 256, 3, padding=1)
# 上采样
self.dec3 = nn.ConvTranspose2d(256 + 256, 128, 4, stride=2, padding=1)
self.dec2 = nn.ConvTranspose2d(128 + 128, 64, 4, stride=2, padding=1)
self.dec1 = nn.Conv2d(64 + 64, 1, 3, padding=1)
# 时间embedding映射到各层
self.time_proj1 = nn.Linear(time_emb_dim, 64)
self.time_proj2 = nn.Linear(time_emb_dim, 128)
self.time_proj3 = nn.Linear(time_emb_dim, 256)
self.time_proj_mid = nn.Linear(time_emb_dim, 256)
self.time_proj_d3 = nn.Linear(time_emb_dim, 128)
self.time_proj_d2 = nn.Linear(time_emb_dim, 64)
def forward(self, x, t):
# x: [batch, 1, 28, 28], t: [batch]
time_emb = self.time_mlp(t) # [batch, time_emb_dim]
# 编码
e1 = self.enc1(x) # [batch, 64, 28, 28]
e1 = e1 + self.time_proj1(time_emb)[:, :, None, None]
e1 = F.relu(e1)
e2 = self.enc2(e1) # [batch, 128, 14, 14]
e2 = e2 + self.time_proj2(time_emb)[:, :, None, None]
e2 = F.relu(e2)
e3 = self.enc3(e2) # [batch, 256, 7, 7]
e3 = e3 + self.time_proj3(time_emb)[:, :, None, None]
e3 = F.relu(e3)
# 中间
m = self.mid(e3) # [batch, 256, 7, 7]
m = m + self.time_proj_mid(time_emb)[:, :, None, None]
m = F.relu(m)
# 解码(跳跃连接)
d3 = torch.cat([m, e3], dim=1) # [batch, 512, 7, 7]
d3 = self.dec3(d3) # [batch, 128, 14, 14]
d3 = d3 + self.time_proj_d3(time_emb)[:, :, None, None]
d3 = F.relu(d3)
d2 = torch.cat([d3, e2], dim=1) # [batch, 256, 14, 14]
d2 = self.dec2(d2) # [batch, 64, 28, 28]
d2 = d2 + self.time_proj_d2(time_emb)[:, :, None, None]
d2 = F.relu(d2)
d1 = torch.cat([d2, e1], dim=1) # [batch, 128, 28, 28]
out = self.dec1(d1) # [batch, 1, 28, 28]
return out
# ---------- 4. 扩散模型封装 ----------
class DiffusionModel:
def __init__(self, T=1000, device='cuda'):
self.T = T
self.device = device
self.betas, self.alphas, self.alphas_bar = linear_beta_schedule(T)
self.betas = self.betas.to(device)
self.alphas = self.alphas.to(device)
self.alphas_bar = self.alphas_bar.to(device)
# 预计算一些常数
self.sqrt_alphas_bar = torch.sqrt(self.alphas_bar)
self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alphas_bar)
self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
self.posterior_variance = self.betas # sigma_t^2 = beta_t
self.model = SimpleUNet(T).to(device)
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
def train_step(self, x0):
"""单步训练"""
batch_size = x0.shape[0]
t = torch.randint(0, self.T, (batch_size,), device=self.device).long()
noise = torch.randn_like(x0)
# 前向加噪
sqrt_ab = self.sqrt_alphas_bar[t][:, None, None, None]
sqrt_one_minus_ab = self.sqrt_one_minus_alphas_bar[t][:, None, None, None]
xt = sqrt_ab * x0 + sqrt_one_minus_ab * noise
# 预测噪声
noise_pred = self.model(xt, t.float())
# 损失
loss = F.mse_loss(noise_pred, noise)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
@torch.no_grad()
def sample(self, batch_size=16):
"""DDPM采样,从x_T逐步去噪到x_0"""
x = torch.randn(batch_size, 1, 28, 28, device=self.device)
for t in reversed(range(self.T)):
t_tensor = torch.full((batch_size,), t, device=self.device, dtype=torch.float32)
noise_pred = self.model(x, t_tensor)
# 计算均值
sqrt_recip_alpha = self.sqrt_recip_alphas[t]
beta = self.betas[t]
sqrt_one_minus_ab = self.sqrt_one_minus_alphas_bar[t]
x_mean = sqrt_recip_alpha * (x - (beta / sqrt_one_minus_ab) * noise_pred)
if t > 0:
noise = torch.randn_like(x)
sigma = torch.sqrt(self.posterior_variance[t])
x = x_mean + sigma * noise
else:
x = x_mean
return x
# ---------- 5. 训练与采样 ----------
def main():
device = 'cuda' if torch.cuda.is_a vailable() else 'cpu'
print(f'Using device: {device}')
# 数据加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # 归一化到[-1, 1]
])
dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)
# 模型初始化
diffusion = DiffusionModel(T=200, device=device) # T=200加速演示
n_epochs = 5
# 训练
for epoch in range(n_epochs):
total_loss = 0.0
for batch_idx, (images, _) in enumerate(dataloader):
images = images.to(device)
loss = diffusion.train_step(images)
total_loss += loss
if batch_idx % 100 == 0:
print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss:.6f}')
a vg_loss = total_loss / len(dataloader)
print(f'Epoch {epoch+1} finished, A verage Loss: {a vg_loss:.6f}')
# 采样
print('Generating samples...')
samples = diffusion.sample(batch_size=16)
# 将样本从[-1,1]反归一化到[0,1]用于保存
samples = (samples + 1.0) / 2.0
samples = torch.clamp(samples, 0.0, 1.0)
# 保存为numpy数组,可用于可视化
samples_np = samples.cpu().numpy()
np.sa ve('generated_mnist.npy', samples_np)
print(f'Samples sa ved to generated_mnist.npy, shape: {samples_np.shape}')
if __name__ == '__main__':
main()
运行结果说明
代码执行后,控制台会逐轮打印损失值,通常损失会从约0.5逐步下降至0.05以下。采样完成后,当前目录会生成一个numpy文件 generated_mnist.npy,维度为(16, 1, 28, 28),包含16张生成的手写数字图像。每个像素值均归一化至[0, 1]区间,可直接使用matplotlib的imshow进行可视化,肉眼可清晰辨别数字轮廓。若希望进一步提升生成效果,可将T增加至1000并延长训练轮次,生成质量会显著改善,逼近真实MNIST样本的视觉水平。
常见问题与避坑
1. 训练不收敛或loss震荡
- 原因:学习率设置过高或batch size过小。建议采用Adam优化器,学习率设为1e-4,batch size至少为64。
- 检查:确认输入x₀已归一化至[-1, 1]区间(使用Normalize([0.5], [0.5])即可),因为噪声ε服从标准正态分布,若x₀尺度不匹配,梯度更新容易出现问题。
2. 生成样本全黑或全白
- 原因:采样过程中噪声调度使用有误。常见情况是未对x_t进行正确缩放,或σ_t计算出现偏差。
- 检查:采样循环中的x_mean公式需仔细核对,此外当t>0时,所添加的噪声标准差必须严格为 sqrt(β_t)。
3. 生成样本模糊
- 原因:T取值过小(小于100),或网络容量不足以捕捉数据分布。扩散模型需要足够的步数才能恢复出精细细节。
- 解决:将T调整为1000,同时加深UNet网络结构(增加通道数或层数)。
4. 显存溢出
- 原因:batch size或T设置过大,导致中间变量占用过多显存。
- 解决:适当减小batch size,或采用梯度累积策略,必要时可使用FP16混合精度训练。
5. 训练时loss为NaN
- 原因:数值计算不稳定。需检查 ᾱ_t 是否出现0值(当T较大时,累积乘积可能下溢)。
- 解决:切换至float64精度,或为 ᾱ_t 添加一个极小值eps(如1e-12)。更推荐的方案是采用cosine调度替代线性调度。
6. 采样速度慢
- 原因:DDPM需要逐步迭代T次,当T=1000时推理耗时较长。
- 解决:改用DDIM采样(确定性采样,步数可压缩至50-100),或采用DPM-solver等加速策略。
7. 模型过拟合
- 原因:训练数据集规模较小,或网络容量过大。
- 解决:引入数据增强手段(随机翻转、旋转等),或添加dropout正则化层。
总结
至此,从数学推导到工程实现,扩散模型的核心技术细节已得到完整梳理。最关键的要点在于理解两个核心机制:前向过程的闭合形式(即 x_t 与 x₀ 之间的直接关系),以及反向过程的变分下界简化策略。训练阶段的核心任务就是单纯地预测噪声,而采样阶段则是逐步将其去除。本文提供的代码实现了完整的训练与采样流程,可直接在MNIST数据集上生成手写数字图像。
与GAN相比,扩散模型的优势十分显著:训练过程更加稳定,模式覆盖更加全面。不过采样速度较慢始终是一个瓶颈。目前工业界普遍采用DDIM、LCM、知识蒸馏等手段来加速推理。在吃透本基础实现之后,再去研读DDPM、DDIM、Score SDE等相关论文,会更加得心应手。
最后建议,完成代码运行后,不妨自行调整T值、尝试不同的噪声调度策略(如cosine调度)、或修改网络结构,观察生成质量的变化——这样对扩散模型各组件的作用机理才能真正掌握。
