游乐游手机版
首页/AI教程/文章详情

PyTorch从零实现DDPM:时间嵌入与UNet扩散调度完整复现

时间:2026-06-23 15:28
从数学原理出发,用PyTorch从零搭建完整可跑的扩散模型,涵盖噪声调度、损失函数、采样策略及常见坑点。详细实现时间嵌入与UNet结构,适合已有深度学习基础的读者深入理解扩散模型底层逻辑。

摘要

扩散模型作为生成式AI领域中技术最为扎实、效果最为突出的方向之一,已在图像生成、音频合成及分子设计等多个任务上全面超越了GAN与VAE。本文将从数学基础出发,系统推导前向扩散与反向去噪的核心流程,并基于PyTorch从零构建一个完整可运行的扩散模型实现方案。代码实现将涵盖噪声调度机制、损失函数计算、采样策略选择,以及实际开发中常见的问题与解决思路。全文以逻辑推导结合工程实践的形式呈现,适合具备深度学习基础、希望深入理解DDPM底层原理的开发者细致研读。

PyTorch从零搭建DDPM:时间嵌入+UNet网络+扩散调度完整复现

应用场景

扩散模型的核心优势在于能够从随机噪声中还原出高保真的数据分布,这一特性使其在众多领域得到广泛应用:

  • 文本到图像生成,典型代表包括Stable Diffusion与DALL-E 3
  • 图像超分辨率重建、修复与编辑任务
  • 音频内容生成,例如AudioLDM框架
  • 分子构象生成与药物发现
  • 时间序列预测与数据插值
  • 三维点云数据生成

简而言之,任何需要从随机噪声中生成高质量、多样化样本的应用场景,扩散模型均能胜任。

核心原理

扩散模型最精巧之处在于其双过程设计理念。

第一个过程为前向扩散。对原始数据 x₀ 逐步添加高斯噪声,经过 T 步后转化为纯噪声 x_T ~ N(0, I)。该过程是一个马尔可夫链,每一步的转移核定义为:

q(x_t | x_{t-1}) = N(x_t; sqrt(1 - β_t) * x_{t-1}, β_t * I)

其中 β_t 是预先定义的噪声调度参数,通常从 1e-4 线性递增至 0.02。

借助重参数化技巧,可直接从 x₀ 推导出任意时刻 t 对应的 x_t:

x_t = sqrt(ᾱ_t) * x₀ + sqrt(1 - ᾱ_t) * ε

这里 α_t = 1 - β_t,ᾱ_t = ∏_{i=1}^t α_i,ε ~ N(0, I)。

第二个过程为反向去噪。需要训练一个神经网络 ε_θ(x_t, t) 来预测所添加的噪声 ε,并逐步将其去除,从而从 x_T 恢复出 x₀。反向过程同样是一个马尔可夫链,其转移核表示为:

p_θ(x_{t-1} | x_t) = N(x_{t-1}; μ_θ(x_t, t), σ_t² * I)

μ_θ 的推导基于变分下界,最终简化为:

μ_θ(x_t, t) = (1 / sqrt(α_t)) * (x_t - (β_t / sqrt(1 - ᾱ_t)) * ε_θ(x_t, t))

σ_t² = β_t(DDPM原始设定),也可采用更复杂的调度方案。

损失函数设计简洁直观:训练阶段最小化预测噪声与真实噪声之间的均方误差:

L = E_{t, x₀, ε} [ || ε - ε_θ( sqrt(ᾱ_t) * x₀ + sqrt(1 - ᾱ_t) * ε, t ) ||² ]

采样过程同样直观:从 x_T ~ N(0, I) 出发,对 t 从 T 到 1 进行迭代:

x_{t-1} = (1 / sqrt(α_t)) * (x_t - (β_t / sqrt(1 - ᾱ_t)) * ε_θ(x_t, t)) + σ_t * z

其中 z ~ N(0, I),当 t=1 时 z=0。

详细步骤

1. 定义噪声调度

采用线性调度策略:β_t 从 β₁ 线性递增至 β_T,随后计算 α_t 及 ᾱ_t。

2. 构建神经网络

基于UNet架构,包含下采样、上采样及跳跃连接模块。网络输入为带噪图像 x_t 与时间步 t,时间步通过正弦位置编码嵌入,并注入到每个残差块中。

3. 训练循环

  • 从数据集中采样 x₀
  • 随机采样 t ~ Uniform(1, T)
  • 采样噪声 ε ~ N(0, I)
  • 计算 x_t = sqrt(ᾱ_t) * x₀ + sqrt(1 - ᾱ_t) * ε
  • 网络预测 ε_pred = ε_θ(x_t, t)
  • 计算损失 MSE(ε, ε_pred) 并反向传播更新参数

4. 采样(推理)

  • 从标准正态分布采样 x_T
  • 对 t 从 T 到 1 执行以下步骤:
    • 预测噪声 ε_pred = ε_θ(x_t, t)
    • 计算 x_{t-1} 的均值 μ
    • 若 t>1,添加噪声 σ_t * z
  • 返回 x₀

完整可运行代码

以下代码基于PyTorch实现了一个简化版扩散模型,在MNIST数据集上完成训练并生成手写数字图像。代码可直接运行,每行均附有详细注释。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import math

# ---------- 1. 噪声调度 ----------
def linear_beta_schedule(T, beta_start=1e-4, beta_end=0.02):
    """线性噪声调度,返回beta_t, alpha_t, alpha_bar_t"""
    betas = torch.linspace(beta_start, beta_end, T, dtype=torch.float32)
    alphas = 1.0 - betas
    alphas_bar = torch.cumprod(alphas, dim=0)  # 累积乘积
    return betas, alphas, alphas_bar

# ---------- 2. 时间嵌入 ----------
class SinusoidalPosEmb(nn.Module):
    """正弦位置编码,将时间步t映射为embedding向量"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, t):
        # t: [batch_size],取值范围[0, T-1]
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None].float() * emb[None, :]  # [batch, half_dim]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # [batch, dim]
        return emb

# ---------- 3. 简易UNet ----------
class SimpleUNet(nn.Module):
    """适用于MNIST的轻量UNet,输入通道1,输出通道1"""
    def __init__(self, T, time_emb_dim=128):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        # 下采样
        self.enc1 = nn.Conv2d(1, 64, 3, padding=1)
        self.enc2 = nn.Conv2d(64, 128, 3, padding=1, stride=2)
        self.enc3 = nn.Conv2d(128, 256, 3, padding=1, stride=2)
        # 中间
        self.mid = nn.Conv2d(256, 256, 3, padding=1)
        # 上采样
        self.dec3 = nn.ConvTranspose2d(256 + 256, 128, 4, stride=2, padding=1)
        self.dec2 = nn.ConvTranspose2d(128 + 128, 64, 4, stride=2, padding=1)
        self.dec1 = nn.Conv2d(64 + 64, 1, 3, padding=1)
        # 时间embedding映射到各层
        self.time_proj1 = nn.Linear(time_emb_dim, 64)
        self.time_proj2 = nn.Linear(time_emb_dim, 128)
        self.time_proj3 = nn.Linear(time_emb_dim, 256)
        self.time_proj_mid = nn.Linear(time_emb_dim, 256)
        self.time_proj_d3 = nn.Linear(time_emb_dim, 128)
        self.time_proj_d2 = nn.Linear(time_emb_dim, 64)

    def forward(self, x, t):
        # x: [batch, 1, 28, 28], t: [batch]
        time_emb = self.time_mlp(t)  # [batch, time_emb_dim]
        # 编码
        e1 = self.enc1(x)  # [batch, 64, 28, 28]
        e1 = e1 + self.time_proj1(time_emb)[:, :, None, None]
        e1 = F.relu(e1)
        e2 = self.enc2(e1)  # [batch, 128, 14, 14]
        e2 = e2 + self.time_proj2(time_emb)[:, :, None, None]
        e2 = F.relu(e2)
        e3 = self.enc3(e2)  # [batch, 256, 7, 7]
        e3 = e3 + self.time_proj3(time_emb)[:, :, None, None]
        e3 = F.relu(e3)
        # 中间
        m = self.mid(e3)  # [batch, 256, 7, 7]
        m = m + self.time_proj_mid(time_emb)[:, :, None, None]
        m = F.relu(m)
        # 解码(跳跃连接)
        d3 = torch.cat([m, e3], dim=1)  # [batch, 512, 7, 7]
        d3 = self.dec3(d3)  # [batch, 128, 14, 14]
        d3 = d3 + self.time_proj_d3(time_emb)[:, :, None, None]
        d3 = F.relu(d3)
        d2 = torch.cat([d3, e2], dim=1)  # [batch, 256, 14, 14]
        d2 = self.dec2(d2)  # [batch, 64, 28, 28]
        d2 = d2 + self.time_proj_d2(time_emb)[:, :, None, None]
        d2 = F.relu(d2)
        d1 = torch.cat([d2, e1], dim=1)  # [batch, 128, 28, 28]
        out = self.dec1(d1)  # [batch, 1, 28, 28]
        return out

# ---------- 4. 扩散模型封装 ----------
class DiffusionModel:
    def __init__(self, T=1000, device='cuda'):
        self.T = T
        self.device = device
        self.betas, self.alphas, self.alphas_bar = linear_beta_schedule(T)
        self.betas = self.betas.to(device)
        self.alphas = self.alphas.to(device)
        self.alphas_bar = self.alphas_bar.to(device)
        # 预计算一些常数
        self.sqrt_alphas_bar = torch.sqrt(self.alphas_bar)
        self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alphas_bar)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.posterior_variance = self.betas  # sigma_t^2 = beta_t
        self.model = SimpleUNet(T).to(device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)

    def train_step(self, x0):
        """单步训练"""
        batch_size = x0.shape[0]
        t = torch.randint(0, self.T, (batch_size,), device=self.device).long()
        noise = torch.randn_like(x0)
        # 前向加噪
        sqrt_ab = self.sqrt_alphas_bar[t][:, None, None, None]
        sqrt_one_minus_ab = self.sqrt_one_minus_alphas_bar[t][:, None, None, None]
        xt = sqrt_ab * x0 + sqrt_one_minus_ab * noise
        # 预测噪声
        noise_pred = self.model(xt, t.float())
        # 损失
        loss = F.mse_loss(noise_pred, noise)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    @torch.no_grad()
    def sample(self, batch_size=16):
        """DDPM采样,从x_T逐步去噪到x_0"""
        x = torch.randn(batch_size, 1, 28, 28, device=self.device)
        for t in reversed(range(self.T)):
            t_tensor = torch.full((batch_size,), t, device=self.device, dtype=torch.float32)
            noise_pred = self.model(x, t_tensor)
            # 计算均值
            sqrt_recip_alpha = self.sqrt_recip_alphas[t]
            beta = self.betas[t]
            sqrt_one_minus_ab = self.sqrt_one_minus_alphas_bar[t]
            x_mean = sqrt_recip_alpha * (x - (beta / sqrt_one_minus_ab) * noise_pred)
            if t > 0:
                noise = torch.randn_like(x)
                sigma = torch.sqrt(self.posterior_variance[t])
                x = x_mean + sigma * noise
            else:
                x = x_mean
        return x

# ---------- 5. 训练与采样 ----------
def main():
    device = 'cuda' if torch.cuda.is_a vailable() else 'cpu'
    print(f'Using device: {device}')
    # 数据加载
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])  # 归一化到[-1, 1]
    ])
    dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)
    # 模型初始化
    diffusion = DiffusionModel(T=200, device=device)  # T=200加速演示
    n_epochs = 5
    # 训练
    for epoch in range(n_epochs):
        total_loss = 0.0
        for batch_idx, (images, _) in enumerate(dataloader):
            images = images.to(device)
            loss = diffusion.train_step(images)
            total_loss += loss
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss:.6f}')
        a vg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch+1} finished, A verage Loss: {a vg_loss:.6f}')
    # 采样
    print('Generating samples...')
    samples = diffusion.sample(batch_size=16)
    # 将样本从[-1,1]反归一化到[0,1]用于保存
    samples = (samples + 1.0) / 2.0
    samples = torch.clamp(samples, 0.0, 1.0)
    # 保存为numpy数组,可用于可视化
    samples_np = samples.cpu().numpy()
    np.sa ve('generated_mnist.npy', samples_np)
    print(f'Samples sa ved to generated_mnist.npy, shape: {samples_np.shape}')

if __name__ == '__main__':
    main()

运行结果说明

代码执行后,控制台会逐轮打印损失值,通常损失会从约0.5逐步下降至0.05以下。采样完成后,当前目录会生成一个numpy文件 generated_mnist.npy,维度为(16, 1, 28, 28),包含16张生成的手写数字图像。每个像素值均归一化至[0, 1]区间,可直接使用matplotlib的imshow进行可视化,肉眼可清晰辨别数字轮廓。若希望进一步提升生成效果,可将T增加至1000并延长训练轮次,生成质量会显著改善,逼近真实MNIST样本的视觉水平。

常见问题与避坑

1. 训练不收敛或loss震荡

  • 原因:学习率设置过高或batch size过小。建议采用Adam优化器,学习率设为1e-4,batch size至少为64。
  • 检查:确认输入x₀已归一化至[-1, 1]区间(使用Normalize([0.5], [0.5])即可),因为噪声ε服从标准正态分布,若x₀尺度不匹配,梯度更新容易出现问题。

2. 生成样本全黑或全白

  • 原因:采样过程中噪声调度使用有误。常见情况是未对x_t进行正确缩放,或σ_t计算出现偏差。
  • 检查:采样循环中的x_mean公式需仔细核对,此外当t>0时,所添加的噪声标准差必须严格为 sqrt(β_t)。

3. 生成样本模糊

  • 原因:T取值过小(小于100),或网络容量不足以捕捉数据分布。扩散模型需要足够的步数才能恢复出精细细节。
  • 解决:将T调整为1000,同时加深UNet网络结构(增加通道数或层数)。

4. 显存溢出

  • 原因:batch size或T设置过大,导致中间变量占用过多显存。
  • 解决:适当减小batch size,或采用梯度累积策略,必要时可使用FP16混合精度训练。

5. 训练时loss为NaN

  • 原因:数值计算不稳定。需检查 ᾱ_t 是否出现0值(当T较大时,累积乘积可能下溢)。
  • 解决:切换至float64精度,或为 ᾱ_t 添加一个极小值eps(如1e-12)。更推荐的方案是采用cosine调度替代线性调度。

6. 采样速度慢

  • 原因:DDPM需要逐步迭代T次,当T=1000时推理耗时较长。
  • 解决:改用DDIM采样(确定性采样,步数可压缩至50-100),或采用DPM-solver等加速策略。

7. 模型过拟合

  • 原因:训练数据集规模较小,或网络容量过大。
  • 解决:引入数据增强手段(随机翻转、旋转等),或添加dropout正则化层。

总结

至此,从数学推导到工程实现,扩散模型的核心技术细节已得到完整梳理。最关键的要点在于理解两个核心机制:前向过程的闭合形式(即 x_t 与 x₀ 之间的直接关系),以及反向过程的变分下界简化策略。训练阶段的核心任务就是单纯地预测噪声,而采样阶段则是逐步将其去除。本文提供的代码实现了完整的训练与采样流程,可直接在MNIST数据集上生成手写数字图像。

与GAN相比,扩散模型的优势十分显著:训练过程更加稳定,模式覆盖更加全面。不过采样速度较慢始终是一个瓶颈。目前工业界普遍采用DDIM、LCM、知识蒸馏等手段来加速推理。在吃透本基础实现之后,再去研读DDPM、DDIM、Score SDE等相关论文,会更加得心应手。

最后建议,完成代码运行后,不妨自行调整T值、尝试不同的噪声调度策略(如cosine调度)、或修改网络结构,观察生成质量的变化——这样对扩散模型各组件的作用机理才能真正掌握。

来源:https://juejin.cn/post/7653415873380188202
上一篇生产级RAG架构设计实战:完整流程与最佳实践 下一篇掌握色彩原理,拍照调色不再难
本站内容用于信息整理与展示,如有侵权或内容问题请及时联系处理。

相关推荐

补充同频道和同主题内容,方便继续浏览更多相关内容。

同类最新

继续查看同栏目最近更新的文章。

更多
CapCut AI Docker 一键部署:镜像拉取、端口映射与数据目录配置教程
AI教程 · 2026-06-30

CapCut AI Docker 一键部署:镜像拉取、端口映射与数据目录配置教程

CapCutAI容器化部署需先确认镜像来源与授权范围,再完成环境准备、镜像拉取、端口映射、数据目录挂载和启动验证,适合本地试用、团队内网演示与轻量化AI剪辑服务管理。

CapCut AI Windows本地安装配置2026最新版含下载与环境要求
AI教程 · 2026-06-30

CapCut AI Windows本地安装配置2026最新版含下载与环境要求

CapCutAI与剪映AI在Windows端适合短视频、口播、课程和营销素材剪辑,安装前需确认系统、显卡、存储与网络条件,优先选择官方渠道下载,并完成账号、素材目录、硬件加速和导出参数配置。

Veo新手保姆级安装教程:从下载到首次运行
AI教程 · 2026-06-30

Veo新手保姆级安装教程:从下载到首次运行

Veo适合用文字生成短视频,新手应先确认官方入口、准备账号与设备环境,再按网页或应用方式完成启用。首次运行重点在提示词、参数、素材合规与结果保存,避免使用非官方安装包。

Veo本地模型运行下载路径设置与性能优化指南
AI教程 · 2026-06-30

Veo本地模型运行下载路径设置与性能优化指南

Veo本地模型部署需先确认模型来源与硬件条件,再完成下载校验、目录规划、路径配置和推理参数优化。重点关注显存占用、依赖版本、缓存位置、授权范围与常见报错处理。

Veo安装失败解决指南:常见报错与日志排查及升级回滚方案
AI教程 · 2026-06-30

Veo安装失败解决指南:常见报错与日志排查及升级回滚方案

Veo安装失败通常与系统环境、依赖版本、网络源、权限和缓存有关。排查时应先确认版本要求,再查看安装日志,按报错类型处理,并提前备份项目,确保升级与回滚可控。