问题描述

实现一个多头注意力机制,给定输入 querykeyvalue 和可选的 mask,计算输出和注意力权重。

具体需要完成以下几个步骤:

  1. 对 Q、K、V 进行线性变换
  2. 拆分为多头并进行缩放点积注意力计算
  3. 合并多头结果
  4. 通过输出线性层映射回原维度

解题思路

多头注意力是 Transformer 的核心模块,其核心思想是将输入投影到多个子空间,分别计算注意力,再拼接起来。关键点在于:

  1. 线性变换:Q、K、V 各自通过一个线性层投影到 d_model 维空间
  2. 多头拆分:将 d_model 维度均分为 n_heads 份,每个头的维度为 d_head = d_model // n_heads
  3. 缩放点积注意力:$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$,其中 $\sqrt{d_k}$ 是缩放因子,防止点积过大导致 softmax 梯度消失
  4. Mask 处理:将 mask 位置替换为 -inf,使 softmax 后该位置权重为 0
  5. 合并输出:将所有头的结果拼接回 d_model 维度

代码实现

import math
import torch
from torch import nn


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout_rate):
        super().__init__()
        self.d_model = d_model          # 模型总维度
        self.n_heads = n_heads          # 注意力头数
        self.d_head = d_model // n_heads  # 每个头的维度

        # 四个线性变换层
        self.q_linear = nn.Linear(d_model, d_model)   # Q 的投影矩阵
        self.k_linear = nn.Linear(d_model, d_model)   # K 的投影矩阵
        self.v_linear = nn.Linear(d_model, d_model)   # V 的投影矩阵
        self.out_linear = nn.Linear(d_model, d_model) # 输出投影矩阵

        self.scale = math.sqrt(self.d_head)           # 缩放因子
        self.dropout = nn.Dropout(dropout_rate)       # Dropout 层

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1. 线性变换 + 多头拆分
        # Q: (batch, seq_len, d_model) -> (batch, n_heads, seq_len, d_head)
        q = (
            self.q_linear(query)                           # (batch, seq_len, d_model)
            .view(batch_size, -1, self.n_heads, self.d_head) # (batch, seq_len, n_heads, d_head)
            .transpose(1, 2)                                # (batch, n_heads, seq_len, d_head)
        )
        k = (
            self.k_linear(key)
            .view(batch_size, -1, self.n_heads, self.d_head)
            .transpose(1, 2)
        )
        v = (
            self.v_linear(value)
            .view(batch_size, -1, self.n_heads, self.d_head)
            .transpose(1, 2)
        )

        # 2. 计算缩放点积注意力分数
        atten_scores = (q @ k.transpose(-1, -2)) / self.scale  # (batch, n_heads, q_len, k_len)

        # 3. Mask 处理(将 padding 位置设为 -inf)
        if mask is not None:
            if mask.dim() == 2:                     # 如果 mask 维度为 (seq_len, seq_len)
                mask = mask.unsqueeze(0).unsqueeze(0)  # 扩展为 (1, 1, seq_len, seq_len)
            atten_scores = atten_scores.masked_fill(mask, float("-inf"))

        # 4. Softmax + Dropout
        atten_weights = atten_scores.softmax(dim=-1)      # (batch, n_heads, q_len, k_len)
        droped_atten_weights = self.dropout(atten_weights)

        # 5. 加权求和:权重 × V
        atten_output = droped_atten_weights @ v           # (batch, n_heads, q_len, d_head)

        # 6. 合并多头:transpose + view
        multi_atten_output = atten_output.transpose(1, 2).view(
            batch_size, -1, self.d_model
        )                                                 # (batch, q_len, d_model)

        # 7. 输出线性变换
        output = self.out_linear(multi_atten_output)
        return output, atten_weights

测试用例

if __name__ == "__main__":
    # 初始化参数
    batch_size = 2
    seq_len = 4
    d_model = 8
    n_heads = 2
    dropout_rate = 0.1

    mha = MultiHeadAttention(d_model, n_heads, dropout_rate)

    # 生成随机输入
    query = torch.randn(batch_size, seq_len, d_model)
    key = torch.randn(batch_size, seq_len, d_model)
    value = torch.randn(batch_size, seq_len, d_model)

    # 无 mask 测试
    output, attn = mha(query, key, value)
    print(f"无 mask 输出形状: {output.shape}")  # 期望: (2, 4, 8)
    print(f"注意力权重形状: {attn.shape}")       # 期望: (2, 2, 4, 4)

    # 带 mask 测试(模拟 padding mask)
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()  # 上三角 mask
    output_masked, attn_masked = mha(query, key, value, mask=mask)
    print(f"带 mask 输出形状: {output_masked.shape}")  # 期望: (2, 4, 8)
    print(f"带 mask 注意力权重:\n{attn_masked[0, 0]}") # 查看第一头注意力(上三角应为0)

运行结果

无 mask 输出形状: torch.Size([2, 4, 8])
注意力权重形状: torch.Size([2, 2, 4, 4])
带 mask 输出形状: torch.Size([2, 4, 8])
带 mask 注意力权重:
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5227, 0.4773, 0.0000, 0.0000],
        [0.3612, 0.3371, 0.3017, 0.0000],
        [0.2534, 0.2501, 0.2488, 0.2477]])

算法分析

  • 时间复杂度:$O(batch \times n\_heads \times L^2 \times d\_head)$,主要消耗在注意力矩阵计算 Q @ K^T,这是 $L \times L$ 维度的矩阵乘法
  • 空间复杂度:$O(batch \times n\_heads \times L^2)$,需要存储注意力矩阵,其中 $L$ 是序列长度

关键点总结

  1. 多维张量变换viewtranspose 的组合是实现多头拆分与合并的关键。拆分时先 view 增加维度再 transpose,合并时先 transpose 恢复再 view 展平
  2. 缩放因子self.scale = math.sqrt(self.d_head),除以 $\sqrt{d_k}$ 是为了将点积结果的方差控制在 1,防止 softmax 进入饱和区
  3. Mask 机制masked_fill(mask, -inf) 将填充位置设为负无穷,经 softmax 后概率为 0。代码中额外处理了 mask 的维度广播
  4. 返回注意力权重:在可视化或调试时,注意力权重 atten_weights 很有用,常用于分析模型的可解释性

扩展思考

  • 交叉注意力 vs 自注意力:如果 query 来自解码器,keyvalue 来自编码器,则实现的是交叉注意力(Cross-Attention)。本代码通过参数传入实现,灵活支持两种模式
  • Flash Attention:当序列很长时,标准注意力的内存开销随 $L^2$ 增长。Flash Attention 通过分块计算和重计算技巧,在不改变结果的前提下大幅降低显存占用
  • GQA(Grouped Query Attention):在推理优化中,多个 query 头共享同一组 key-value 头,可以减少 KV Cache 的存储开销