什么是注意力机制
注意力机制(Attention Mchanism):模仿了人类在处理信息时候的选择性关注能力,允许模型在处理输入数据时动态地调整其注意力权重,从而突出重要信息并忽略不重要的信息。
核心思想
它主要通过计算查询向量(Query)、键向量(Key)之间的相似度来确定注意力权重,然后对值向量(Value)进行加权求和,得到最终的输出。
Q:当前这个位置“想要找什么信息”的向量表示。
K:序列中每个位置“能提供什么信息”的向量表示。
V:序列中每个位置“真正要被拿来加权求和的内容”。
QK^T:计算每个词与其他词的相似性,点积越大,越相关。
\sqrt{d_k}:缩放因子,控制点积的数量级,避免维度爆炸
Softmax:把所有相似度转化成概率分布,归一化处理
AV:用权重对所有词的信息做加权平局,得到一个综合的上下文语义向量。
自注意力机制算法

假设输入序列为 X=[x_1,...,x_L] 经过词嵌入(Embedding)得到 [h_1,...,h_L] \in \R^{L \times D_{n}}。
投影
通过三组线性变换
W_Q \in \R^{D_n \times D_k} 、W_K \in \R^{D_n \times D_k} 、W_V \in \R^{D_n \times D_v}
将同一样本输入 X 映射到三个子空间:
Q \in \R^{L \times D_k} 、K \in \R^{L \times D_k} 、V \in \R^{L \times D_v}
计算相似度
-
q_i:第 i 个位置的 Query 向量
-
k_j:第 j 个位置的 Key 向量
-
Score \in \R^{D_k} × \R^{D_k} → \R:相似度打分
缩放
Softmax归一化
A \in \R^{L \times L}: 注意力权重矩阵
加权求和
C \in \R^{L \times D_k}: 注意力结果矩阵
代码
点积注意力
import math
import torch
from torch import nn
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super().__init__()
def forward(self, query, key, value, causal_mask=None, padding_mask=None):
# query, key, value: [batch_size, seq_len, hidden_size]
scale = 1.0 / math.sqrt(query.size(-1))
# [batch_size, seq_len, seq_len]
attention_scores = torch.matmul(query, key.transpose(-2, -1)) * scale
if causal_mask is not None:
causal_mask = causal_mask.to(dtype=torch.bool, device=attention_scores.device)
attention_scores = attention_scores.masked_fill(causal_mask, -1e9)
if padding_mask is not None:
padding_mask = padding_mask.to(dtype=torch.bool, device=attention_scores.device)
padding_mask = padding_mask.unsqueeze(1)
attention_scores = attention_scores.masked_fill(padding_mask, -1e9)
# noramlize the attention scores
attention_probs = torch.softmax(attention_scores, dim=-1)
# apply the attention scores to the value
# [batch_size, seq_len, hidden_size]
attention_output = torch.matmul(attention_probs, value)
return attention_output
def test_atten():
batch_size = 128
seq_len = 512
hidden_size = 1024
# create random query, key, value
query = torch.randn(batch_size, seq_len, hidden_size)
key = torch.randn(batch_size, seq_len, hidden_size)
value = torch.randn(batch_size, seq_len, hidden_size)
sdpa = ScaledDotProductAttention()
output = sdpa(query, key, value)
print("Query shape: ", query.shape)
print("Key shape: ", key.shape)
print("Value shape: ", value.shape)
print("Output shape: ", output.shape)
if __name__ == "__main__":
test_atten()
多头注意力代码
import math
import torch
from torch import nn
class MultHeadAttention(nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
# initiate linear Q K V layers
self.q_liner = nn.Linear(hidden_size, hidden_size)
self.k_liner = nn.Linear(hidden_size, hidden_size)
self.v_liner = nn.Linear(hidden_size, hidden_size)
self.out_liner = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_state, causal_mask=None, padding_mask=None):
# hidden_state: [batch_size, seq_len, hidden_size]
batch_size, seq_len, _ = hidden_state.size()
# [batch_size, seq_len, hidden_size]
query = self.q_liner(hidden_state)
key = self.k_liner(hidden_state)
value = self.v_liner(hidden_state)
def shape(x):
return x.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# reshape
# [batch_size, seq_len, hidden_size] -> [batch_size, num_heads, seq_len, head_dim]
query = shape(query)
key = shape(key)
value = shape(value)
# [batch_size, num_heads, seq_len, seq_len]
attention_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale
if causal_mask is not None:
causal_mask = causal_mask.to(attention_scores.device)
attention_scores = attention_scores.masked_fill(causal_mask, -1e9)
if padding_mask is not None:
padding_mask = padding_mask.to(attention_scores.device)
padding_mask = padding_mask.unsqueeze(1).unsqueeze(1)
attention_scores = attention_scores.masked_fill(padding_mask, -1e9)
# noramlize the attention scores
attention_probs = torch.softmax(attention_scores, dim=-1)
# apply the attention scores to the value
# [batch_size, num_heads, seq_len, head_dim]
attention_output = torch.matmul(attention_probs, value)
# reshape back to [batch_size, seq_len, hidden_size]
attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.head_dim)
attention_output = self.out_liner(attention_output)
return attention_output
def test_MHA():
batch_size = 128
seq_len = 512
hidden_size = 1024
num_heads = 8
device = "cuda" if torch.cuda.is_available() else "cpu"
hidden_state = torch.randn(batch_size, seq_len, hidden_size, device=device)
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
mha = MultHeadAttention(hidden_size, num_heads).to(device)
attention_output = mha(hidden_state, causal_mask=causal_mask)
print("Input shape: ", hidden_state.shape)
print("Attention output shape: ", attention_output.shape)
if __name__ == "__main__":
test_MHA()