问题描述
实现一个多头注意力机制,给定输入 query、key、value 和可选的 mask,计算输出和注意力权重。
具体需要完成以下几个步骤:
- 对 Q、K、V 进行线性变换
- 拆分为多头并进行缩放点积注意力计算
- 合并多头结果
- 通过输出线性层映射回原维度
解题思路
多头注意力是 Transformer 的核心模块,其核心思想是将输入投影到多个子空间,分别计算注意力,再拼接起来。关键点在于:
- 线性变换:Q、K、V 各自通过一个线性层投影到
d_model维空间 - 多头拆分:将
d_model维度均分为n_heads份,每个头的维度为d_head = d_model // n_heads - 缩放点积注意力:$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$,其中 $\sqrt{d_k}$ 是缩放因子,防止点积过大导致 softmax 梯度消失
- Mask 处理:将 mask 位置替换为
-inf,使 softmax 后该位置权重为 0 - 合并输出:将所有头的结果拼接回
d_model维度
代码实现
import math
import torch
from torch import nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout_rate):
super().__init__()
self.d_model = d_model # 模型总维度
self.n_heads = n_heads # 注意力头数
self.d_head = d_model // n_heads # 每个头的维度
# 四个线性变换层
self.q_linear = nn.Linear(d_model, d_model) # Q 的投影矩阵
self.k_linear = nn.Linear(d_model, d_model) # K 的投影矩阵
self.v_linear = nn.Linear(d_model, d_model) # V 的投影矩阵
self.out_linear = nn.Linear(d_model, d_model) # 输出投影矩阵
self.scale = math.sqrt(self.d_head) # 缩放因子
self.dropout = nn.Dropout(dropout_rate) # Dropout 层
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1. 线性变换 + 多头拆分
# Q: (batch, seq_len, d_model) -> (batch, n_heads, seq_len, d_head)
q = (
self.q_linear(query) # (batch, seq_len, d_model)
.view(batch_size, -1, self.n_heads, self.d_head) # (batch, seq_len, n_heads, d_head)
.transpose(1, 2) # (batch, n_heads, seq_len, d_head)
)
k = (
self.k_linear(key)
.view(batch_size, -1, self.n_heads, self.d_head)
.transpose(1, 2)
)
v = (
self.v_linear(value)
.view(batch_size, -1, self.n_heads, self.d_head)
.transpose(1, 2)
)
# 2. 计算缩放点积注意力分数
atten_scores = (q @ k.transpose(-1, -2)) / self.scale # (batch, n_heads, q_len, k_len)
# 3. Mask 处理(将 padding 位置设为 -inf)
if mask is not None:
if mask.dim() == 2: # 如果 mask 维度为 (seq_len, seq_len)
mask = mask.unsqueeze(0).unsqueeze(0) # 扩展为 (1, 1, seq_len, seq_len)
atten_scores = atten_scores.masked_fill(mask, float("-inf"))
# 4. Softmax + Dropout
atten_weights = atten_scores.softmax(dim=-1) # (batch, n_heads, q_len, k_len)
droped_atten_weights = self.dropout(atten_weights)
# 5. 加权求和:权重 × V
atten_output = droped_atten_weights @ v # (batch, n_heads, q_len, d_head)
# 6. 合并多头:transpose + view
multi_atten_output = atten_output.transpose(1, 2).view(
batch_size, -1, self.d_model
) # (batch, q_len, d_model)
# 7. 输出线性变换
output = self.out_linear(multi_atten_output)
return output, atten_weights
测试用例
if __name__ == "__main__":
# 初始化参数
batch_size = 2
seq_len = 4
d_model = 8
n_heads = 2
dropout_rate = 0.1
mha = MultiHeadAttention(d_model, n_heads, dropout_rate)
# 生成随机输入
query = torch.randn(batch_size, seq_len, d_model)
key = torch.randn(batch_size, seq_len, d_model)
value = torch.randn(batch_size, seq_len, d_model)
# 无 mask 测试
output, attn = mha(query, key, value)
print(f"无 mask 输出形状: {output.shape}") # 期望: (2, 4, 8)
print(f"注意力权重形状: {attn.shape}") # 期望: (2, 2, 4, 4)
# 带 mask 测试(模拟 padding mask)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() # 上三角 mask
output_masked, attn_masked = mha(query, key, value, mask=mask)
print(f"带 mask 输出形状: {output_masked.shape}") # 期望: (2, 4, 8)
print(f"带 mask 注意力权重:\n{attn_masked[0, 0]}") # 查看第一头注意力(上三角应为0)
运行结果
无 mask 输出形状: torch.Size([2, 4, 8])
注意力权重形状: torch.Size([2, 2, 4, 4])
带 mask 输出形状: torch.Size([2, 4, 8])
带 mask 注意力权重:
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
[0.5227, 0.4773, 0.0000, 0.0000],
[0.3612, 0.3371, 0.3017, 0.0000],
[0.2534, 0.2501, 0.2488, 0.2477]])
算法分析
- 时间复杂度:$O(batch \times n\_heads \times L^2 \times d\_head)$,主要消耗在注意力矩阵计算
Q @ K^T,这是 $L \times L$ 维度的矩阵乘法 - 空间复杂度:$O(batch \times n\_heads \times L^2)$,需要存储注意力矩阵,其中 $L$ 是序列长度
关键点总结
- 多维张量变换:
view和transpose的组合是实现多头拆分与合并的关键。拆分时先view增加维度再transpose,合并时先transpose恢复再view展平 - 缩放因子:
self.scale = math.sqrt(self.d_head),除以 $\sqrt{d_k}$ 是为了将点积结果的方差控制在 1,防止 softmax 进入饱和区 - Mask 机制:
masked_fill(mask, -inf)将填充位置设为负无穷,经 softmax 后概率为 0。代码中额外处理了 mask 的维度广播 - 返回注意力权重:在可视化或调试时,注意力权重
atten_weights很有用,常用于分析模型的可解释性
扩展思考
- 交叉注意力 vs 自注意力:如果
query来自解码器,key和value来自编码器,则实现的是交叉注意力(Cross-Attention)。本代码通过参数传入实现,灵活支持两种模式 - Flash Attention:当序列很长时,标准注意力的内存开销随 $L^2$ 增长。Flash Attention 通过分块计算和重计算技巧,在不改变结果的前提下大幅降低显存占用
- GQA(Grouped Query Attention):在推理优化中,多个 query 头共享同一组 key-value 头,可以减少 KV Cache 的存储开销