# 自注意力机制
# 直观理解
自注意力机制就像是阅读理解中的关键信息提取。想象你在阅读一篇文章时,大脑会自然地将注意力分配到不同的关键词上,并根据上下文理解它们之间的关系。这就是自注意力机制的核心思想。
# 技术原理
# 1. 查询-键-值(Query-Key-Value)模型
自注意力机制通过三个核心概念来实现信息的选择性关注:
- 查询(Query):当前位置想要查找的信息
- 键(Key):其他位置提供的线索
- 值(Value):其他位置的实际内容
# 2. 注意力分数计算
注意力分数通过Query和Key的点积来计算,表示不同位置之间的相关性:
其中:
- 是键向量的维度
- 是缩放因子,用于控制点积的规模
# 数学推导
# 1. 线性变换
首先,输入序列X通过三个权重矩阵转换为Q、K、V:
其中:
- 是可学习的参数矩阵
- 是输入序列的嵌入表示
# 2. 注意力权重计算
计算注意力分数:
缩放:
Softmax归一化:
加权求和:
# 实现细节
class SelfAttention(nn.Module):
def __init__(self, d_model, d_k):
super().__init__()
self.d_k = d_k
self.W_q = nn.Linear(d_model, d_k)
self.W_k = nn.Linear(d_model, d_k)
self.W_v = nn.Linear(d_model, d_k)
def forward(self, x, mask=None):
# x: [batch_size, seq_length, d_model]
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# 可选:使用mask
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# softmax归一化
attention = F.softmax(scores, dim=-1)
# 加权求和
output = torch.matmul(attention, V)
return output, attention
# 优化技巧
# 1. 注意力稀疏化
局部注意力
- 只关注窗口范围内的token
- 减少计算复杂度
稀疏注意力
- 使用固定模式的稀疏连接
- Longformer、BigBird等变体
# 2. 内存优化
渐进式计算
- 分块计算注意力矩阵
- 减少内存占用
线性复杂度变体
- Performer
- Linear Transformer
# 应用场景
序列建模
- 机器翻译
- 文本生成
- 语音识别
图像处理
- Vision Transformer
- DETR目标检测
# 注意事项
计算复杂度
- 时间复杂度:O(n²)
- 空间复杂度:O(n²)
长序列处理
- 考虑使用稀疏注意力
- 选择合适的注意力变体
数值稳定性
- 使用缩放因子
- 注意数值溢出问题