算法可视化与交互学习平台

学习 Cross Attention:Prompt 如何控制 Latent DiffusionCross Attention: How Prompts Guide Latent Diffusion

在 No.4 VAE、No.5 DDPM 和 No.10 Latent Diffusion 的基础上,理解 Cross Attention 如何把 prompt token 注入 denoising network,让 latent 空间中的每个位置在预测噪声时读取文字条件,从而控制图像生成方向。

Deep LearningIntermediateFree
Kernel
1

先抓住 Cross Attention 的核心直觉

Cross Attention 的一句话版本:

图像 latent 中的每个空间位置发出 Query, prompt token 提供 Key / Value, denoiser 用读到的文字语义来决定这一处应该怎样去噪。

No.10 已经说明,Latent Diffusion 可以在 z_t 上训练 denoiser 预测噪声 epsilon。但无条件版本只能学到“像训练数据”的图像,不能回答“按这段文字生成什么”。

Cross Attention 就是文字进入 denoising network 的关键接口。它不直接画图,也不直接输出像素,而是在每个去噪 timestep 中改变 U-Net / denoiser 的内部特征,让噪声预测变成:

epsilon_theta(z_t, t) -> 无文字条件 epsilon_theta(z_t, t, c) -> 有 prompt 条件

其中 c 是 prompt 经过 text encoder 得到的 token embeddings。

2

它接在 No.10 的哪里

No.10 的 Latent Diffusion 已经有三段结构:VAE 把图像压缩成 latent,DDPM 在 latent 中加噪/去噪,Decoder 把干净 latent 翻译回图像。Cross Attention 加在中间的 denoiser 内部。

No.4 VAE: image x -> Encoder -> latent z -> Decoder -> image x_hat No.5 DDPM: z0 -> zt, train epsilon_theta(zt, t) No.10 Latent Diffusion: 在 latent z 上做 DDPM No.11 Cross Attention: prompt -> text encoder -> token embeddings c zt + t + c -> denoiser -> predicted epsilon

换句话说,Cross Attention 解决的是“条件从哪里进入模型”的问题。VAE 不负责理解文字,DDPM 公式本身也不理解文字;真正读 prompt 的地方,是 denoising network 中的 Cross Attention 层。

3

Self-Attention vs Cross-Attention

Self-Attention 和 Cross-Attention 用的是同一个核心公式,但 Q/K/V 的来源不同。

机制Q 来自哪里K / V 来自哪里它在问什么
Self-Attention同一组 token同一组 token这一组 token 内部,谁应该看谁
Cross-Attention当前图像 / latent featuresprompt token embeddings每个图像位置应该读取哪些文字语义

在文生图 Latent Diffusion 中,可以把图像 latent feature map 看成一张小网格。每个网格位置都发出一个 Query,去和 prompt 中的 token 做匹配。

Self-Attention: Q = X Wq, K = X Wk, V = X Wv Cross-Attention in text-to-image: Q = H Wq # H: image/latent feature map K = C Wk # C: text token embeddings V = C Wv
4

Cross Attention 公式:Q 从图像来,K/V 从文字来

H 是当前 noisy latent 经过 U-Net block 得到的图像特征,C 是 prompt token embeddings。Cross Attention 让每个 latent 空间位置根据自己的 Query 去读取文字 token 的 Value。

当前 noisy latent 的图像特征;每个空间位置是一条 query 来源
image/latent features
text encoder 输出的 prompt token embeddings
prompt token embeddings
由 latent feature 线性投影得到的 Query
queries from latent features
由文字 token embedding 投影得到的 Key
keys from text tokens
由文字 token embedding 投影得到的 Value
values from text tokens
latent positions x prompt tokens 的注意力权重矩阵
attention weights
每个 latent 位置读到的文字上下文
text context

为什么 Q 来自图像

因为当前要被更新的是图像 latent feature。每个空间位置都要问:为了让这里符合 prompt,我应该参考哪个文字 token?

为什么 K/V 来自文字

Key 负责被匹配,Value 负责提供语义信息。prompt token 不是直接生成图像,而是作为条件信息被 denoiser 反复读取。

最小代码

q = Wq(latent_features) k = Wk(text_embeddings) v = Wv(text_embeddings) attn = softmax(q @ k.transpose(-2, -1) / sqrt(d)) context = attn @ v
5

逐步演算:一个 latent 位置如何读取 prompt

准备一个 latent 位置的 Query 和两个 prompt token 的 Key
prompt = "draw circle"
q_cell = [0.8, 0.4]
k_draw = [0.2, 0.6]
k_circle = [0.7, 0.2]
Initial Variables
q_cell
[0.8, 0.4]
k_draw
[0.2, 0.6]
k_circle
[0.7, 0.2]
scale
1.414
Step 1 Variables
q_cell
[0.8, 0.4]
k_draw
[0.2, 0.6]
k_circle
[0.7, 0.2]
Step 1 / 4
6

带 prompt 条件的 DDPM 训练目标

Cross Attention 不改变 DDPM 的监督信号。训练目标仍然是预测 forward process 中真实加入的噪声 epsilon,只是 denoiser 在预测时多读取了 prompt 条件 c。

第 t 个噪声等级下的 latent
noisy latent
prompt token embeddings / text condition
text condition
forward 加噪时真实加入的噪声
true added noise
读取 prompt 后预测的噪声
prompt-conditioned predicted noise

和 No.10 的关系

No.10 的 loss 是 。No.11 只是把 denoiser 的输入扩展为 ,loss 的本质仍然是 epsilon prediction。

最小代码

noise = torch.randn_like(z0) zt = sqrt(alpha_bar[t]) * z0 + sqrt(1 - alpha_bar[t]) * noise text = text_encoder(prompt_tokens) noise_pred = denoiser(zt, t, text) loss = mse(noise_pred, noise)
7

Cross Attention 如何改变 denoiser 的预测

Cross Attention 的输出不是最终图片,而是一份文字上下文 context。这份上下文会和 U-Net / denoiser 的图像特征融合,然后影响最终的 noise_pred

无条件 denoiser: features = image_block(z_t, t) noise_pred = output_block(features) 带 Cross Attention 的 denoiser: features = image_block(z_t, t) context = cross_attention(Q=features, K=text, V=text) features = fuse(features, context) noise_pred = output_block(features)

所以 prompt 的作用不是“把文字翻译成像素”,而是改变每一步噪声预测的方向。模型从纯噪声开始时,很多结构都不确定;每一步读取 prompt,都会让 latent 更倾向于符合文字条件的区域。

8

Cross Attention 与 CFG 的关系

Classifier-Free Guidance,简称 CFG,不是 Cross Attention 本身,但它经常和 Cross Attention 一起出现在 Stable Diffusion 这类模型里。

它的做法是同一个 z_t 跑两次 denoiser:一次不给 prompt 或给空 prompt,得到 epsilon_uncond;一次给真实 prompt,得到 epsilon_text。两者差值表示“文字条件把预测往哪里推”,guidance scale s 会放大这个方向。

概念负责什么
Cross Attention让 denoiser 能读取 prompt token
CFG放大有 prompt 和无 prompt 两次预测之间的差异

先理解 Cross Attention,再理解 CFG 会更稳。因为 CFG 的前提是模型已经能通过 Cross Attention 产生一份有文字条件的噪声预测。

9

Toy Cross Attention 实验:用提示词控制生成形状

这个实验分为两个真实计算阶段:先点击“① 训练 Cross Attention 模型”,训练 TinyVAE 与带 Cross Attention 的 latent DDPM,并把模型权重临时保存在当前页面;训练完成后,由用户自行在空白输入框中填写形状提示词,再点击“② 生成图形”。系统不会内置或自动提交默认提示词,空输入也不会回退为 circle。 支持 circle、square、triangle、diamond、cross、smile,也支持圆形、正方形、三角形、菱形、十字、笑脸等中文提示词。每次生成都会加载同一个训练后模型,从纯 latent 高斯噪声开始,按用户这一次输入的 Prompt 执行完整反向去噪。

Parameter Panel
7 Params
10

拆解 Toy Cross Attention:公式如何落到代码

这张卡片只追踪一件事:Prompt 条件如何经过 Cross Attention,进入 latent DDPM 的噪声预测。Cross Attention 不直接输出形状,也不直接生成像素;它在每个 timestep 改写 denoiser 的内部特征,使预测噪声 epsilon_theta 带上文字条件。

noisy latent z_t + timestep t -> image feature H -> Query Q prompt tokens -> text embedding C -> Key K, Value V Q 与 K 匹配 -> attention weights A A 对 V 加权求和 -> text context H 与 context 融合 -> predicted epsilon predicted epsilon -> DDPM reverse step -> z_(t-1)

1. Cross Attention 位于哪条算法链上

训练图像先由 VAE Encoder 压缩成干净 latent z0。DDPM forward process 随机选择 timestep t,并加入已知高斯噪声 epsilon

t = torch.randint(0, T, (batch_size,))
noise = torch.randn_like(z0)
a = alpha_bar[t].view(-1, 1, 1, 1)
z_t = torch.sqrt(a) * z0 + torch.sqrt(1 - a) * noise

此时 z_t 同时含有图像结构和噪声。Denoiser 的任务不是猜最终图片,而是在已知 z_tt 和 Prompt 条件时预测刚才加入的 epsilon

2. 图像 latent 先变成可发问的特征 H

t_img = t.float().view(b, 1, 1, 1).expand(b, 1, h, w) / (T - 1)
h = self.in_net(torch.cat([z_t, t_img, coords], dim=1))
h_flat = h.flatten(2).transpose(1, 2)

in_net 把 noisy latent、timestep 和空间坐标融合为图像特征 H。若 latent 分辨率为 4x4,则空间位置数 N=16。展平后:

H : [B, d_model, h, w] H_flat : [B, N, d_model] N : h * w H_flat[:, i, :] : 第 i 个 latent 位置当前看见的图像特征

每一个 latent 位置随后都会产生自己的 Query。直觉上,它在问:“为了正确去掉我这里的噪声,我应该从 Prompt 的哪些 token 读取信息?”

3. Prompt token 变成条件矩阵 C

text = self.token_emb(token_ids) + self.token_pos.unsqueeze(0)

Token embedding 加上 token 位置编码后得到条件矩阵 C。设 Prompt 有 L 个 token:

C : [B, L, d_text] C[:, j, :] : 第 j 个 Prompt token 的语义向量

真实文生图模型通常由大型 text encoder 产生 C;Toy 实验用小型 nn.Embedding 代替。这个替换只缩小了文字编码器,不改变 Cross Attention 的计算结构。

4. Q 来自图像,K/V 来自文字

q = self.to_q(h_flat)
k = self.to_k(text)
v = self.to_v(text)
Q : [B, N, d] 每个 latent 位置发出的查询 K : [B, L, d] 每个 Prompt token 用于被匹配的索引 V : [B, L, d] 匹配成功后真正读取的语义内容

这是 Cross Attention 与 Self-Attention 最关键的区别:Q 和 K/V 来自两种不同模态。图像 latent 负责发问,Prompt token 负责提供可匹配、可读取的条件信息。

5. QK^T 计算每个 latent 位置应关注哪个 token

scores = q @ k.transpose(-2, -1) / math.sqrt(text_dim)
attn = F.softmax(scores, dim=-1)
scores : [B, N, L] attn : [B, N, L] attn[b, i, j] = 第 b 张图中,第 i 个 latent 位置读取第 j 个 Prompt token 的比例 对固定的 latent 位置 i: sum_j attn[b, i, j] = 1

除以 sqrt(d) 是为了防止维度增大后点积幅度过大,导致 softmax 过早饱和。softmax(dim=-1) 则让每个 latent 位置在所有 Prompt token 之间分配注意力。

6. A 乘 V:把相关文字内容读回图像位置

context = attn @ v
context = self.context_out(context)

对第 i 个 latent 位置,context 是所有 Value 的加权和:

因此 context 不是一个全局 Prompt 向量,而是每个 latent 位置各自读取文字后得到的条件向量。不同空间位置可以对同一 Prompt 形成不同的读取比例。

7. Context 如何真正改变噪声预测

context_map = context.transpose(1, 2).contiguous().view(
    b, model_dim, latent_res, latent_res
)
noise_pred = self.out_net(torch.cat([h, context_map], dim=1))

Context 被还原为空间 feature map,并与原图像特征 H 融合。最终输出仍是与 z_t 同形状的预测噪声:

没有文字条件:epsilon_theta(z_t, t) 加入 Cross Attention:epsilon_theta(z_t, t, C) Prompt 改变 -> C 改变 -> K/V 改变 -> attention 与 context 改变 -> predicted epsilon 改变 -> 下一步 z_(t-1) 改变

所以 Prompt 控制图像的直接机制,是它改变了每一步“应该减去哪一部分噪声”的预测方向。

8. 训练目标没有被 Cross Attention 改写

pred_noise = denoiser(z_t, t, token_ids)
loss = F.mse_loss(pred_noise, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()

监督答案仍是 forward process 中由程序亲自加入的真实噪声 noise。Cross Attention 增加的是条件输入 C,并没有把 DDPM 训练改成分类任务,也没有让模型直接拟合目标形状像素。

9. 生成时为什么每一步都要重新计算 Attention

z_t = torch.randn(1, latent_channels, h, w)
for t in reversed(range(T)):
    epsilon = denoiser(z_t, t, prompt_tokens)
    z_t = ddpm_reverse_step(z_t, epsilon, t)
image = vae.decode(z_t * z_std + z_mean)

生成从纯 latent 高斯噪声开始。整段 Prompt 的 token embeddings 在一次生成中保持不变,但 z_t 每一步都在变化,因此图像 Query Q、Attention 权重 A 和 context 都必须重新计算。早期步骤更多决定大结构,后期步骤逐渐修正边界和局部细节。

10. 怎样阅读实验里的 Attention 热力图

热力图的一行:一个 latent 空间位置 / 一个 Query 热力图的一列:一个 Prompt token / 一个 Key-Value 对 单元格 A[i,j]:位置 i 从 token j 读取信息的比例

热力图说明“内部特征在这一步如何分配 Prompt 信息”,但单张热力图不能单独证明某个像素由某个词唯一决定。最终图像来自所有 timestep、所有 latent 位置和网络非线性共同累积的结果。

完整变量对齐

z0 -> VAE 编码后的干净 latent z_t -> forward process 加噪后的 latent H -> in_net([z_t, t, coords]) C -> token_emb(prompt_tokens) Q -> to_q(H_flat) K, V -> to_k(C), to_v(C) S -> Q @ K.T / sqrt(d) A -> softmax(S) context -> A @ V epsilon_theta -> out_net([H, context_map]) training target -> forward process 中真实加入的 epsilon reverse result -> 根据 epsilon_theta 得到 z_(t-1)

Toy 边界:实验中的小型 Prompt parser 只负责把输入转换成训练词表中的 token id,相当于简化版 tokenizer。它不是 Cross Attention 算法的核心;核心始终是图像 Q 与文字 K/V 的匹配、读取、融合,以及它对条件噪声预测的持续影响。

11

Stable Diffusion 全链路:把已经学过的组件拼起来

到这里不再引入一套新的生成公式,而是把 No.4、No.5、No.10 和 No.11 拼成一个完整系统。Stable Diffusion 可以理解为:在压缩 latent 中运行、由文字条件控制、用 CFG 引导并由 Scheduler 推进采样的 Latent Diffusion 系统。

Prompt -> Tokenizer -> Text Encoder -> C_prompt 空 Prompt / Negative Prompt -> Text Encoder -> C_uncond z_T ~ N(0, I) | v Conditional Denoiser / U-Net Q <- noisy latent features K,V <- text embeddings | v CFG 合并有条件与无条件噪声预测 | v Scheduler: z_t -> z_(t-1) | 重复所有 timestep v z_0 -> VAE Decoder -> 最终图像

1. 已学模块怎样对应 Stable Diffusion

已有模块Stable Diffusion 中的职责关键变量
No.4 VAE训练时把图像压缩成 latent;生成结束后把 latent 解码成图像xz0x_hat
No.5 DDPM定义 forward 加噪、噪声预测目标和 reverse denoisingtepsilonalpha_bar_t
No.10 Latent Diffusion把扩散过程从像素空间搬到压缩 latent 空间z_tepsilon_theta(z_t,t)
No.11 Cross Attention让 denoiser 在每一步读取 Prompt embeddingsCQ/K/Vcontext
No.11 CFG放大文字条件相对无条件预测带来的方向变化epsilon_uncondepsilon_texts

这些组件不是串行训练出来的一整块黑盒。VAE、Text Encoder 和条件 denoiser 各有明确边界;推理 Pipeline 负责把它们按正确顺序连接起来。

2. Training 与 Inference 是两条不同链路

阶段TrainingInference / Generation
起点真实训练图像 x纯 latent 高斯噪声 z_T
VAE Encoder需要:x -> z0通常不需要
真实噪声答案程序采样并加入 epsilon,因此监督答案已知不存在真实答案,只能反复使用模型预测
timestep每个 batch 随机抽一个 t按 Scheduler 给出的 timestep 顺序逐步运行
参数更新计算 loss、反向传播、更新 denoiser不更新参数,只执行前向预测
VAE Decoder训练 denoiser 时通常不必每步解码最后把 z0 解码为图像
Training: image x -> VAE Encoder -> z0 z0 + sampled epsilon + random t -> z_t z_t + t + Prompt -> denoiser -> predicted epsilon MSE(predicted epsilon, sampled epsilon) -> update denoiser Inference: pure noise z_T + Prompt -> repeated conditional denoising -> z_0 -> VAE Decoder -> image

3. CFG 为什么需要同一个 z_t 跑两次 denoiser

eps_uncond = unet(z_t, t, context=negative_or_empty_context)
eps_text = unet(z_t, t, context=prompt_context)
eps_cfg = eps_uncond + guidance_scale * (eps_text - eps_uncond)

eps_uncond 表示“不强调这段文字时,模型认为应该怎样去噪”;eps_text 表示“读取当前 Prompt 后应该怎样去噪”。两者差值就是文字条件造成的方向变化,CFG 用 guidance_scale 放大它。Negative Prompt 并不是第三个独立模型,它通常进入无条件分支,作为需要远离的文字上下文。

4. Scheduler 不负责理解 Prompt

Cross Attention 和 CFG 决定这一步采用哪一个噪声预测;Scheduler 决定如何用该预测把 z_t 更新成下一状态。它保存 timestep 顺序、噪声日程和具体数值更新规则,但不读取 Prompt,也不学习图像语义。

Denoiser / U-Net: (z_t, t, text context) -> epsilon_prediction CFG: (epsilon_uncond, epsilon_text, guidance_scale) -> epsilon_cfg Scheduler: (z_t, epsilon_cfg, t) -> z_(t-1)

因此更换 Scheduler 可以改变采样步数、速度和轨迹,但不会替代 denoiser,也不会替代 Cross Attention。

5. 完整 Stable Diffusion 采样循环

# 1. 文字条件只编码一次
prompt_context = text_encoder(tokenizer(prompt))
uncond_context = text_encoder(tokenizer(negative_prompt_or_empty))

# 2. 不从图像开始,而是从纯 latent 噪声开始
z = torch.randn(latent_shape)

# 3. Scheduler 给出实际采样 timestep
for t in scheduler.timesteps:
    eps_uncond = unet(z, t, context=uncond_context)
    eps_text = unet(z, t, context=prompt_context)

    eps_cfg = eps_uncond + guidance_scale * (eps_text - eps_uncond)
    z = scheduler.step(eps_cfg, t, z).prev_sample

# 4. 只在去噪结束后把 latent 翻译回像素
image = vae.decode(z)

循环中 Text Encoder 的输出保持不变,但 z_t 每一步都变化,所以 U-Net 内部的图像 Query、Cross Attention 权重和噪声预测会在每个 timestep 重新计算。

6. 从课程变量读回真实系统

课程中的 TinyVAE -> Stable Diffusion 的 VAE 课程中的 latent z_t -> 压缩图像 latent 课程中的 token embedding C -> Text Encoder 输出 课程中的 Cross Attention -> U-Net 内部文字条件接口 课程中的 epsilon_theta -> Conditional U-Net / denoiser 课程中的 alpha_bar / timestep -> Scheduler 管理的噪声日程 课程中的 reverse loop -> Stable Diffusion sampling pipeline 课程中的 Decoder -> 最终 latent-to-image 解码
AI
问问 LLM:把 Cross Attention 讲回 Latent Diffusion