算法可视化与交互学习平台
学习 Cross Attention:Prompt 如何控制 Latent DiffusionCross Attention: How Prompts Guide Latent Diffusion
在 No.4 VAE、No.5 DDPM 和 No.10 Latent Diffusion 的基础上,理解 Cross Attention 如何把 prompt token 注入 denoising network,让 latent 空间中的每个位置在预测噪声时读取文字条件,从而控制图像生成方向。
先抓住 Cross Attention 的核心直觉
Cross Attention 的一句话版本:
No.10 已经说明,Latent Diffusion 可以在 z_t 上训练 denoiser 预测噪声 epsilon。但无条件版本只能学到“像训练数据”的图像,不能回答“按这段文字生成什么”。
Cross Attention 就是文字进入 denoising network 的关键接口。它不直接画图,也不直接输出像素,而是在每个去噪 timestep 中改变 U-Net / denoiser 的内部特征,让噪声预测变成:
其中 c 是 prompt 经过 text encoder 得到的 token embeddings。
它接在 No.10 的哪里
No.10 的 Latent Diffusion 已经有三段结构:VAE 把图像压缩成 latent,DDPM 在 latent 中加噪/去噪,Decoder 把干净 latent 翻译回图像。Cross Attention 加在中间的 denoiser 内部。
换句话说,Cross Attention 解决的是“条件从哪里进入模型”的问题。VAE 不负责理解文字,DDPM 公式本身也不理解文字;真正读 prompt 的地方,是 denoising network 中的 Cross Attention 层。
Self-Attention vs Cross-Attention
Self-Attention 和 Cross-Attention 用的是同一个核心公式,但 Q/K/V 的来源不同。
| 机制 | Q 来自哪里 | K / V 来自哪里 | 它在问什么 |
|---|---|---|---|
| Self-Attention | 同一组 token | 同一组 token | 这一组 token 内部,谁应该看谁 |
| Cross-Attention | 当前图像 / latent features | prompt token embeddings | 每个图像位置应该读取哪些文字语义 |
在文生图 Latent Diffusion 中,可以把图像 latent feature map 看成一张小网格。每个网格位置都发出一个 Query,去和 prompt 中的 token 做匹配。
Cross Attention 公式:Q 从图像来,K/V 从文字来
H 是当前 noisy latent 经过 U-Net block 得到的图像特征,C 是 prompt token embeddings。Cross Attention 让每个 latent 空间位置根据自己的 Query 去读取文字 token 的 Value。
为什么 Q 来自图像
因为当前要被更新的是图像 latent feature。每个空间位置都要问:为了让这里符合 prompt,我应该参考哪个文字 token?
为什么 K/V 来自文字
Key 负责被匹配,Value 负责提供语义信息。prompt token 不是直接生成图像,而是作为条件信息被 denoiser 反复读取。
最小代码
逐步演算:一个 latent 位置如何读取 prompt
prompt = "draw circle"
q_cell = [0.8, 0.4]
k_draw = [0.2, 0.6]
k_circle = [0.7, 0.2][0.8, 0.4]
[0.2, 0.6]
[0.7, 0.2]
[0.8, 0.4]
[0.2, 0.6]
[0.7, 0.2]
带 prompt 条件的 DDPM 训练目标
Cross Attention 不改变 DDPM 的监督信号。训练目标仍然是预测 forward process 中真实加入的噪声 epsilon,只是 denoiser 在预测时多读取了 prompt 条件 c。
和 No.10 的关系
No.10 的 loss 是 。No.11 只是把 denoiser 的输入扩展为 ,loss 的本质仍然是 epsilon prediction。
最小代码
Cross Attention 如何改变 denoiser 的预测
Cross Attention 的输出不是最终图片,而是一份文字上下文 context。这份上下文会和 U-Net / denoiser 的图像特征融合,然后影响最终的 noise_pred。
所以 prompt 的作用不是“把文字翻译成像素”,而是改变每一步噪声预测的方向。模型从纯噪声开始时,很多结构都不确定;每一步读取 prompt,都会让 latent 更倾向于符合文字条件的区域。
Cross Attention 与 CFG 的关系
Classifier-Free Guidance,简称 CFG,不是 Cross Attention 本身,但它经常和 Cross Attention 一起出现在 Stable Diffusion 这类模型里。
它的做法是同一个 z_t 跑两次 denoiser:一次不给 prompt 或给空 prompt,得到 epsilon_uncond;一次给真实 prompt,得到 epsilon_text。两者差值表示“文字条件把预测往哪里推”,guidance scale s 会放大这个方向。
| 概念 | 负责什么 |
|---|---|
| Cross Attention | 让 denoiser 能读取 prompt token |
| CFG | 放大有 prompt 和无 prompt 两次预测之间的差异 |
先理解 Cross Attention,再理解 CFG 会更稳。因为 CFG 的前提是模型已经能通过 Cross Attention 产生一份有文字条件的噪声预测。
Toy Cross Attention 实验:用提示词控制生成形状
这个实验分为两个真实计算阶段:先点击“① 训练 Cross Attention 模型”,训练 TinyVAE 与带 Cross Attention 的 latent DDPM,并把模型权重临时保存在当前页面;训练完成后,由用户自行在空白输入框中填写形状提示词,再点击“② 生成图形”。系统不会内置或自动提交默认提示词,空输入也不会回退为 circle。 支持 circle、square、triangle、diamond、cross、smile,也支持圆形、正方形、三角形、菱形、十字、笑脸等中文提示词。每次生成都会加载同一个训练后模型,从纯 latent 高斯噪声开始,按用户这一次输入的 Prompt 执行完整反向去噪。
拆解 Toy Cross Attention:公式如何落到代码
这张卡片只追踪一件事:Prompt 条件如何经过 Cross Attention,进入 latent DDPM 的噪声预测。Cross Attention 不直接输出形状,也不直接生成像素;它在每个 timestep 改写 denoiser 的内部特征,使预测噪声 epsilon_theta 带上文字条件。
1. Cross Attention 位于哪条算法链上
训练图像先由 VAE Encoder 压缩成干净 latent z0。DDPM forward process 随机选择 timestep t,并加入已知高斯噪声 epsilon:
t = torch.randint(0, T, (batch_size,)) noise = torch.randn_like(z0) a = alpha_bar[t].view(-1, 1, 1, 1) z_t = torch.sqrt(a) * z0 + torch.sqrt(1 - a) * noise
此时 z_t 同时含有图像结构和噪声。Denoiser 的任务不是猜最终图片,而是在已知 z_t、t 和 Prompt 条件时预测刚才加入的 epsilon。
2. 图像 latent 先变成可发问的特征 H
t_img = t.float().view(b, 1, 1, 1).expand(b, 1, h, w) / (T - 1) h = self.in_net(torch.cat([z_t, t_img, coords], dim=1)) h_flat = h.flatten(2).transpose(1, 2)
in_net 把 noisy latent、timestep 和空间坐标融合为图像特征 H。若 latent 分辨率为 4x4,则空间位置数 N=16。展平后:
每一个 latent 位置随后都会产生自己的 Query。直觉上,它在问:“为了正确去掉我这里的噪声,我应该从 Prompt 的哪些 token 读取信息?”
3. Prompt token 变成条件矩阵 C
text = self.token_emb(token_ids) + self.token_pos.unsqueeze(0)
Token embedding 加上 token 位置编码后得到条件矩阵 C。设 Prompt 有 L 个 token:
真实文生图模型通常由大型 text encoder 产生 C;Toy 实验用小型 nn.Embedding 代替。这个替换只缩小了文字编码器,不改变 Cross Attention 的计算结构。
4. Q 来自图像,K/V 来自文字
q = self.to_q(h_flat) k = self.to_k(text) v = self.to_v(text)
这是 Cross Attention 与 Self-Attention 最关键的区别:Q 和 K/V 来自两种不同模态。图像 latent 负责发问,Prompt token 负责提供可匹配、可读取的条件信息。
5. QK^T 计算每个 latent 位置应关注哪个 token
scores = q @ k.transpose(-2, -1) / math.sqrt(text_dim) attn = F.softmax(scores, dim=-1)
除以 sqrt(d) 是为了防止维度增大后点积幅度过大,导致 softmax 过早饱和。softmax(dim=-1) 则让每个 latent 位置在所有 Prompt token 之间分配注意力。
6. A 乘 V:把相关文字内容读回图像位置
context = attn @ v context = self.context_out(context)
对第 i 个 latent 位置,context 是所有 Value 的加权和:
因此 context 不是一个全局 Prompt 向量,而是每个 latent 位置各自读取文字后得到的条件向量。不同空间位置可以对同一 Prompt 形成不同的读取比例。
7. Context 如何真正改变噪声预测
context_map = context.transpose(1, 2).contiguous().view(
b, model_dim, latent_res, latent_res
)
noise_pred = self.out_net(torch.cat([h, context_map], dim=1))Context 被还原为空间 feature map,并与原图像特征 H 融合。最终输出仍是与 z_t 同形状的预测噪声:
所以 Prompt 控制图像的直接机制,是它改变了每一步“应该减去哪一部分噪声”的预测方向。
8. 训练目标没有被 Cross Attention 改写
pred_noise = denoiser(z_t, t, token_ids) loss = F.mse_loss(pred_noise, noise) optimizer.zero_grad() loss.backward() optimizer.step()
监督答案仍是 forward process 中由程序亲自加入的真实噪声 noise。Cross Attention 增加的是条件输入 C,并没有把 DDPM 训练改成分类任务,也没有让模型直接拟合目标形状像素。
9. 生成时为什么每一步都要重新计算 Attention
z_t = torch.randn(1, latent_channels, h, w)
for t in reversed(range(T)):
epsilon = denoiser(z_t, t, prompt_tokens)
z_t = ddpm_reverse_step(z_t, epsilon, t)
image = vae.decode(z_t * z_std + z_mean)生成从纯 latent 高斯噪声开始。整段 Prompt 的 token embeddings 在一次生成中保持不变,但 z_t 每一步都在变化,因此图像 Query Q、Attention 权重 A 和 context 都必须重新计算。早期步骤更多决定大结构,后期步骤逐渐修正边界和局部细节。
10. 怎样阅读实验里的 Attention 热力图
热力图说明“内部特征在这一步如何分配 Prompt 信息”,但单张热力图不能单独证明某个像素由某个词唯一决定。最终图像来自所有 timestep、所有 latent 位置和网络非线性共同累积的结果。
完整变量对齐
Toy 边界:实验中的小型 Prompt parser 只负责把输入转换成训练词表中的 token id,相当于简化版 tokenizer。它不是 Cross Attention 算法的核心;核心始终是图像 Q 与文字 K/V 的匹配、读取、融合,以及它对条件噪声预测的持续影响。
Stable Diffusion 全链路:把已经学过的组件拼起来
到这里不再引入一套新的生成公式,而是把 No.4、No.5、No.10 和 No.11 拼成一个完整系统。Stable Diffusion 可以理解为:在压缩 latent 中运行、由文字条件控制、用 CFG 引导并由 Scheduler 推进采样的 Latent Diffusion 系统。
1. 已学模块怎样对应 Stable Diffusion
| 已有模块 | Stable Diffusion 中的职责 | 关键变量 |
|---|---|---|
| No.4 VAE | 训练时把图像压缩成 latent;生成结束后把 latent 解码成图像 | x、z0、x_hat |
| No.5 DDPM | 定义 forward 加噪、噪声预测目标和 reverse denoising | t、epsilon、alpha_bar_t |
| No.10 Latent Diffusion | 把扩散过程从像素空间搬到压缩 latent 空间 | z_t、epsilon_theta(z_t,t) |
| No.11 Cross Attention | 让 denoiser 在每一步读取 Prompt embeddings | C、Q/K/V、context |
| No.11 CFG | 放大文字条件相对无条件预测带来的方向变化 | epsilon_uncond、epsilon_text、s |
这些组件不是串行训练出来的一整块黑盒。VAE、Text Encoder 和条件 denoiser 各有明确边界;推理 Pipeline 负责把它们按正确顺序连接起来。
2. Training 与 Inference 是两条不同链路
| 阶段 | Training | Inference / Generation |
|---|---|---|
| 起点 | 真实训练图像 x | 纯 latent 高斯噪声 z_T |
| VAE Encoder | 需要:x -> z0 | 通常不需要 |
| 真实噪声答案 | 程序采样并加入 epsilon,因此监督答案已知 | 不存在真实答案,只能反复使用模型预测 |
| timestep | 每个 batch 随机抽一个 t | 按 Scheduler 给出的 timestep 顺序逐步运行 |
| 参数更新 | 计算 loss、反向传播、更新 denoiser | 不更新参数,只执行前向预测 |
| VAE Decoder | 训练 denoiser 时通常不必每步解码 | 最后把 z0 解码为图像 |
3. CFG 为什么需要同一个 z_t 跑两次 denoiser
eps_uncond = unet(z_t, t, context=negative_or_empty_context) eps_text = unet(z_t, t, context=prompt_context) eps_cfg = eps_uncond + guidance_scale * (eps_text - eps_uncond)
eps_uncond 表示“不强调这段文字时,模型认为应该怎样去噪”;eps_text 表示“读取当前 Prompt 后应该怎样去噪”。两者差值就是文字条件造成的方向变化,CFG 用 guidance_scale 放大它。Negative Prompt 并不是第三个独立模型,它通常进入无条件分支,作为需要远离的文字上下文。
4. Scheduler 不负责理解 Prompt
Cross Attention 和 CFG 决定这一步采用哪一个噪声预测;Scheduler 决定如何用该预测把 z_t 更新成下一状态。它保存 timestep 顺序、噪声日程和具体数值更新规则,但不读取 Prompt,也不学习图像语义。
因此更换 Scheduler 可以改变采样步数、速度和轨迹,但不会替代 denoiser,也不会替代 Cross Attention。
5. 完整 Stable Diffusion 采样循环
# 1. 文字条件只编码一次
prompt_context = text_encoder(tokenizer(prompt))
uncond_context = text_encoder(tokenizer(negative_prompt_or_empty))
# 2. 不从图像开始,而是从纯 latent 噪声开始
z = torch.randn(latent_shape)
# 3. Scheduler 给出实际采样 timestep
for t in scheduler.timesteps:
eps_uncond = unet(z, t, context=uncond_context)
eps_text = unet(z, t, context=prompt_context)
eps_cfg = eps_uncond + guidance_scale * (eps_text - eps_uncond)
z = scheduler.step(eps_cfg, t, z).prev_sample
# 4. 只在去噪结束后把 latent 翻译回像素
image = vae.decode(z)循环中 Text Encoder 的输出保持不变,但 z_t 每一步都变化,所以 U-Net 内部的图像 Query、Cross Attention 权重和噪声预测会在每个 timestep 重新计算。