算法可视化与交互学习平台
学习 DiT:从 U-Net Denoiser 到 Diffusion TransformerDiT: From U-Net Denoisers to Diffusion Transformers
在 No.3 Transformer、No.4 VAE、No.5 DDPM、No.10 Latent Diffusion 与 No.11 条件生成的基础上,理解 DiT 如何把 noisy latent 切成 patch tokens,通过 timestep/class embeddings 与 adaLN-Zero 调制 Transformer blocks,在不改变 DDPM 训练目标和采样规则的前提下,用 Transformer 替换 U-Net denoiser。
先抓住 DiT 的核心变化
DiT(Diffusion Transformer)的一句话版本:
因此 DiT 不是新的扩散概率公式,而是一种 denoiser 架构。它把 noisy latent 切成 patch tokens,用 Self-Attention 建模 patch 之间的关系,再把 timestep 与类别条件通过 adaLN-Zero 注入每个 Transformer block。
从 No.3、4、5、10、11 走到 DiT
| 已有模块 | 带进 DiT 的知识 | 本模块不再重复什么 |
|---|---|---|
| No.3 Transformer | token、Self-Attention、MLP、残差连接 | 不重新推导 Q/K/V 基础 |
| No.4 VAE | image 与 latent 的压缩/解码 | 不重新推导 KL 与重参数化 |
| No.5 DDPM | forward 加噪、epsilon prediction、reverse sampling | 不重新推导 DDPM |
| No.10 Latent Diffusion | 在 z_t 而不是像素 x_t 上扩散 | 不重新解释为何压缩 |
| No.11 Cross Attention | 条件如何改变 denoiser、CFG 如何工作 | 原始 DiT 的 class condition 主要走 adaLN,不等同于 Cross Attention |
U-Net Denoiser vs DiT
| 维度 | U-Net | DiT |
|---|---|---|
| 基本单元 | 卷积 feature map | latent patch tokens |
| 空间建模 | 局部卷积 + 多尺度 down/up sampling | Self-Attention 直接连接所有 patch |
| timestep 条件 | 注入 ResBlock | 通过 adaLN 调制 Transformer block |
| 类别条件 | embedding / attention / feature fusion | 原始 DiT 将 class embedding 与 timestep embedding 相加后送入 adaLN |
| 输出 | 与 latent 同形状的噪声张量 | patch 输出 unpatchify 后仍是同形状噪声张量 |
两者都可以实现同一个函数 epsilon_theta(z_t,t,c)。区别不是“谁才是 diffusion”,而是谁来承担 denoiser。
Patchify:latent 怎样变成 Transformer tokens
Patchify 是 DiT 把二维 noisy latent 接到 Transformer 上的接口:先把 切成 个不重叠的 小块,再把每块展平、投影到 hidden size ,并加入二维位置编码,得到 Transformer 可以处理的 token sequence 。
1. 为什么 latent 不能原样交给 Transformer
卷积网络天然按通道、高度、宽度处理二维 feature map;标准 Transformer 的输入接口则是一串 token,每个 token 都是 D 维向量。因此问题不是 latent 中的信息不能被 Transformer 理解,而是张量组织方式不匹配:必须把空间网格改写成序列。Patchify 就是二者之间的形状适配器。
2. 先从 Stable Diffusion 的 latent 开始
以 512×512 RGB 图像和 8 倍空间压缩的 VAE 为例,Encoder 输出 64×64×4 latent。它有 64×64=4096 个空间位置,每个位置保存 4 个 latent features。DiT 接收的是扩散过程中的 noisy latent z_t,而不是原始 RGB pixels。
3. 切 Patch:空间网格变成较短的块序列
选择 patch size P=8 时,64×64 latent 被划分成 8×8 个不重叠区域,所以 N=64。P 不是固定常数,而是模型设计选择:P 越大,token 越少、Attention 越便宜,但每个 token 覆盖的空间更粗;P 越小,空间粒度更细,但 token 更多。
4. Flatten:每个二维小块变成一个向量
一个 patch 同时覆盖 P×P 个空间位置和 C 个通道,所以它含有 P²C 个数。示例中每个 8×8×4 patch 展平为 256 维向量。Flatten 只是重排元素,不学习参数,也不会丢失这个 patch 内的数值。
5. Linear Projection:patch vector 变成 token embedding
Transformer Block 的 hidden size 是 D,例如 768 或 1024,而 patch 的原始维度是 P²C。可学习矩阵 W_E 把每个 patch 独立投影到统一的 D 维空间。这个投影既完成维度转换,也学习怎样组合 patch 内的局部 latent features。
6. 为什么代码常用 Conv2d 代替显式切块 + Linear
kernel_size=P、stride=P 的 Conv2d 会在每个不重叠 P×P 区域上使用同一组权重,数学效果等价于“取 patch、展平、乘同一个 Linear”。Conv2d 直接输出 [B,D,H/P,W/P],随后 flatten 空间轴并 transpose 成 [B,N,D],实现更紧凑。
7. 加入二维位置编码:告诉 Attention 每个 patch 在哪里
Self-Attention 本身只看 token 内容;若打乱 token 顺序而不提供位置信息,它无法区分左上角与右下角。每个 patch token 因此要加上与网格坐标 (row,column) 对应的位置向量。位置编码不改变 token 数或 hidden size。
8. Patch size 为什么直接决定 Attention 成本
Self-Attention 要建立 N×N 的 token 关系矩阵,核心成本随 N² 增长。如果把 64×64 latent 的每个位置都当 token,N=4096;使用 P=8 后 N=64,Attention 关系数从约 1677 万降到 4096。代价是空间粒度变粗,因此 P 是质量与算力之间的重要旋钮。
9. 为什么先在 latent 中做 Patchify
若直接对 512×512 RGB 图像使用 P=8,会得到 64×64=4096 tokens;先用 VAE 压缩到 64×64,再使用相同 P=8,只得到 8×8=64 tokens。VAE 先降低空间分辨率,Patchify 再控制 token 粒度,两层降维共同让全局 Self-Attention 可计算。
10. ViT、DiT 与 Unpatchify:同一入口,不同任务
ViT 对 RGB patches 编码后做分类;DiT 对 noisy latent patches 编码后预测每个 patch 中的噪声。Final Layer 把每个 D 维 token 投影回 P²C 个数,Unpatchify 再按原网格拼回 [B,C,H,W],这样 Scheduler 才能用预测噪声更新 z_t。当前 Toy DiT 默认 latent=4×4×4、P=1、D=64,因此真实输入是 16 个 64 维 token。
timestep 与类别条件如何进入 DiT
原始 DiT 不把 timestep 或 class label 当作普通图像 patch。它先分别编码,再相加形成统一条件向量:
t_embed = timestep_embedding(t) c = time_mlp(t_embed) + class_embedding(y)
c 的形状是 [B,D]。它不增加 token 数,而是为每个样本生成一组 shift、scale 和 gate,调制全部 DiT blocks。
adaLN-Zero:用条件调制每个 DiT Block
adaLN-Zero = adaptive LayerNorm + zero initialization。条件 c 由 timestep 与类别/文本条件组成,它为每个样本、每个 DiT Block 动态生成归一化的 shift/scale 和两条残差分支的 gate;零初始化让深层网络从接近恒等映射开始,再逐步学习每一层应当修改多少。
1. 先从普通 LayerNorm 开始
Transformer 常见数据流是 token h 先经过 LayerNorm,再进入 Attention 或 MLP,最后走残差连接。普通 LayerNorm 的 gamma 与 beta 是训练得到但推理时固定的参数;同一层面对不同 timestep 和不同类别时,使用的是同一组 gamma、beta。
2. adaLN:让 gamma、beta 由条件动态决定
adaptive LayerNorm 不再只依赖一组固定的缩放和平移。条件向量 c 经过 MLP,为当前样本生成动态 shift 与 scale。换一个 timestep、类别或文本条件,归一化后的 token 就会被用不同方式重新缩放和偏移。
3. 为什么 Diffusion 特别需要 adaptive modulation
Diffusion 的 denoiser 在不同 t 面对的是不同任务:高噪声阶段主要恢复全局结构,中间阶段稳定轮廓与关系,低噪声阶段修正局部边界和细节。adaLN 让同一个 Transformer Block 根据当前 t 动态改变工作方式,而不是在所有噪声等级上使用完全相同的特征变换。
4. adaLN-Zero 一次生成六组参数
一个 DiT Block 有 Self-Attention 和 MLP 两条残差分支,所以条件 MLP 不是只输出一组 shift/scale,而是输出六组参数。msa 表示 multi-head self-attention。
5. Attention 分支的完整数据流
先对 token 做无 affine 参数的 LayerNorm,再用当前条件生成的 shift/scale 调制。Self-Attention 处理的是调制后的 latent patch tokens,最后乘 gate_msa 才写回主干。
6. MLP 分支使用独立的调制与 gate
Attention 更新后的 x 再进入第二次 LayerNorm。MLP 分支拥有自己的 shift_mlp、scale_mlp 和 gate_mlp,因此条件可以分别控制“patch 之间交换信息”和“每个 patch 内部变换特征”的强度。
7. gate 是每个残差分支的可学习油门
普通 Transformer 直接把 Attention/MLP 输出加回 x。adaLN-Zero 先乘条件 gate:gate 接近 0 时这条分支几乎关闭;绝对值增大时,分支对主干的修改增强。不同样本、不同 timestep 会得到不同 gate,因此它不仅是静态层权重。
8. Zero initialization 为什么让深层 DiT 更稳定
Toy DiT 将 modulation MLP 的最后一层权重与 bias 初始化为 0,最终预测头也从 0 开始。于是初始 shift、scale、gate 都接近 0,每个 Block 近似 x -> x,不会让许多随机初始化层从第一步起连续强行改写 token。训练先学会打开 gate,再学习更复杂的条件变换。DiT 论文的消融实验中,adaLN-Zero 是所比较条件注入方式里效果最好的设计。
9. 为什么代码写 1 + scale
如果直接写 h_mod = scale * LN(h) + shift,那么 scale 初始化为 0 会把归一化特征整体压成 0。写成 1 + scale 后,scale=0 对应单位缩放,shift=0 对应零偏移;调制分支初始保留原始 LayerNorm 特征。
10. 一个直观比喻:条件控制的两组加工台
普通 Transformer Block 上来就让 Attention 和 MLP 加工 token;adaLN 根据 t/类别决定如何加工;adaLN-Zero 还给两条加工线分别加上油门,而且初始油门为 0。模型训练时逐步学会:哪一层、哪个 timestep、哪个条件应该开多大。
11. adaLN-Zero 与 Cross Attention 不解决同一个问题
Cross Attention 让图像 Query 主动读取一组外部 token 的 Key/Value,适合保留逐 token 语义;adaLN-Zero 把条件压成控制向量,整体调节 Block 的归一化和残差强度。原始 DiT 用 timestep/class embedding 驱动 adaLN-Zero;现代文本生成模型也可以把文本池化向量用于调制,或同时保留 Cross/Joint Attention。
12. 与当前 Toy DiT 代码逐行对齐
当前实验的 _DiTBlock.forward() 正是上述完整逻辑。condition 由 timestep embedding 与 class embedding 相加得到,ada(condition) 输出六组参数,两条残差分支分别调制和门控。
逐步演算:一个条件如何打开 DiT 残差分支
h_norm = [0.5, -0.5][0.5, -0.5]
[0.2, 0.2]
[0.1, -0.1]
[0.5, -0.5]
输出头与 DDPM loss:目标没有改变
DiT 输出每个 patch 的噪声预测,unpatchify 后恢复为与 z_t 同形状的张量,再与 forward process 中真实加入的 epsilon 做 MSE。
代码
输出形状
DiT 采样与 CFG:Scheduler 仍然照常工作
训练时随机把一部分 class label 替换成 null class,同一个 DiT 就同时学会有条件和无条件噪声预测。生成时:
z_t = torch.randn(latent_shape)
for t in scheduler.timesteps:
eps_null = dit(z_t, t, null_class)
eps_class = dit(z_t, t, selected_class)
eps = eps_null + cfg_scale * (eps_class - eps_null)
z_t = scheduler.step(eps, t, z_t)Scheduler 不关心 denoiser 是 U-Net 还是 DiT。只要网络输出约定的噪声张量,它就能继续执行同一反向更新。
Patch size、token 数与计算量
DiT 的一个关键工程旋钮是 patch size。设 latent 是 H×W,patch size 是 P:
| latent | patch | tokens | 影响 |
|---|---|---|---|
| 4×4 | 2×2 | 4 | 快,但空间粒度粗 |
| 4×4 | 1×1 | 16 | 细粒度,更高 Attention 成本 |
| 8×8 | 1×1 | 64 | token 数增加 4 倍,Attention 矩阵元素增加 16 倍 |
扩大 hidden size、depth 或减少 patch size都会增加计算量。DiT 的核心优势之一,是可以沿 Transformer 的宽度、深度和 token 数进行规则化扩展。
Toy DiT 实验:训练 Transformer Denoiser,再从纯噪声生成
第一步在 CPU 上真实训练 TinyVAE 与 adaLN-Zero Tiny DiT。DiT 接收 noisy latent patch tokens、timestep embedding 和 shape class embedding,训练目标是 forward process 中真实加入的 epsilon。训练时随机丢弃部分 class condition,从而让同一个模型学会 CFG 的 conditional / null 两种预测。 训练完成后,选择类别并点击生成。生成 Action 会加载当前页面保存的真实模型权重,从纯 latent 高斯噪声执行完整反向采样,不会重新训练。 点击任一 Action 后,卡片底部会显示该次运行使用的完整 Python 代码,包括参数赋值、TinyVAE、DiT Block、adaLN-Zero、训练循环或 CFG 采样循环;代码区右上角可一键复制。
拆解 Toy DiT:公式如何落到代码
实验的主链不是“用 Transformer 直接画图”,而是让 Transformer 承担 DDPM 的 denoiser:
1. Patchify 不会改变扩散变量
tokens = patch_embed(z_t).flatten(2).transpose(1, 2) tokens = tokens + pos_embed
z_t 仍然是 DDPM 中的 noisy latent。Patchify 只是把它改写成 Transformer 可以处理的序列表示。
2. 条件不是额外图像 token
condition = time_mlp(timestep_embedding(t)) + label_embed(labels) shift, scale, gate = modulation(condition)
条件向量为每个样本产生调制参数,控制归一化后的 token features 和残差分支。类别信息因此能影响所有 patch,但不会增加 patch token 数。
3. Self-Attention 负责 latent patch 之间的信息交换
attention_input = modulate(norm1(tokens), shift_msa, scale_msa) attention_output = self_attention(attention_input) tokens = tokens + gate_msa * attention_output
这里 Q/K/V 全部来自 latent patch tokens,所以是 Self-Attention。与 No.11 不同,class embedding 不作为文字 K/V。
4. 输出必须回到 latent 张量形状
patch_noise = final_layer(tokens, condition) predicted_noise = unpatchify(patch_noise) loss = mse(predicted_noise, true_noise)
Scheduler 需要与 z_t 同形状的噪声预测,因此 Transformer token 输出必须 unpatchify 回 [B,C,H,W]。
5. CFG 来自 condition dropout
labels[random_mask] = null_class # training eps_null = dit(z_t, t, null_class) eps_class = dit(z_t, t, selected_class) eps = eps_null + cfg_scale * (eps_class - eps_null)
这使同一个 DiT 同时学会有条件与无条件预测。生成按钮加载训练后状态,从同一纯噪声起点比较不同 CFG scale。