# 自注意力机制

# 直观理解

自注意力机制就像是阅读理解中的关键信息提取。想象你在阅读一篇文章时,大脑会自然地将注意力分配到不同的关键词上,并根据上下文理解它们之间的关系。这就是自注意力机制的核心思想。

# 技术原理

# 1. 查询-键-值(Query-Key-Value)模型

自注意力机制通过三个核心概念来实现信息的选择性关注:

  • 查询(Query):当前位置想要查找的信息
  • 键(Key):其他位置提供的线索
  • 值(Value):其他位置的实际内容

# 2. 注意力分数计算

注意力分数通过Query和Key的点积来计算,表示不同位置之间的相关性:

Attention(Q,K,V)=softmax(QKTdk)VAttention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V

其中:

  • dkd_k 是键向量的维度
  • dk\sqrt{d_k} 是缩放因子,用于控制点积的规模

# 数学推导

# 1. 线性变换

首先,输入序列X通过三个权重矩阵转换为Q、K、V:

Q=XWQQ = XW_Q K=XWKK = XW_K V=XWVV = XW_V

其中:

  • WQ,WK,WVW_Q, W_K, W_V 是可学习的参数矩阵
  • XX 是输入序列的嵌入表示

# 2. 注意力权重计算

  1. 计算注意力分数: S=QKTS = QK^T

  2. 缩放: Sscaled=SdkS_{scaled} = \frac{S}{\sqrt{d_k}}

  3. Softmax归一化: A=softmax(Sscaled)A = softmax(S_{scaled})

  4. 加权求和: O=AVO = AV

# 实现细节

class SelfAttention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_k = d_k
        self.W_q = nn.Linear(d_model, d_k)
        self.W_k = nn.Linear(d_model, d_k)
        self.W_v = nn.Linear(d_model, d_k)
        
    def forward(self, x, mask=None):
        # x: [batch_size, seq_length, d_model]
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # 可选:使用mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # softmax归一化
        attention = F.softmax(scores, dim=-1)
        
        # 加权求和
        output = torch.matmul(attention, V)
        return output, attention

# 优化技巧

# 1. 注意力稀疏化

  1. 局部注意力

    • 只关注窗口范围内的token
    • 减少计算复杂度
  2. 稀疏注意力

    • 使用固定模式的稀疏连接
    • Longformer、BigBird等变体

# 2. 内存优化

  1. 渐进式计算

    • 分块计算注意力矩阵
    • 减少内存占用
  2. 线性复杂度变体

    • Performer
    • Linear Transformer

# 应用场景

  1. 序列建模

    • 机器翻译
    • 文本生成
    • 语音识别
  2. 图像处理

    • Vision Transformer
    • DETR目标检测

# 注意事项

  1. 计算复杂度

    • 时间复杂度:O(n²)
    • 空间复杂度:O(n²)
  2. 长序列处理

    • 考虑使用稀疏注意力
    • 选择合适的注意力变体
  3. 数值稳定性

    • 使用缩放因子
    • 注意数值溢出问题