Qwen3.5采用了混合注意力机制,每四层注意力中,有三层是Gated Delta Attention,一层是带门控的Full Attention。GDN是一种线性注意力机制,要理解其结构,要从理解线性注意力机制开始。

1 线性注意力机制

在标准的Attention机制当中,我们执行的是:

这其中,Q是[seqlens,nhead,headdim]的张量,K和V则是[seqlens,nkvhead,headdim]的张量,计算复杂度是$O(S^2d)$。S表示seqlen。

对于$QK^TV$这三个矩阵的运算,如果没有softmax和因果掩码,则可以根据矩阵乘法结合律来先计算$K^TV$,这样就变成了Q乘上一个[headdim,headdim]的矩阵,整体的复杂度是$O(Sd^2)$,时间复杂度和序列长度是线性关系,这就是线性注意力机制的根源。

先暂时不考虑去掉softmax的影响,我们首先要把因果掩码加回来,这一点上,我们把QK^TV这个计算做一个拆分,按照每个token来考虑,每个Qi先和Kj计算注意力分数,然后和V的向量j乘,得到结果的一个分片。

image-20260318165041344

公式上表示为:

注意力分数作为标量可以移到后面去(vj移到前面将被视为列向量),而向量内积可以交换位置(交换后q_i为列向量,kj^T为行向量)

从而得出我们的Oi可以根据累积的v和k得到:

累加矩阵被定义为Si,由v和k的外积构成,由于kj本来就是以行向量存储在内存当中,所以读取的时候可以直接把他按照行向量读进来,而v则需要转置。

这样只要保存状态矩阵S,每次计算oi时更新Si,同时以线性复杂度计算注意力。v是[seqlen,nkvhead,headdim],k是[seqlen,nkvhead,headdim],向量一一对应做外积得到[nkvhead,headdim]的S。

上述的这种累积式逐步更新的S解决了因果掩码的问题,但是softmax被去掉的问题仍然还在,在常规注意力机制下,softmax会通过指数运算重点保留注意力高分的位置,再通过V得到结果,去掉softmax后,所有的K和V的信息都不经缩放累加到了S上,导致信息模糊。有不同的处理办法来解决这个问题,典型方法包括:

  • 添加衰减因子:$S_t = \lambda S_{t-1} + v_t k_t^\top$
  • 门控机制
  • 更复杂的核函数来映射q和k

2 Gated DeltaNet linear Attention

针对信息模糊的问题,DeltaNet提出了一种通过取出旧的value,删除旧的kv外积,再写入融合后的外积的方法,来控制信息的写入和更新。

怎么取出之前的v呢,假设 S 矩阵目前只存入了两对信息:$(k_1, v_1)$ 和 $(k_2, v_2)$。那么当前的S为:

可以用当前的k来检索,假设当前k是k3:

得到的就是我们认为的需要消除的旧记忆,在更新S的时候,要去掉这部分旧记忆,而新的vt则是旧的v值和新的v值的一个根据β控制比例的融合,和k^T做外积计算累加到S上。

image-20260318165101149

根据上图描述,DeltaNet可以展开,提取公因式紧凑的写为:

Gated DeltaNet进一步引入了遗忘门,来控制信息的衰减,其S的更新进一步添加了α参数,变为了:

3 Qwen3.5 GDN

之前的介绍只是介绍了GatedDeltaLinearAttention的S更新原理,实际的Qwen3.5线性注意力机制要包含更多内容,比如在下图中,我们可以看到QKV会先进行一个kernel为4的Conv1D + SiLU,这一步是为了补偿线性化导致的表达能力下降,关于其具体的原因, 为什么线性注意力要加Short Conv?中给出了进一步的解释,不过对于infra或者说inference的角度来说,恐怕还是有点难理解,这里我们就浅显的理解为是为了通过交互周围的token,来提高表达能力。

这张图看过去,除了上述的Conv1D,还有一些其他值得注意的地方,这些有些是linear Attention的常见技巧,如果是第一次了解一种linear attention,可能不太熟悉

  • Q和K的维度是相同的,而V的head更多:V决定读写什么内容,可以用更多的 value heads 提升表达能力/通道数,同时是 per-value-head 的门控参数:让不同 value heads 即使共享,也能用不同的衰减/更新强度形成差异化记忆[2]。
  • 新的参数:A_log,delta_bias,两个向量参数
  • 衰减因子的定义:α=exp(g),g = -exp(A_log) softplus(a + delta_bias)
  • QKV之外的投影:
    • b,学习率β=sigmoid(b)
    • a,用于衰减因子g的生成,图里的A Proj对应的是下面的a,与A_log无关
    • z,用于门控
  • L2Norm:防止Q和K数值爆炸
  • 输出门控与归一化:使用 RMSNorm 和门控处理输出。

image-20260318165110981

结合第2节的公式以及上述的补充,Qwen3.5的Gated Delta linearAttention的大体计算流程就可以理解了。下一小节,我们再结合代码看一下GDN计算的流程。

4 GDN

在site-packages/transformers/models/qwen3_5/modeling_qwen3_5.py下,可以看到GDN的具体计算过程。我们逐步来过一遍计算过程中的tensor变化,以qwen3.5-0.8b为例。

输入的tensor为[B, S, d]

1
2
3
4
5
6
def forward(self, hidden_states, cache_params=None, cache_position=None, attention_mask=None):
# 将 Padding 部分 mask 掉,防止对计算造成干扰
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)

# 记录序列信息
batch_size, seq_len, _ = hidden_states.shape

对QKV的投影:

1
2
3
4
5
6
7
mixed_qkv = self.in_proj_qkv(hidden_states) 
# shape: (B, S, 6144)
# 计算: 6144 = (KeyDim * NumHeads * 2) + (ValueDim * NumHeads)
# 这里 = (128*16*2) + (128*16) = 4096 + 2048 = 6144

mixed_qkv = mixed_qkv.transpose(1, 2)
# shape: (B, 6144, S) 转置以便 Conv1d 在序列维度(最后一维)上滑动

对于0.8B的qwen3.5,num_k_heads和num_v_heads都是16,更大尺寸中,k的头数会少于v的头数,后面的计算会需要repeat k。

接下来是门控和参数的投影,b和a对于每个头来说是标量

1
2
3
4
5
6
7
# z: 输出门控分支,用于最后的归一化融合
z = self.in_proj_z(hidden_states) # (B, S, 2048)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) # (B, S, 16, 128)

# b, a: 用于 Delta Rule 的核心控制参数 (控制更新强度与衰减)
b = self.in_proj_b(hidden_states) # (B, S, 16) - 这里的 16 映射到 Head 数
a = self.in_proj_a(hidden_states) # (B, S, 16)

casual conv:

1
2
3
4
5
6
7
8
if use_precomputed_states: # Decoding 模式 (推理阶段,seq_len=1)
mixed_qkv = self.causal_conv1d_update(mixed_qkv, conv_state, ...)
# 更新 conv_state 缓存,仅处理当前 single token
else: # Prefill 模式 (首字推理)
if cache_params is not None: # 如果需要缓存以便后续推理
cache_params.conv_states[self.layer_idx] = F.pad(mixed_qkv, ...)
# 执行全序列卷积
mixed_qkv = self.causal_conv1d_fn(x=mixed_qkv, ...)

卷积后的QKV拆分:

1
2
3
4
5
6
7
mixed_qkv = mixed_qkv.transpose(1, 2) # 恢复形状 (B, S, 6144)
query, key, value = torch.split(mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1)

# Reshape 为多头空间 (Head=16, Dim=128)
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) # (B, S, 16, 128)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) # (B, S, 16, 128)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) # (B, S, 16, 128)

衰减和学习参数的计算:

1
2
3
4
5
6
beta = b.sigmoid() # (B, S, 16) 范围 (0, 1)。作为写入权重,决定当前信息存入状态的比重。

# g: 衰减系数 (Decay Gate)
# 公式: g = -exp(A_log) * softplus(a + dt_bias)
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
# shape: (B, S, 16)。负值用于控制线性状态的指数衰减。

GatedDeltaRule,结合之前的流程图,可以注意到QK的l2norm被融合到了该算子当中去

1
2
3
4
5
6
7
8
if not use_precomputed_states: # Prefill
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query, key, value, g=g, beta=beta, ...
)
else: # Decoding (使用 recurrent_state 进行单步增量更新)
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query, key, value, g=g, beta=beta, initial_state=recurrent_state, ...
)
  • core_attn_out: (B, S, 16, 128)。注意力机制提取后的特征。
  • last_recurrent_state: (B, 16, 128, 128)。这是一个庞大的二阶状态矩阵,存储了关联记忆。

最后是门控的处理:

1
2
3
4
5
6
7
8
9
10
# 展平准备进入高效的 Norm 算子
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) # (B*S*16, 128)
z = z.reshape(-1, self.head_v_dim) # (B*S*16, 128)

# 利用 z 执行带门控的 RMSNorm
core_attn_out = self.norm(core_attn_out, z)

# 恢复序列并合并 Head 通道
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) # (B, S, 2048)
output = self.out_proj(core_attn_out) # (B, S, 1024)

5 recurrent Gated Delta Rule

这通常是decode阶段的更新,因为这个阶段我们并不需要考虑seq维度的累加,但是其对于prefill阶段也是适用的,只要迭代更新seqlen次。对于输入,首先做qknorm,并把H调整到前面[B, H, S, Dk]

1
2
3
4
5
6
7
8
9
10
11
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale #标准scale

状态的初始化,如果有存储了last state,那么会通过initial_state传入

1
2
3
4
5
6
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)

状态更新,其实如果是普通的每个batch一个token的decode,这里的seqlen应该是1,q_t等是取出第i步的长度为head_dim的向量 [B,H,head_dim]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
for i in range(sequence_length):
q_t = query[:, :, i]
k_t = key[:, :, i]
v_t = value[:, :, i]
# [B, H] -> [B, H, 1, 1] 要去乘 [B, H, k_head_dim, v_head_dim] 的状态矩阵
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
# [B, H] -> [B, H, 1] 要去乘(v_t - kv_mem),是[B H v_head_dim]的向量
beta_t = beta[:, :, i].unsqueeze(-1)

last_recurrent_state = last_recurrent_state * g_t

# [B, H, Dk, Dv] x [B, H, Dk, 1] 矩阵乘列向量
# 实际上算的是y = S^T x k [B, H, Dv]
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t

# [B,H, Dk,Dv] += [B,H,Dk, 1] * [B, H, 1, Dv]
last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
# St^T * q [B,H, Dk,Dv] x [B, H, Dk]
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)

if not output_final_state:
last_recurrent_state = None
# [B, S, H, Dk]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state

Reference

[1] Gated DeltaNet :Qwen和Kimi都在用的线性注意力机制 - Vinci叽里呱啦的文章 - 知乎 https://zhuanlan.zhihu.com/p/1971671251123705351

[2]【LLM】Qwen3.5解剖 - Plunck的文章 - 知乎 https://zhuanlan.zhihu.com/p/2005306558997882654

[3] 以Qwen3Next为例,从Inference视角学习LinearAttention - CodeLearner的文章 - 知乎 https://zhuanlan.zhihu.com/p/1993467979799732563

[4] https://spaces.ac.cn/archives/11320