自注意力机制(1)

自注意机制

1. 自注意机制的特点

考虑这样一个问题,输入长度为m的序列\(\{x_1, x_2,…,x_m\}\),序列中的元素都是向量,要求输出长度同样为m的序列\(\{c_1, c_2,…,c_m\}\),另外还有两个要求:

  1. 序列的长度m是不确定的,可以动态变化,但是神经网络的参数数量不能变。
  2. 输出的向量\(c_i\)不仅仅和\(x_i\)有关,而是依赖于所有新的输入向量\(\{x_1, x_2,…,x_m\}\)

传统的RNN不能解决上述问题,因此传统RNN的输出\(c_i\)只依赖于\(\{x_1, x_2,…,x_i\}\),而不依赖于\(\{x_{i+1},…,x_m\}\)。自注意机制就能很好的解决上述问题。

2. 数学形式

输入:\(X=\{x_1, x_2,…,x_m\}\)\(x_i\)\(d_{in}\times1\)的向量。

三个参数矩阵:\(W_q:d_q*d_{in}\); \(W_k:d_q*d_{in}\); \(W_v:d_{out}*d_{in}\)

无论输入序列有多长,参数矩阵不需要发生改变,这三个参数矩阵需要从训练数据中进行学习

输出:\(C=\{c_1, c_2,…,c_m\}\)\(c_i\)\(d_{out}\times1\)的向量。

计算步骤:

  1. 第一步将输入\(x_i\)映射为三元组\(\{q_i,k_i,v_i\}\)

    1. \(q_i=W_q*x_i\)\(q_i\)的大小是\(d_q\times1\)
    2. \(k_i=W_k*x_i\)\(k_i\)的大小为\(d_q*1\)
    3. \(v_i=W_v*x_i\)\(v_i\)的大小为\(d_{out}*1\)

    第一步将输出映射为三元组,上述是每个元素的计算过程。在实际计算中,会得到三个矩阵,\(Q=\{q_1, q_2,…,q_m\}\)大小为\(d_q\times m\)\(K=\{k_1,k_2,…,k_m\}\)大小为\(d_q\times m\)\(V=\{v_i, v_2,…,v_m\}\),大小为\(d_{out}\times m\)

  2. 第二步利用\(q_i\)\(K\)计算权重向量\(a_i\):

    1. \(a_i=\text{softmax}(<q_i,k_1>,<q_i, k_2>,…,<q_i, k_m>), i=1,..,m\)

    上述的<,>表示内积,\(\text{softmax}\)函数导致\(a_i\)中所有元素的和为1,每个元素对应着与\(\{x_1, x_2,…,x_m\}\)的重要程度,权重矩阵\(A=\{a_1,a_2,…,a_m\}\),大小为\(m \times m\)

  3. 第三步利用权重矩阵\(A\)\(V\)矩阵得到最终的输出矩阵\(C=\{c_1, c_2,…,c_m\}\),第\(i\)个输出向量\(c_i\)依赖于\(a_i\)\(\{v_1, v_2,…, v_m\}\):

    1. \(c_i=[v_1, v_2,..,v_m]*a_i=\sum_{j=1}^m a_i^j*v_j, i=1,..,m\)

    \(c_i\)是向量\(\{v_1, v_2,…, v_m\}\)的加权平均,权重是\(a_i=[a_i^1, a_i^2,…,a_i^m]\)\(c_i\)的大小是\(d_{out}\times 1\)。整个输出矩阵\(C\)大小为\(d_{out}\times m\)

为什么要叫“注意力”呢,我们看最后的输出\(c_i=a_i^1v_1+a_i^2v_2+\cdot \cdot+a_i^mv_m\),权重\(a_i=[a_i^1, a_i^2,…,a_i^m]\)反映出\(c_i\)最关注那些输入的\(v_i=W_v*x_i\),如果权重\(a_i^j\)大,说明\(x_j\)\(c_i\)的影响较大,应当重点关注。

3. Pytorch代码实现(单头自注意层)

import torch 
import torch.nn as nn
from math import sqrt

class Self_attention(nn.Module):
    def __init__(self, d_in, d_q, d_out):
        super(Self_attention, self).__init__()
        self.din = d_in
        self.dq = d_q
        self.dout = d_out
        
        self.Wq = nn.Linear(self.din, self.dq, bias=False)
        self.Wk = nn.Linear(self.din, self.dq, bias=False)
        self.Wv = nn.Linear(self.din, self.dout, bias=False)
        
        self._norm_fact = 1/sqrt(self.dq)   # 归一化层
        
    def forward(self, x):
        m, din = x.shape
        assert din == self.din   # 判断输入数据维度是否正确
        
        # 第一步
        Q = self.Wq(x)  # m*dq
        K = self.Wk(x)  # m*dq
        V = self.Wv(x)  # m*dout
        
        # 第二步
        A = torch.softmax(torch.matmul(Q, K.T)*self._norm_fact, dim=-1)  # m*m
  		
        # 第三步
        C = torch.matmul(A, V)  # m*dout
                
        return C
                     
请登录后发表评论

    没有回复内容