kyrie
kyrie
发布于 2025-11-11 / 21 阅读
0
0

注意力机制

什么是注意力机制

注意力机制(Attention Mchanism):模仿了人类在处理信息时候的选择性关注能力,允许模型在处理输入数据时动态地调整其注意力权重,从而突出重要信息并忽略不重要的信息。

核心思想

它主要通过计算查询向量(Query)键向量(Key)之间的相似度来确定注意力权重,然后对值向量(Value)进行加权求和,得到最终的输出。

Attention(Q,K,V)=softmax(\frac{Qk^T}{\sqrt{d_k}})V

Q:当前这个位置“想要找什么信息”的向量表示。

K:序列中每个位置“能提供什么信息”的向量表示。

V:序列中每个位置“真正要被拿来加权求和的内容”。

QK^T:计算每个词与其他词的相似性,点积越大,越相关。

\sqrt{d_k}:缩放因子,控制点积的数量级,避免维度爆炸

Softmax:把所有相似度转化成概率分布,归一化处理

AV:用权重对所有词的信息做加权平局,得到一个综合的上下文语义向量。

自注意力机制算法

假设输入序列为 X=[x_1,...,x_L] 经过词嵌入(Embedding)得到 [h_1,...,h_L] \in \R^{L \times D_{n}}

投影

通过三组线性变换

W_Q \in \R^{D_n \times D_k}W_K \in \R^{D_n \times D_k}W_V \in \R^{D_n \times D_v}

将同一样本输入 X 映射到三个子空间:

Q \in \R^{L \times D_k}K \in \R^{L \times D_k}V \in \R^{L \times D_v}

Q=XW_Q \space \space K=XW_K \space \space W=XW_V

计算相似度

  • q_i:第 i 个位置的 Query 向量

  • k_j​:第 j 个位置的 Key 向量

  • Score \in \R^{D_k} × \R^{D_k} → \R:相似度打分

Score(q_i \cdot k_j)

缩放

\operatorname{Score}(q_i, k_j) = \frac{q_i k_j^\top}{\sqrt{d_k}}

Softmax归一化

A \in \R^{L \times L}: 注意力权重矩阵

A=\alpha_{ij} = \frac{\exp(\operatorname{score}(q_i, k_j))} {\sum_{j'} \exp(\operatorname{score}(q_i, k_{j'}))}

加权求和

C \in \R^{L \times D_k}: 注意力结果矩阵

C=AV

代码

点积注意力

import math
import torch
from torch import nn

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    
    def forward(self, query, key, value, causal_mask=None, padding_mask=None):
        # query, key, value: [batch_size, seq_len, hidden_size]

        scale = 1.0 / math.sqrt(query.size(-1))

        # [batch_size, seq_len, seq_len]
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) * scale

        if causal_mask is not None:
            causal_mask = causal_mask.to(dtype=torch.bool, device=attention_scores.device)
            attention_scores = attention_scores.masked_fill(causal_mask, -1e9)

        if padding_mask is not None:
            padding_mask = padding_mask.to(dtype=torch.bool, device=attention_scores.device)
            padding_mask = padding_mask.unsqueeze(1)
            attention_scores = attention_scores.masked_fill(padding_mask, -1e9)

        # noramlize the attention scores
        attention_probs = torch.softmax(attention_scores, dim=-1)

        # apply the attention scores to the value
        # [batch_size, seq_len, hidden_size]
        attention_output = torch.matmul(attention_probs, value)

        return attention_output
    

def test_atten():
    batch_size = 128
    seq_len = 512
    hidden_size = 1024

    # create random query, key, value
    query = torch.randn(batch_size, seq_len, hidden_size)
    key = torch.randn(batch_size, seq_len, hidden_size)
    value = torch.randn(batch_size, seq_len, hidden_size)

    sdpa = ScaledDotProductAttention()
    output = sdpa(query, key, value)

    print("Query shape: ", query.shape)
    print("Key shape: ", key.shape)
    print("Value shape: ", value.shape)
    print("Output shape: ", output.shape)

if __name__ == "__main__":
    test_atten()

多头注意力代码

import math
import torch
from torch import nn

class MultHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scale = 1.0 / math.sqrt(self.head_dim)

        # initiate linear Q K V layers
        self.q_liner = nn.Linear(hidden_size, hidden_size)
        self.k_liner = nn.Linear(hidden_size, hidden_size)
        self.v_liner = nn.Linear(hidden_size, hidden_size)

        self.out_liner = nn.Linear(hidden_size, hidden_size)


    def forward(self, hidden_state, causal_mask=None, padding_mask=None):
        # hidden_state: [batch_size, seq_len, hidden_size]
        batch_size, seq_len, _ = hidden_state.size()

        # [batch_size, seq_len, hidden_size]
        query = self.q_liner(hidden_state)
        key = self.k_liner(hidden_state)
        value = self.v_liner(hidden_state)

        def shape(x):
            return x.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # reshape
        # [batch_size, seq_len, hidden_size] -> [batch_size, num_heads, seq_len, head_dim]
        query = shape(query)
        key   = shape(key)
        value = shape(value)

        # [batch_size, num_heads, seq_len, seq_len]
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scale

        if causal_mask is not None:
            causal_mask = causal_mask.to(attention_scores.device)
            attention_scores = attention_scores.masked_fill(causal_mask, -1e9)

        if padding_mask is not None:
            padding_mask = padding_mask.to(attention_scores.device)
            padding_mask = padding_mask.unsqueeze(1).unsqueeze(1)
            attention_scores = attention_scores.masked_fill(padding_mask, -1e9)

        
        # noramlize the attention scores
        attention_probs = torch.softmax(attention_scores, dim=-1)

        # apply the attention scores to the value
        # [batch_size, num_heads, seq_len, head_dim]
        attention_output = torch.matmul(attention_probs, value)

        # reshape back to [batch_size, seq_len, hidden_size]
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.head_dim)

        attention_output = self.out_liner(attention_output)

        return attention_output

def test_MHA():
    batch_size = 128
    seq_len = 512
    hidden_size = 1024
    num_heads = 8

    device = "cuda" if torch.cuda.is_available() else "cpu"

    hidden_state = torch.randn(batch_size, seq_len, hidden_size, device=device)

    causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)

    mha = MultHeadAttention(hidden_size, num_heads).to(device)

    attention_output = mha(hidden_state, causal_mask=causal_mask)

    print("Input shape: ", hidden_state.shape)
    print("Attention output shape: ", attention_output.shape)


if __name__ == "__main__":
    test_MHA()


评论