1 AFT-simple

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class AFT_Simple(nn.Module):
def __init__(self, dim, hidden_dim=64, **kwargs):
super().__init__()
self.w_q = nn.Linear(dim, hidden_dim)
self.w_k = nn.Linear(dim, hidden_dim)
self.w_v = nn.Linear(dim, hidden_dim)
self.out = nn.Linear(hidden_dim, dim)

def forward(self, x):
B, H, W, C = x.shape
x = x.reshape(B, -1, C)

q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)

y = torch.sigmoid(q) * (torch.softmax(k, dim=1) * v).sum(dim=1, keepdim=True)
return self.out(y).view(B, H, W, C)

2 AFT-full

2.1 2个问题

在复现代码的过程中发现了2个问题:

  1. 矩阵运算中@和 * 有什么区别?
  2. a@b和b@a有什么区别?

2.1.1 2个问题

numpynp.array()定义矩阵A,B,如下:

1
2
3
4
5
6
7
import numpy as np
A = np.array([ [1, 2, 3],
[4, 5, 6] ])

B = np.array([ [6, 5],
[4, 3],
[2, 1] ])

2.1.1.1 * 运算

“*”运算是将两个向量中每个元素进行相乘,是数乘运算,需要两个矩阵维度相同,所以需要A的转置与B做“*”运算,就是论文中的element-wise运算。

1
2
3
4
5
6
7
print(A.T*B)

'''
[[ 6 20]
[ 8 15]
[ 6 6]]
'''

2.1.1.2 @运算

“@”运算都可以起到矩阵乘法的作用。

1
2
3
4
5
print(A @ B)
'''
[[20 14]
[56 41]]
'''

3 AFT-conv

我的理解是在每个头中同时使用卷积运算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class AFT_Conv(nn.Module):
def __init__(self, dim, hidden_dim=64, head_num=4, kernel_size=7, **kwargs):
super().__init__()
self.head_num = head_num
self.hidden_dim = hidden_dim
self.head_dim = hidden_dim // head_num

self.w_q = nn.Linear(dim, hidden_dim)
self.w_v = nn.Linear(dim, hidden_dim)
self.w_k = nn.Linear(dim, hidden_dim)
self.kernels = [
nn.Parameter(torch.Tensor(self.head_dim, self.head_dim, kernel_size, kernel_size), requires_grad=True)
for _ in range(head_num)
]
self.conv2d = nn.ModuleList([
nn.Conv2d(in_channels=self.head_dim, out_channels=self.head_dim, kernel_size=kernel_size, padding=3)
for _ in range(head_num)
])
for i in range(head_num):
self.conv2d[i].weight.data = torch.exp(self.kernels[i]) - 1
self.out = nn.Linear(hidden_dim, dim)

def attention(self, q, k, v, conv):
max_k = k.max(dim=0, keepdims=True)[0]
exp_k = torch.exp(k - max_k)

num = conv(exp_k * v) + exp_k.sum(dim=1, keepdim=True) * v
den = conv(exp_k) + exp_k.sum(dim=1, keepdim=True)

y = torch.sigmoid(q) * num / den
return y

def forward(self, x):
B, H, W, C = x.shape
assert C % self.head_num == 0

q = self.w_q(x).view(B, self.head_num, -1, H, W)
v = self.w_v(x).view(B, self.head_num, -1, H, W)
k = self.w_k(x).view(B, self.head_num, -1, H, W)

q_s = [q[:, i, :self.hidden_dim // self.head_num, :, :].contiguous() for i in range(self.head_num)]
k_s = [k[:, i, :self.hidden_dim // self.head_num, :, :].contiguous() for i in range(self.head_num)]
v_s = [v[:, i, :self.hidden_dim // self.head_num, :, :].contiguous() for i in range(self.head_num)]

attentions = [self.attention(q_, k_, v_, conv) for conv, q_, k_, v_ in zip(self.conv2d, q_s, k_s, v_s)]

y = torch.cat(attentions, dim=1).view(B, H, W, -1)
return self.out(y)

4 AFT-local

4.1 代码解释

4.1.1 初始化

1
2
3
4
5
6
7
8
9
10
11
12
13
def __init__(self, d_model, seq_len, local_window_size, bias):
super().__init__()
self.local_window_size = local_window_size
# Q K V
self.query = nn.Linear(d_model, d_model, bias=bias)
self.key = nn.Linear(d_model, d_model, bias=bias)
self.value = nn.Linear(d_model, d_model, bias=bias)
# 成对位置偏差
self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)
# 掩码
self.w_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)
self.activation = nn.Sigmoid()
self.output = nn.Linear(d_model, d_model)

4.1.2 创建掩码

#torch.tril

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@staticmethod
def create_local_mask(seq_len, local_window_size):
"""
创建局部掩码
:param seq_len:
:param local_window_size:
:return:
"""
# 初始化为1
local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)
# 将 [t+s,+∞] 设置为0
local_mask = torch.tril(local_mask, local_window_size - 1)
# 将 [-∞,t + s] 设置为0
local_mask = torch.tril(local_mask, -(local_window_size - 1))
return local_mask

4.1.3 forward函数

输入:query、key、value、mask

  • querykeyvalue是存储查询、键和值的标记嵌入集合的张量。它们的形状为 [seq_len, batch_size, d_model]
  • mask 具有形状 [seq_len, seq_len, batch_size]mask[i, j, b] 指示对于批次 b ,位置 i 处的查询是否有权访问密钥-位置 j 处的值

输出:注意力模块计算结果

4.1.4 获取序列长度

1
seq_len, _, _ = query.shape

4.1.5 判断是否有掩码

mask 具有形状 [seq_len_q, seq_len_k, batch_size] ,其中第一个维度是查询维度。如果查询维度等于 1 ,它将被广播。

1
2
3
4
if mask is not None:
assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
assert mask.shape[1] == key.shape[0]
assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]

4.1.6 变换query、key和value

1
2
3
query = self.query(query)
key = self.key(key)
value = self.value(value)

由于$W^Q$、$W^K$和$W^V$的输入和输出维度相同,因此三个变量的形状没有发生变化。

4.1.7 计算成对位置偏置

1
pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]

pos_bias的形状为[seq_len, seq_len]

  • self.pos_bias:这是一个预定义的位置偏差矩阵,它可能包含了模型中不同位置之间的某种偏差信息。这个矩阵的维度通常是根据模型能处理的最大序列长度来定义的。
  • seq_len:当前处理的序列长度。由于不同的输入序列可能有不同的长度,而self.pos_bias是基于最大长度定义的,因此需要截取与当前序列长度相匹配的部分。
  • self.local_mask[:seq_len, :seq_len]:这是一个局部掩码(local mask),用于限制位置偏差的作用范围。它可能用于实现某种形式的局部注意力机制,即只允许序列中的某些位置相互注意。这个掩码通常是一个与seq_len大小相同的二维矩阵,其中的元素可能是0(表示不允许)或1(表示允许)。
  • 通过将self.pos_biasself.local_mask的对应部分相乘,我们得到了一个调整后的位置偏差矩阵,它只保留了那些被局部掩码允许的位置偏差。
1
pos_bias = pos_bias.unsqueeze(-1)

pos_bias的形状为[seq_len, seq_len, 1]

  • unsqueeze(-1):这个操作是在pos_bias最后一个维度上增加一个大小为1的维度。这通常是为了满足后续操作的维度要求。例如,在注意力机制中,位置偏差可能需要与其他张量(如注意力权重)进行广播(broadcasting)操作,而这些张量可能具有额外的维度。
1
pos_bias.masked_fill_(~mask, float('-inf'))
  • mask:这是一个与当前序列长度seq_len相关的布尔型掩码,用于指示哪些位置是有效的(通常是True)或无效的(False)。注意,这里的mask可能与前面提到的self.local_mask不同,尽管它们的作用类似,但可能具有不同的形状或逻辑。
  • ~mask:对mask取反,得到一个与原mask形状相同但逻辑相反的布尔型张量。在这个张量中,原mask中为True的位置现在为False,反之亦然。
  • masked_fill_:这是一个原地(in-place)操作,它会将pos_bias~maskTrue(即原maskFalse)的位置填充为float('-inf')。在注意力机制中,这通常用于屏蔽掉那些不应该被注意到的位置,因为float('-inf')在后续的softmax操作中会被转换成接近0的概率值。

4.1.8 稳定softmax计算

我们在计算指数之前减去$\max {t^{\prime}}\left(K{t^{\prime}}\right)$和$\max {t^{\prime}}\left(w{t, t^{\prime}}\right)$以稳定softmax计算。

如果$x_i$很大,$exp(x_i)$就会变得很大,并且$\frac{\sum \exp \left(x_{i}\right) y_{i}}{\sum \exp \left(x_{i}\right)}$的计算变得不稳定。在计算分子和分母的指数之前减去一个常数将会抵消。并且可以帮助稳定计算,因此减去$\max (x_i)$来稳定计算。

1
2
max_key = key.max(dim=0, keepdims=True)[0]
max_pos_bias = pos_bias.max(dim=1, keepdims=True)[0]

计算$\exp \left(K_{t^{\prime}}-\max {t^{\prime}}\left(K{t^{\prime}}\right)\right)$:

1
exp_key = torch.exp(key - max_key)

计算$\exp \left(w_{t, t^{\prime}}-\max {t^{\prime}}\left(w{t, t^{\prime}}\right)\right)$:

1
exp_pos_bias = torch.exp(pos_bias - max_pos_bias)

4.1.9 计算输出

#torch.einsum

接下来根据以下公式进行计算

$$
\begin{aligned} Y_{t} & =\sigma\left(Q_{t}\right) \odot \frac{\sum_{t^{\prime}=1}^{T} \exp \left(K_{t^{\prime}}+w_{t, t^{\prime}}\right) \odot V_{t^{\prime}}}{\sum_{t^{\prime}=1}^{T} \exp \left(K_{t^{\prime}}+w_{t, t^{\prime}}\right)} \ & =\sigma\left(Q_{t}\right) \odot \frac{\sum_{t^{\prime}=1}^{T} \exp \left(w_{t, t^{\prime}}\right) \odot \exp \left(K_{t^{\prime}}\right) \odot V_{t^{\prime}}}{\sum_{t^{\prime}=1}^{T} \exp \left(w_{t, t^{\prime}}\right) \odot \exp \left(K_{t^{\prime}}\right)}\end{aligned}
$$

分子部分$\sum_{t^{\prime}=1}^{T} \exp \left(w_{t, t^{\prime}}\right) \odot \exp \left(K_{t^{\prime}}\right) \odot V_{t^{\prime}}$:

  1. exp_pos_bias的形状:[seq_len, seq_len, 1]
  2. exp_key的形状:[seq_len, batch_size, d_model]
  3. value的形状:[seq_len, batch_size, d_model]
  4. exp_key * value的形状:[seq_len, batch_size, d_model]
1
num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)

num的形状:[seq_len, batch_size, d_model]

分母部分$\sum_{t^{\prime}=1}^{T} \exp \left(w_{t, t^{\prime}}\right) \odot \exp \left(K_{t^{\prime}}\right)$:

1
den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)

den的形状:[seq_len, batch_size, d_model]

输出:

1
2
y = self.activation(query) * num / den
return self.output(y)

最后输出的形状为[seq_len, batch_size, d_model]

4.2 每个变量的形状总结

变量名 形状
输入x [seq_len, batch_size, d_model]
$W^Q$ [d_model, d_q][d_model, d_model]
$W^K$ [d_model, d_k][d_model, d_model]
$W^V$ [d_model, d_v][d_model, d_model]
$w$ [seq_len, seq_len]
$w _ mask$ [seq_len, seq_len]

4.3 完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from torch import nn
from typing import Optional
from labml_helpers.module import Module


class AFTLocal(Module):
def __init__(self, d_model, seq_len, local_window_size, bias):
super().__init__()
self.local_window_size = local_window_size
# Q K V
self.query = nn.Linear(d_model, d_model, bias=bias)
self.key = nn.Linear(d_model, d_model, bias=bias)
self.value = nn.Linear(d_model, d_model, bias=bias)
# 成对位置偏差
self.pos_bias = nn.Parameter(torch.zeros(seq_len, seq_len), requires_grad=True)
# 掩码
self.w_mask = nn.Parameter(self.create_local_mask(seq_len, local_window_size), requires_grad=False)
self.activation = nn.Sigmoid()
self.output = nn.Linear(d_model, d_model)

@staticmethod
def create_local_mask(seq_len, local_window_size):
"""
创建局部掩码
:param seq_len:
:param local_window_size:
:return:
"""
# 初始化为1
local_mask = torch.ones(seq_len, seq_len, dtype=torch.bool)
# 将 [t+s,+∞] 设置为0
local_mask = torch.tril(local_mask, local_window_size - 1)
# 将 [-∞,t + s] 设置为0
local_mask = torch.tril(local_mask, -(local_window_size - 1))
return local_mask

def forward(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None):
seq_len, _, _ = query.shape

if mask is not None:
assert mask.shape[0] == 1 or mask.shape[0] == query.shape[0]
assert mask.shape[1] == key.shape[0]
assert mask.shape[2] == 1 or mask.shape[2] == query.shape[1]

query = self.query(query)
key = self.key(key)
value = self.value(value)

pos_bias = self.pos_bias[:seq_len, :seq_len] * self.local_mask[:seq_len, :seq_len]
pos_bias = pos_bias.unsqueeze(-1)
pos_bias.masked_fill_(~mask, float('-inf'))

max_key = key.max(dim=0, keepdims=True)[0]
max_pos_bias = pos_bias.max(dim=1, keepdims=True)[0]

exp_key = torch.exp(key - max_key)
exp_pos_bias = torch.exp(pos_bias - max_pos_bias)

num = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key * value)
den = torch.einsum('ijb,jbd->ibd', exp_pos_bias, exp_key)

y = self.activation(query) * num / den

return self.output(y)


def _test_local_mask():
from labml.logger import inspect
inspect(AFTLocal.create_local_mask(10, 4))


if __name__ == '__main__':
_test_local_mask()