transformer源代码学习

前言

代码参考文章:一文彻底搞懂 Transformer(图解+代码手撕)
以下说明的transformer的使用以翻译任务为例子,其目标是将中文句子翻译成英文句子。因此中文句子为源序列,英文句子为目标序列,原序列词典大小就是用于模型训练的所有中文词汇数量,目标序列词典大小则是用于模型训练的所有英文单词数量,所有的词汇在模型训练前都会被标上唯一的数字索引作为一套词典,相应地所有的单词也会在模型训练前都会被标上唯一的数字索引作为另一套词典。在模型训练时,输入数据(源序列)和目标数据(目标序列)都是以这些索引的形式提供的。

代码实现过程中的一个细节:pytorch的函数需要先实例化然后才能使用,后面注释为“定义xxx”就是在实例化。

transformer数据流

transformer结构图如下:
transformer

上图中左半部分为编码器,右半部分为解码器,按照该图的结构,可以在代码中首先大致实现其数据分别在编码器和解码器中的传输流程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class transformer(nn.module):
def __init__():
super(Transformer, self).__init__()


def forward(self,src,tgt):
"""
Args:
src: 输入编码器的源序列,形状为(batch_size, seq_length)
tgt: 输入解码器的目标序列,形状为(batch_size, seq_length)
"""
# 编码器数据流向
encoder_embedding = self.encoder_embedding(src)
en_positional_encoding = self.positional_encoding(encoder_embedding)
src_embedded = en_positional_encoding
enc_output = self.encoder_layer(src_embedded)

# 解码器数据流向
decoder_embedding = self.decoder_embedding(tgt)
de_positional_encoding = self.positional_encoding(decoder_embedding)
tgt_embedded = de_positional_encoding
dec_output = self.decoder_layer(tgt_embedded, enc_output)


但是编码器和解码器一般都是多层堆叠,我们使用pytorch的nn.modulelist函数来简化堆叠过程

1
2
3
4
5
# 定义编码器和解码器的多层堆叠
self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)])

定义了多层堆叠后,再将前面transformer类的forward中单层的编码器解码器改为多层,这样我们就可以得到编码器和解码器中的数据传输过程代码了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class transformer(nn.module):
def __init__():
"""
Args:
d_model: 词嵌入向量的维度
num_heads: 多头注意力模块的头数
d_ff: 前馈网络的隐藏层维度
num_layers: 编码器和解码器的堆叠层数
"""
super(Transformer, self).__init__(d_model,num_heads,d_ff,num_layers)
self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])

def forward(self,src,tgt):
"""
Args:
src: 输入编码器的源序列,形状为(batch_size, seq_length)
tgt: 输入解码器的目标序列,形状为(batch_size, seq_length)
"""
# 编码器数据流向
encoder_embedding = self.encoder_embedding(src)
en_positional_encoding = self.positional_encoding(encoder_embedding)
src_embedded = en_positional_encoding
enc_output = src_embedded
for enc_layer in self.encoder_layers:
enc_output = enc_layer(enc_output, src)

# 解码器数据流向
decoder_embedding = self.decoder_embedding(tgt)
de_positional_encoding = self.positional_encoding(decoder_embedding)
tgt_embedded = de_positional_encoding
dec_output = tgt_embedded
for dec_layer in self.decoder_layers:
dec_output = dec_layer(dec_output, enc_output, src, tgt)

数据流出解码器后还要经过线性层和softmax,线性层需要将词嵌入向量的维度转为目标词典大小的维度,每个维度代表了一个词,之所以这样做,是因为翻译任务被设置为通过目标句子的前面的词预测下一个词,只有经过了线性层转换维度,这样向量通过softmax之后才可以得到预测的每个词的为下一个词的概率。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class transformer(nn.module):
def __init__():
"""
Args:
d_model: 词嵌入向量的维度
num_heads: 多头注意力模块的头数
d_ff: 前馈网络的隐藏层维度
num_layers: 编码器和解码器的堆叠层数
tgt_vocab_size: 目标序列词典大小
"""
super(Transformer, self).__init__(d_model,num_heads,d_ff,num_layers,tgt_vocab_size)
# 定义编码器和解码器的多层堆叠
self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])

# 定义线性层
self.linear = nn.linear(d_model,tgt_vocab_size)


def forward(self,src,tgt):
"""
Args:
src: 输入编码器的源序列,形状为(batch_size, seq_length)
tgt: 输入解码器的目标序列,形状为(batch_size, seq_length)
"""
# 编码器数据流向
encoder_embedding = self.encoder_embedding(src)
en_positional_encoding = self.positional_encoding(encoder_embedding)
src_embedded = en_positional_encoding
enc_output = src_embedded
for enc_layer in self.encoder_layers:
enc_output = enc_layer(enc_output, src)

# 解码器数据流向
decoder_embedding = self.decoder_embedding(tgt)
de_positional_encoding = self.positional_encoding(decoder_embedding)
tgt_embedded = de_positional_encoding
dec_output = tgt_embedded
for dec_layer in self.decoder_layers:
dec_output = dec_layer(dec_output, enc_output, src, tgt)

# 线性层和softmax
output = F.log_softmax(self.linear(dec_output), dim=-1)
return output

在了解完整体的数据流向后,下面开始对每一部分模块的学习。

词嵌入向量生成

根据前面的代码,词嵌入向量生成的代码如下,接下来我们要进一步详细构建该模块

1
2
self.encoder_embedding(src)
self.decoder_embedding(src)

词嵌入向量可以使用nn.Embedding函数获得初始值,后续随训练过程进一步调整,假如词典大小为n,词嵌入向量维度为m,那么该函数就会生成n个m维的词嵌入向量,每个向量对应一个词,该函数也可以使用其他的模型预训练的词嵌入向量作为初始值。使用该函数的实例时实际上是在用词的索引来提取对应的词嵌入向量,该函数相当于一个查找表。

1
2
3
4
5
6
7
8
9
# 定义编码器和解码器的词嵌入层
"""
Args:
src_vocab_size: 原序列词典大小
tgt_vocab_size: 目标序列词典大小
d_model: 词嵌入向量的维度
"""
self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)

位置编码

1
self.positional_encoding(decoder_embedding)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 位置编码的实现
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()

# 计算位置编码
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(
torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)

def forward(self, x):
x = x + x + self.pe[:, :x.size(1)]
return x

单层编码器

1
EncoderLayer(d_model, num_heads, d_ff)

编码器包含了多头注意力和前馈网络两个模块

前馈网络

前馈网络的实现较为简单,只是进行了两次线性变换。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 前馈网络的代码实现
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()

def forward(self, x):
# 线性变换1
x = self.relu(self.linear1(x))

# 线性变换2
x = self.linear2(x)

return x

多头注意力

多头注意力模块的示意图如下
image.png
其中的scaled dot-product attention的具体实现如下图
image.png
其数学公式表示为
image.png
我们首先实现scaled dot-product attention。
矩阵乘法的实现使用torch.matmul实现,该函数对输入的两个向量的最后两个维度(刚好构成一个二维矩阵)进行矩阵乘法,如果输入的向量为三维及以上,则进行多矩阵乘法,意思就是只有最后两个维度会被视为矩阵并进行矩阵乘法,其他的维度会被视为矩阵数量的增加,如torch.matmul((3,2,2),(3,2,2))就是对两个2x2的矩阵进行三次矩阵乘法。
掩码的使用则用pytorch提供的masked_fill函数实现,它需要外部传入mask作为参数,mask与需要进行掩码处理的向量有相同的形状,函数可以将向量的mask=0的相应位置的量都替换为一个极小值从而实现掩码处理。

1
2
3
4
5
6
7
8
9
10
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_model)
# 如果提供了掩码,则应用掩码
if mask is not None:
scores += scores.masked_fill(mask == 0, -1e9)

# 计算注意力权重并应用softmax
attention_weights = torch.softmax(scores, dim=-1)

# 应用注意力到值
attention_output = torch.matmul(attention_weights, value)

数据经过linear层再到scaled dot-product attention的代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 线性投影
query = self.query_linear(query)
key = self.key_linear(key)
value = self.value_linear(value)

# 缩放点积注意力
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.depth)

# 如果提供了掩码,则应用掩码
if mask is not None:
scores += scores.masked_fill(mask == 0, -1e9)

# 计算注意力权重并应用softmax
attention_weights = torch.softmax(scores, dim=-1)

# 应用注意力到值
attention_output = torch.matmul(attention_weights, value)

数据再往下就涉及到多个头生成的数据的拼接了,多头的实现是通过分割Q,K,V矩阵的d_model维度实现的,矩阵的形状为(batchsize,sequence length,d_model),分割的实现将其形状变为(batchsize,sequence length,num_heads, depth),其中(num_heads) x (depth) = d_model,这样做的详细分析可以参考:这样图解Transformer应该没人看不懂了吧——多头注意力机制详解,实现分割的代码如下

1
2
3
def split_heads(self, x):
batch_size, seq_length, d_model = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.depth).transpose(1, 2)

多头注意力的完整实现代码如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# 多头注意力的代码实现
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % num_heads == 0
self.depth = d_model // num_heads

# 查询、键和值的线性投影
self.query_linear = nn.Linear(d_model, d_model)
self.key_linear = nn.Linear(d_model, d_model)
self.value_linear = nn.Linear(d_model, d_model)

# 输出线性投影
self.output_linear = nn.Linear(d_model, d_model)

def split_heads(self, x):
batch_size, seq_length, d_model = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.depth).transpose(1, 2)

def forward(self, query, key, value, mask=None):

# 线性投影
query = self.query_linear(query)
key = self.key_linear(key)
value = self.value_linear(value)

# 分割头部
query = self.split_heads(query)
key = self.split_heads(key)
value = self.split_heads(value)

# 缩放点积注意力
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.depth)

# 如果提供了掩码,则应用掩码
if mask is not None:
scores += scores.masked_fill(mask == 0, -1e9)

# 计算注意力权重并应用softmax
attention_weights = torch.softmax(scores, dim=-1)

# 应用注意力到值
attention_output = torch.matmul(attention_weights, value)

# 合并头部
batch_size, _, seq_length, d_k = attention_output.size()
attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size,
seq_length, self.d_model)

# 线性投影
attention_output = self.output_linear(attention_output)

return attention_output

编码器的完整实现代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 编码器的代码实现
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(EncoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x, mask):

# 自注意力层
attention_output= self.self_attention(x, x,
x, mask)
attention_output = self.dropout(attention_output)
x = x + attention_output
x = self.norm1(x)

# 前馈层
feed_forward_output = self.feed_forward(x)
feed_forward_output = self.dropout(feed_forward_output)
x = x + feed_forward_output
x = self.norm2(x)

return x

单层解码器

1
DecoderLayer(d_model, num_heads, d_ff)

使用前面的多头注意力模块和前馈网络模块,容易得到解码器的代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# 解码器的代码实现
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(DecoderLayer, self).__init__()
self.masked_self_attention = MultiHeadAttention(d_model, num_heads)
self.enc_dec_attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x, encoder_output, src_mask, tgt_mask):

# 掩码的自注意力层
self_attention_output= self.masked_self_attention(x, x, x, tgt_mask)
self_attention_output = self.dropout(self_attention_output)
x = x + self_attention_output
x = self.norm1(x)

# 编码器-解码器注意力层
enc_dec_attention_output= self.enc_dec_attention(x, encoder_output,
encoder_output, src_mask)
enc_dec_attention_output = self.dropout(enc_dec_attention_output)
x = x + enc_dec_attention_output
x = self.norm2(x)

# 前馈层
feed_forward_output = self.feed_forward(x)
feed_forward_output = self.dropout(feed_forward_output)
x = x + feed_forward_output
x = self.norm3(x)

return x

transformer实现代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# TRANSFORMER的实现
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff,
max_len, dropout):
super(Transformer, self).__init__()

# 定义编码器和解码器的词嵌入层
self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)

# 定义位置编码层
self.positional_encoding = PositionalEncoding(d_model, max_len)

# 定义编码器和解码器的多层堆叠
self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)])

# 定义线性层
self.linear = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)

# 生成掩码
def generate_mask(self, src, tgt):
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
seq_length = tgt.size(1)
nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
tgt_mask = tgt_mask & nopeak_mask
return src_mask, tgt_mask

# 前向传播
def forward(self, src, tgt):
src_mask, tgt_mask = self.generate_mask(src, tgt)

# 编码器输入的词嵌入和位置编码
encoder_embedding = self.encoder_embedding(src)
en_positional_encoding = self.positional_encoding(encoder_embedding)
src_embedded = self.dropout(en_positional_encoding)

# 解码器输入的词嵌入和位置编码
decoder_embedding = self.decoder_embedding(tgt)
de_positional_encoding = self.positional_encoding(decoder_embedding)
tgt_embedded = self.dropout(de_positional_encoding)

enc_output = src_embedded
for enc_layer in self.encoder_layers:
enc_output = enc_layer(enc_output, src_mask)

dec_output = tgt_embedded
for dec_layer in self.decoder_layers:
dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

output = self.linear(dec_output)
return output