1 AFT-simple 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 class AFT_Simple (nn.Module): def __init__ (self, dim, hidden_dim=64 , **kwargs ): super ().__init__() self.w_q = nn.Linear(dim, hidden_dim) self.w_k = nn.Linear(dim, hidden_dim) self.w_v = nn.Linear(dim, hidden_dim) self.out = nn.Linear(hidden_dim, dim) def forward (self, x ): B, H, W, C = x.shape x = x.reshape(B, -1 , C) q = self.w_q(x) k = self.w_k(x) v = self.w_v(x) y = torch.sigmoid(q) * (torch.softmax(k, dim=1 ) * v).sum (dim=1 , keepdim=True ) return self.out(y).view(B, H, W, C)
2 AFT-full 2.1 2个问题 在复现代码的过程中发现了2个问题:
矩阵运算中@和 * 有什么区别?
a@b和b@a有什么区别?
2.1.1 2个问题 用numpy
的np.array()
定义矩阵A,B,如下:
1 2 3 4 5 6 7 import numpy as npA = np.array([ [1 , 2 , 3 ], [4 , 5 , 6 ] ]) B = np.array([ [6 , 5 ], [4 , 3 ], [2 , 1 ] ])
2.1.1.1 * 运算 “*”运算是将两个向量中每个元素进行相乘,是数乘运算,需要两个矩阵维度相同,所以需要A的转置与B做“*”运算,就是论文中的element-wise运算。
1 2 3 4 5 6 7 print (A.T*B) ''' [[ 6 20] [ 8 15] [ 6 6]] '''
2.1.1.2 @运算 “@”运算都可以起到矩阵乘法的作用。
1 2 3 4 5 print (A @ B)''' [[20 14] [56 41]] '''
3 AFT-conv 我的理解是在每个头中同时使用卷积运算。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 class AFT_Conv (nn.Module): def __init__ (self, dim, hidden_dim=64 , head_num=4 , kernel_size=7 , **kwargs ): super ().__init__() self.head_num = head_num self.hidden_dim = hidden_dim self.head_dim = hidden_dim // head_num self.w_q = nn.Linear(dim, hidden_dim) self.w_v = nn.Linear(dim, hidden_dim) self.w_k = nn.Linear(dim, hidden_dim) self.kernels = [ nn.Parameter(torch.Tensor(self.head_dim, self.head_dim, kernel_size, kernel_size), requires_grad=True ) for _ in range (head_num) ] self.conv2d = nn.ModuleList([ nn.Conv2d(in_channels=self.head_dim, out_channels=self.head_dim, kernel_size=kernel_size, padding=3 ) for _ in range (head_num) ]) for i in range (head_num): self.conv2d[i].weight.data = torch.exp(self.kernels[i]) - 1 self.out = nn.Linear(hidden_dim, dim) def attention (self, q, k, v, conv ): max_k = k.max (dim=0 , keepdims=True )[0 ] exp_k = torch.exp(k - max_k) num = conv(exp_k * v) + exp_k.sum (dim=1 , keepdim=True ) * v den = conv(exp_k) + exp_k.sum (dim=1 , keepdim=True ) y = torch.sigmoid(q) * num / den return y def forward (self, x ): B, H, W, C = x.shape assert C % self.head_num == 0 q = self.w_q(x).view(B, self.head_num, -1 , H, W) v = self.w_v(x).view(B, self.head_num, -1 , H, W) k = self.w_k(x).view(B, self.head_num, -1 , H, W) q_s = [q[:, i, :self.hidden_dim // self.head_num, :, :].contiguous() for i in range (self.head_num)] k_s = [k[:, i, :self.hidden_dim // self.head_num, :, :].contiguous() for i in range (self.head_num)] v_s = [v[:, i, :self.hidden_dim // self.head_num, :, :].contiguous() for i in range (self.head_num)] attentions = [self.attention(q_, k_, v_, conv) for conv, q_, k_, v_ in zip (self.conv2d, q_s, k_s, v_s)] y = torch.cat(attentions, dim=1 ).view(B, H, W, -1 ) return self.out(y)
4 AFT-local 4.1 代码解释 4.1.1 初始化 1 2 3 4 5 6 7 8 9 10 11 12 13 def __init__ (self, d_model, seq_len, local_window_size, bias ): super ().__init__() self.local_window_size = local_window_size self.query = nn.Linear(d_model, d_model, bias=bias) self.key = nn.Linear(d_model, d_model, bias=bias) self.value = nn.Linear(d_model, d_model, bias=bias) self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True ) self.w_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False ) self.activation = nn.Sigmoid() self.output = nn.Linear(d_model, d_model)
4.1.2 创建掩码 #torch.tril
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 @staticmethod def create_local_mask (seq_len, local_window_size ): """ 创建局部掩码 :param seq_len: :param local_window_size: :return: """ local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool ) local_mask = torch.tril(local_mask, local_window_size - 1 ) local_mask = torch.tril(local_mask, -(local_window_size - 1 )) return local_mask
4.1.3 forward函数 输入: query、key、value、mask
query
、key
和value
是存储查询、键和值的标记嵌入集合的张量。它们的形状为 [seq_len, batch_size, d_model]
mask
具有形状 [seq_len, seq_len, batch_size]
和 mask[i, j, b]
指示对于批次 b
,位置 i
处的查询是否有权访问密钥-位置 j
处的值
输出: 注意力模块计算结果
4.1.4 获取序列长度 1 seq_len, _, _ = query.shape
4.1.5 判断是否有掩码 mask
具有形状 [seq_len_q, seq_len_k, batch_size]
,其中第一个维度是查询维度。如果查询维度等于 1 ,它将被广播。
1 2 3 4 if mask is not None : assert mask.shape[0 ] == 1 or mask.shape[0 ] == query.shape[0 ] assert mask.shape[1 ] == key.shape[0 ] assert mask.shape[2 ] == 1 or mask.shape[2 ] == query.shape[1 ]
4.1.6 变换query、key和value 1 2 3 query = self.query(query) key = self.key(key) value = self.value(value)
由于$W^Q$、$W^K$和$W^V$的输入和输出维度相同,因此三个变量的形状没有发生变化。
4.1.7 计算成对位置偏置 1 pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
pos_bias
的形状为[seq_len, seq_len]
self.pos_bias
:这是一个预定义的位置偏差矩阵,它可能包含了模型中不同位置之间的某种偏差信息。这个矩阵的维度通常是根据模型能处理的最大序列长度来定义的。
seq_len
:当前处理的序列长度。由于不同的输入序列可能有不同的长度,而self.pos_bias
是基于最大长度定义的,因此需要截取与当前序列长度相匹配的部分。
self.local_mask[:seq_len, :seq_len]
:这是一个局部掩码(local mask),用于限制位置偏差的作用范围。它可能用于实现某种形式的局部注意力机制,即只允许序列中的某些位置相互注意。这个掩码通常是一个与seq_len
大小相同的二维矩阵,其中的元素可能是0(表示不允许)或1(表示允许)。
通过将self.pos_bias
和self.local_mask
的对应部分相乘,我们得到了一个调整后的位置偏差矩阵,它只保留了那些被局部掩码允许的位置偏差。
1 pos_bias = pos_bias.unsqueeze(-1 )
pos_bias
的形状为[seq_len, seq_len, 1]
unsqueeze(-1)
:这个操作是在pos_bias
的最后一个维度上增加一个大小为1的维度 。这通常是为了满足后续操作的维度要求。例如,在注意力机制中,位置偏差可能需要与其他张量(如注意力权重)进行广播(broadcasting)操作,而这些张量可能具有额外的维度。
1 pos_bias.masked_fill_(~mask, float ('-inf' ))
mask
:这是一个与当前序列长度seq_len
相关的布尔型掩码,用于指示哪些位置是有效的(通常是True
)或无效的(False
)。注意,这里的mask
可能与前面提到的self.local_mask
不同,尽管它们的作用类似,但可能具有不同的形状或逻辑。
~mask
:对mask
取反,得到一个与原mask
形状相同但逻辑相反的布尔型张量。在这个张量中,原mask
中为True
的位置现在为False
,反之亦然。
masked_fill_
:这是一个原地(in-place)操作,它会将pos_bias
中~mask
为True
(即原mask
为False
)的位置填充为float('-inf')
。在注意力机制中,这通常用于屏蔽掉那些不应该被注意到的位置,因为float('-inf')
在后续的softmax
操作中会被转换成接近0的概率值。
4.1.8 稳定softmax
计算 我们在计算指数之前减去$\max {t^{\prime}}\left(K {t^{\prime}}\right)$和$\max {t^{\prime}}\left(w {t, t^{\prime}}\right)$以稳定softmax
计算。
如果$x_i$很大,$exp(x_i)$就会变得很大,并且$\frac{\sum \exp \left(x_{i}\right) y_{i}}{\sum \exp \left(x_{i}\right)}$的计算变得不稳定。在计算分子和分母的指数之前减去一个常数将会抵消。并且可以帮助稳定计算,因此减去$\max (x_i)$来稳定计算。
1 2 max_key = key.max (dim=0 , keepdims=True )[0 ] max_pos_bias = pos_bias.max (dim=1 , keepdims=True )[0 ]
计算$\exp \left(K_{t^{\prime}}-\max {t^{\prime}}\left(K {t^{\prime}}\right)\right)$:
1 exp_key = torch.exp(key - max_key)
计算$\exp \left(w_{t, t^{\prime}}-\max {t^{\prime}}\left(w {t, t^{\prime}}\right)\right)$:
1 exp_pos_bias = torch.exp(pos_bias - max_pos_bias)
4.1.9 计算输出 #torch.einsum
接下来根据以下公式进行计算
$$ \begin{aligned} Y_{t} & =\sigma\left(Q_{t}\right) \odot \frac{\sum_{t^{\prime}=1}^{T} \exp \left(K_{t^{\prime}}+w_{t, t^{\prime}}\right) \odot V_{t^{\prime}}}{\sum_{t^{\prime}=1}^{T} \exp \left(K_{t^{\prime}}+w_{t, t^{\prime}}\right)} \ & =\sigma\left(Q_{t}\right) \odot \frac{\sum_{t^{\prime}=1}^{T} \exp \left(w_{t, t^{\prime}}\right) \odot \exp \left(K_{t^{\prime}}\right) \odot V_{t^{\prime}}}{\sum_{t^{\prime}=1}^{T} \exp \left(w_{t, t^{\prime}}\right) \odot \exp \left(K_{t^{\prime}}\right)}\end{aligned} $$
分子部分$\sum_{t^{\prime}=1}^{T} \exp \left(w_{t, t^{\prime}}\right) \odot \exp \left(K_{t^{\prime}}\right) \odot V_{t^{\prime}}$:
exp_pos_bias
的形状:[seq_len, seq_len, 1]
exp_key
的形状:[seq_len, batch_size, d_model]
value
的形状:[seq_len, batch_size, d_model]
exp_key * value
的形状:[seq_len, batch_size, d_model]
1 num = torch.einsum('ijb,jbd->ibd' , exp_pos_bias, exp_key * value)
num
的形状:[seq_len, batch_size, d_model]
分母部分$\sum_{t^{\prime}=1}^{T} \exp \left(w_{t, t^{\prime}}\right) \odot \exp \left(K_{t^{\prime}}\right)$:
1 den = torch.einsum('ijb,jbd->ibd' , exp_pos_bias, exp_key)
den
的形状:[seq_len, batch_size, d_model]
输出:
1 2 y = self.activation(query) * num / den return self.output(y)
最后输出的形状为[seq_len, batch_size, d_model]
4.2 每个变量的形状总结
变量名
形状
输入x
[seq_len, batch_size, d_model]
$W^Q$
[d_model, d_q]
→[d_model, d_model]
$W^K$
[d_model, d_k]
→[d_model, d_model]
$W^V$
[d_model, d_v]
→[d_model, d_model]
$w$
[seq_len, seq_len]
$w _ mask$
[seq_len, seq_len]
4.3 完整代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 import torchfrom torch import nnfrom typing import Optional from labml_helpers.module import Moduleclass AFTLocal (Module ): def __init__ (self, d_model, seq_len, local_window_size, bias ): super ().__init__() self.local_window_size = local_window_size self.query = nn.Linear(d_model, d_model, bias=bias) self.key = nn.Linear(d_model, d_model, bias=bias) self.value = nn.Linear(d_model, d_model, bias=bias) self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True ) self.w_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False ) self.activation = nn.Sigmoid() self.output = nn.Linear(d_model, d_model) @staticmethod def create_local_mask (seq_len, local_window_size ): """ 创建局部掩码 :param seq_len: :param local_window_size: :return: """ local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool ) local_mask = torch.tril(local_mask, local_window_size - 1 ) local_mask = torch.tril(local_mask, -(local_window_size - 1 )) return local_mask def forward (self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional [torch.Tensor] = None ): seq_len, _, _ = query.shape if mask is not None : assert mask.shape[0 ] == 1 or mask.shape[0 ] == query.shape[0 ] assert mask.shape[1 ] == key.shape[0 ] assert mask.shape[2 ] == 1 or mask.shape[2 ] == query.shape[1 ] query = self.query(query) key = self.key(key) value = self.value(value) pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len] pos_bias = pos_bias.unsqueeze(-1 ) pos_bias.masked_fill_(~mask, float ('-inf' )) max_key = key.max (dim=0 , keepdims=True )[0 ] max_pos_bias = pos_bias.max (dim=1 , keepdims=True )[0 ] exp_key = torch.exp(key - max_key) exp_pos_bias = torch.exp(pos_bias - max_pos_bias) num = torch.einsum('ijb,jbd->ibd' , exp_pos_bias, exp_key * value) den = torch.einsum('ijb,jbd->ibd' , exp_pos_bias, exp_key) y = self.activation(query) * num / den return self.output(y) def _test_local_mask (): from labml.logger import inspect inspect(AFTLocal.create_local_mask(10 , 4 )) if __name__ == '__main__' : _test_local_mask()