前言

关于transformer的理论部分,之前已经说过了,大家对于transformer应该有了一个大致的了解了。如果没有看过的可以看看这篇文章一文带你深度剖析什么叫Transformer。今天带大家来手把手具体如何实现基于transformer的机器翻译。总共分为三部分,transformer模型的搭建,数据集的处理,训练(train)的构建。

完整的代码地址:[基于transformer的机器翻译](Branches · wolalala/-transformer-pytorch-)如果代码对你有帮助麻烦给个star谢谢。文章对你有帮助麻烦给个三连鼓励谢谢。

transformer模型的搭建

带着大家一步一步的搭建出完整的transformer。然后又回到这个看了很多遍的图了,我们会按照这个图,一步步搭建我们的transformer。

Input embedding

首先是对输入的embedding操作,大家注意有个细节,embedding操作后乘了个sqrt(d_model)。为什么这么操作,是为了让其尺度范围与后续的位置编码的尺度范围相同,避免梯度爆炸梯度消失现象。

class InputEmbeddings(nn.Module):

    def __init__(self, d_model: int, vocab_size: int) -> None:
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.embedding = nn.Embedding(vocab_size, d_model) def forward(self, x):
# (batch, seq_len) --> (batch, seq_len, d_model)
# Multiply by sqrt(d_model) to scale the embeddings according to the paper
return self.embedding(x) * math.sqrt(self.d_model)

然后就是位置编码了,奇数维度cos,偶数维度sin。并且我们不需要其梯度的运算,我们只需要将其位置信息传入模型,让模型知道即可。

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.dropout = nn.Dropout(dropout)
# Create a matrix of shape (seq_len, d_model)
pe = torch.zeros(seq_len, d_model)
# Create a vector of shape (seq_len)
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
# Create a vector of shape (d_model)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
# Apply sine to even indices
pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
# Apply cosine to odd indices
pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
# Add a batch dimension to the positional encoding
pe = pe.unsqueeze(0) # (1, seq_len, d_model)
# Register the positional encoding as a buffer
self.register_buffer('pe', pe) def forward(self, x):
x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
return self.dropout(x)

Encoder

然后就是Encoder的搭建,首先我们需要知道我们的思路去如何构建Encoder部分,我们先完成每个小组件的搭建,然后再将每个小组件组合起来构成整个Encoder部分。

首先来看norm部分,我们使用的是Layernorm(层归一化)。这个就没什么问题,求均值,方差即可,注意我们加入了个参数eps,设置极小的数,是为了避免方差为0无意义的情况。

class LayerNormalization(nn.Module):

    def __init__(self, features: int, eps: float = 10 ** -6) -> None:
super().__init__()
self.eps = eps
self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter def forward(self, x):
# x: (batch, seq_len, hidden_size)
# Keep the dimension for broadcasting
mean = x.mean(dim=-1, keepdim=True) # (batch, seq_len, 1)
# Keep the dimension for broadcasting
std = x.std(dim=-1, keepdim=True) # (batch, seq_len, 1)
# eps is to prevent dividing by zero or when std is very small
return self.alpha * (x - mean) / (std + self.eps) + self.bias

然后是残差连接的部分了,这部分也很容易理解,本质上很简单就是该层与上一层相加。

class ResidualConnection(nn.Module):

    def __init__(self, features: int, dropout: float) -> None:
super().__init__()
self.dropout = nn.Dropout(dropout)
self.norm = LayerNormalization(features) def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))

然后是前馈网络,其实本质上就是一个线性层,直接写就行了。注意维度的转变。

class FeedForwardBlock(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
super().__init__()
self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2 def forward(self, x):
# (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

然后是最重要的部分,多头注意力机制。多头的数量我们设置为h,这里注意就是d_model必须能够被h整除。然后就是注意力机制中q,k,v的设置。如果是初学者的话,这里就理解为其都是个线性层即可。其原理需要比较长的篇幅,如果有需要也可以评论区说下,下篇给大家讲讲什么是注意力机制。很多步骤维度的变化都给大家标出来了,一定要注意理解。

class MultiHeadAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float) -> None:
super().__init__()
self.d_model = d_model # Embedding vector size
self.h = h # Number of heads
# Make sure d_model is divisible by h
assert d_model % h == 0, "d_model is not divisible by h" self.d_k = d_model // h # Dimension of vector seen by each head
self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
self.dropout = nn.Dropout(dropout) @staticmethod
def attention(query, key, value, mask, dropout: nn.Dropout):
d_k = query.shape[-1]
# Just apply the formula from the paper
# (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
# Write a very low value (indicating -inf) to the positions where mask == 0
attention_scores.masked_fill_(mask == 0, -1e9)
attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
if dropout is not None:
attention_scores = dropout(attention_scores)
# (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
# return attention scores which can be used for visualization
return (attention_scores @ value), attention_scores def forward(self, q, k, v, mask):
query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model) # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2) # Calculate attention
x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout) # Combine all the heads together
# (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k) # Multiply by Wo
# (batch, seq_len, d_model) --> (batch, seq_len, d_model)
return self.w_o(x)

最后就是大家整个Encoder模块了。注意我们首先是EncoderBlock的建立,然后根据需要搭建几个就循环即可。

class EncoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock,
feed_forward_block: FeedForwardBlock, dropout: float) -> None:
super().__init__()
self.self_attention_block = self_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)]) def forward(self, x, src_mask):
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
x = self.residual_connections[1](x, self.feed_forward_block)
return x class Encoder(nn.Module): def __init__(self, features: int, layers: nn.ModuleList) -> None:
super().__init__()
self.layers = layers
self.norm = LayerNormalization(features) def forward(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)

Decoder

Decoder模块的模块和Encoder是相同的,所以我们直接搭建就可以了。

class DecoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock,
cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock,
dropout: float) -> None:
super().__init__()
self.self_attention_block = self_attention_block
self.cross_attention_block = cross_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)]) def forward(self, x, encoder_output, src_mask, tgt_mask):
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output,
src_mask))
x = self.residual_connections[2](x, self.feed_forward_block)
return x class Decoder(nn.Module): def __init__(self, features: int, layers: nn.ModuleList) -> None:
super().__init__()
self.layers = layers
self.norm = LayerNormalization(features) def forward(self, x, encoder_output, src_mask, tgt_mask):
for layer in self.layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return self.norm(x)

output

最后就是输出层的构建了

class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size) -> None:
super().__init__()
self.proj = nn.Linear(d_model, vocab_size) def forward(self, x) -> None:
# (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
return self.proj(x)

transformer构建

根据之前搭建的每个小模块集成成完整的transformer,同时初始化相关的参数。这里的各项参数都是参照原论文所设置的。

class Transformer(nn.Module):

    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings,
src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.src_pos = src_pos
self.tgt_pos = tgt_pos
self.projection_layer = projection_layer def encode(self, src, src_mask):
# (batch, seq_len, d_model)
src = self.src_embed(src)
src = self.src_pos(src)
return self.encoder(src, src_mask) def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
# (batch, seq_len, d_model)
tgt = self.tgt_embed(tgt)
tgt = self.tgt_pos(tgt)
return self.decoder(tgt, encoder_output, src_mask, tgt_mask) def project(self, x):
# (batch, seq_len, vocab_size)
return self.projection_layer(x) def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int = 512,
N: int = 6, h: int = 8, dropout: float = 0.1, d_ff: int = 2048) -> Transformer:
# Create the embedding layers
src_embed = InputEmbeddings(d_model, src_vocab_size)
tgt_embed = InputEmbeddings(d_model, tgt_vocab_size) # Create the positional encoding layers
src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout) # Create the encoder blocks
encoder_blocks = []
for _ in range(N):
encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
encoder_blocks.append(encoder_block) # Create the decoder blocks
decoder_blocks = []
for _ in range(N):
decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block,
feed_forward_block, dropout)
decoder_blocks.append(decoder_block) # Create the encoder and decoder
encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
decoder = Decoder(d_model, nn.ModuleList(decoder_blocks)) # Create the projection layer
projection_layer = ProjectionLayer(d_model, tgt_vocab_size) # Create the transformer
transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer) # Initialize the parameters
for p in transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p) return transformer

data数据集处理

双语数据集的处理,这里使用的是英语意大利语数据集。注意处理的方式。填充的处理,以及整个输入的构建,encoder_input有开始符、内容、终止符、填冲构成。其mask掩码就是掩盖后续的填充内容即可。decoder_input由开始符、内容、填冲构成,其mask掩码需要掩盖预测后面的信息,所以我们不仅需要掩盖填充内容,还需掩盖预测时后续的内容,这里采用因果掩码的设置。就是causal_mask函数。还有label有内容、终止符、填冲构成,注意无开始符合,我们只需要其什么时候终止即可。

import torch
import torch.nn as nn
from torch.utils.data import Dataset class BilingualDataset(Dataset): def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
super().__init__()
self.seq_len = seq_len self.ds = ds
self.tokenizer_src = tokenizer_src
self.tokenizer_tgt = tokenizer_tgt
self.src_lang = src_lang
self.tgt_lang = tgt_lang self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64) def __len__(self):
return len(self.ds) def __getitem__(self, idx):
src_target_pair = self.ds[idx]
src_text = src_target_pair['translation'][self.src_lang]
tgt_text = src_target_pair['translation'][self.tgt_lang] # Transform the text into tokens
enc_input_tokens = self.tokenizer_src.encode(src_text).ids
dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids # Add sos, eos and padding to each sentence
enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # We will add <s> and </s>
# We will only add <s>, and </s> only on the label
dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1 # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
raise ValueError("Sentence is too long") # Add <s> and </s> token
encoder_input = torch.cat(
[
self.sos_token,
torch.tensor(enc_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
],
dim=0,
) # Add only <s> token
decoder_input = torch.cat(
[
self.sos_token,
torch.tensor(dec_input_tokens, dtype=torch.int64),
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
],
dim=0,
) # Add only </s> token
label = torch.cat(
[
torch.tensor(dec_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
],
dim=0,
) # Double check the size of the tensors to make sure they are all seq_len long
assert encoder_input.size(0) == self.seq_len
assert decoder_input.size(0) == self.seq_len
assert label.size(0) == self.seq_len return {
"encoder_input": encoder_input, # (seq_len)
"decoder_input": decoder_input, # (seq_len)
"encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
"decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)),
# (1, seq_len) & (1, seq_len, seq_len),
"label": label, # (seq_len)
"src_text": src_text,
"tgt_text": tgt_text,
} def causal_mask(size):
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
return mask == 0

train

train loop的全过程,导入模型以及处理的相关数据,代码最好是能够在gpu上运行,要不然太慢了。损失函数采用交叉熵损失函数,学习策略采用adam。然后每个epoch完成都会进行一次验证评估。

from model import build_transformer
from dataset import BilingualDataset, causal_mask
from config import get_config, get_weights_file_path, latest_weights_file_path import torchtext.datasets as datasets
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import LambdaLR import warnings
from tqdm import tqdm
import os
from pathlib import Path # Huggingface datasets and tokenizers
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace import torchmetrics
from torch.utils.tensorboard import SummaryWriter os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890' def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
sos_idx = tokenizer_tgt.token_to_id('[SOS]')
eos_idx = tokenizer_tgt.token_to_id('[EOS]') # Precompute the encoder output and reuse it for every step
encoder_output = model.encode(source, source_mask)
# Initialize the decoder input with the sos token
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
while True:
if decoder_input.size(1) == max_len:
break # build mask for target
decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device) # calculate output
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask) # get next token
prob = model.project(out[:, -1])
_, next_word = torch.max(prob, dim=1)
decoder_input = torch.cat(
[decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
) if next_word == eos_idx:
break return decoder_input.squeeze(0) def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer,
num_examples=2):
model.eval()
count = 0 source_texts = []
expected = []
predicted = [] try:
# get the console window width
with os.popen('stty size', 'r') as console:
_, console_width = console.read().split()
console_width = int(console_width)
except:
# If we can't get the console width, use 80 as default
console_width = 80 with torch.no_grad():
for batch in validation_ds:
count += 1
encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len) # check that the batch size is 1
assert encoder_input.size(
0) == 1, "Batch size must be 1 for validation" model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device) source_text = batch["src_text"][0]
target_text = batch["tgt_text"][0]
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy()) source_texts.append(source_text)
expected.append(target_text)
predicted.append(model_out_text) # Print the source, target and model output
print_msg('-' * console_width)
print_msg(f"{f'SOURCE: ':>12}{source_text}")
print_msg(f"{f'TARGET: ':>12}{target_text}")
print_msg(f"{f'PREDICTED: ':>12}{model_out_text}") if count == num_examples:
print_msg('-' * console_width)
break if writer:
# Evaluate the character error rate
# Compute the char error rate
metric = torchmetrics.CharErrorRate()
cer = metric(predicted, expected)
writer.add_scalar('validation cer', cer, global_step)
writer.flush() # Compute the word error rate
metric = torchmetrics.WordErrorRate()
wer = metric(predicted, expected)
writer.add_scalar('validation wer', wer, global_step)
writer.flush() # Compute the BLEU metric
metric = torchmetrics.BLEUScore()
bleu = metric(predicted, expected)
writer.add_scalar('validation BLEU', bleu, global_step)
writer.flush() def get_all_sentences(ds, lang):
for item in ds:
yield item['translation'][lang] def get_or_build_tokenizer(config, ds, lang):
tokenizer_path = Path(config['tokenizer_file'].format(lang))
if not Path.exists(tokenizer_path):
# Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
tokenizer.save(str(tokenizer_path))
else:
tokenizer = Tokenizer.from_file(str(tokenizer_path))
return tokenizer def get_ds(config):
# It only has the train split, so we divide it overselves
#ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')
# ds_raw=load_dataset('./opus-100-corpus-en-zh-v1.0') ds_raw = load_dataset("Helsinki-NLP/opus-100", "en-zh",split='train') # Build tokenizers
tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt']) # Keep 90% for training, 10% for validation
train_ds_size = int(0.9 * len(ds_raw))
val_ds_size = len(ds_raw) - train_ds_size
train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size]) train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'],
config['seq_len'])
val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'],
config['seq_len']) # Find the maximum length of each sentence in the source and target sentence
max_len_src = 0
max_len_tgt = 0 for item in ds_raw:
src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
max_len_src = max(max_len_src, len(src_ids))
max_len_tgt = max(max_len_tgt, len(tgt_ids)) print(f'Max length of source sentence: {max_len_src}')
print(f'Max length of target sentence: {max_len_tgt}') train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True) return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt def get_model(config, vocab_src_len, vocab_tgt_len):
model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'],
d_model=config['d_model'])
return model def train_model(config):
# Define the device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
if (device == 'cuda'):
print(f"Device name: {torch.cuda.get_device_name(device.index)}")
print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
elif (device == 'mps'):
print(f"Device name: <mps>")
else:
print("NOTE: If you have a GPU, consider using it for training.")
print(
" On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
print(
" On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
device = torch.device(device) # Make sure the weights folder exists
Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True) train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
# Tensorboard
writer = SummaryWriter(config['experiment_name']) optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9) # If the user specified a model to preload before training, load it
initial_epoch = 0
global_step = 0
preload = config['preload']
model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config,
preload) if preload else None
if model_filename:
print(f'Preloading model {model_filename}')
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])
initial_epoch = state['epoch'] + 1
optimizer.load_state_dict(state['optimizer_state_dict'])
global_step = state['global_step']
else:
print('No model to preload, starting from scratch') loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device) for epoch in range(initial_epoch, config['num_epochs']):
torch.cuda.empty_cache()
model.train()
batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
for batch in batch_iterator:
encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len) # Run the tensors through the encoder, decoder and the projection layer
encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
decoder_output = model.decode(encoder_output, encoder_mask, decoder_input,
decoder_mask) # (B, seq_len, d_model)
proj_output = model.project(decoder_output) # (B, seq_len, vocab_size) # Compare the output with the label
label = batch['label'].to(device) # (B, seq_len) # Compute the loss using a simple cross entropy
loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"}) # Log the loss
writer.add_scalar('train loss', loss.item(), global_step)
writer.flush() # Backpropagate the loss
loss.backward() # Update the weights
optimizer.step()
optimizer.zero_grad(set_to_none=True) global_step += 1 # Run validation at the end of every epoch
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device,
lambda msg: batch_iterator.write(msg), global_step, writer) # Save the model at the end of every epoch
model_filename = get_weights_file_path(config, f"{epoch:02d}")
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'global_step': global_step
}, model_filename) if __name__ == '__main__':
warnings.filterwarnings("ignore")
config = get_config()
train_model(config)

config

设置batchsize为8,然后seq_len设置要注意,我们需要超过句子最大的长度,这个注意train输出信息时会有输出最大长度的,如果更换数据集的话这里注意就行。

from pathlib import Path

def get_config():
return {
"batch_size": 8,
"num_epochs": 20,
"lr": 10**-4,
"seq_len": 350,
"d_model": 512,
"datasource": 'opus_books',
"lang_src": "en",
"lang_tgt": "it",
"model_folder": "weights",
"model_basename": "tmodel_",
"preload": "latest",
"tokenizer_file": "tokenizer_{0}.json",
"experiment_name": "runs/tmodel"
} def get_weights_file_path(config, epoch: str):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}{epoch}.pt"
return str(Path('.') / model_folder / model_filename) # Find the latest weights file in the weights folder
def latest_weights_file_path(config):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}*"
weights_files = list(Path(model_folder).glob(model_filename))
if len(weights_files) == 0:
return None
weights_files.sort()
return str(weights_files[-1])

参考

transformer

基于transformer的机器翻译:手把手教你实现的更多相关文章

  1. 用Python手把手教你搭一个Transformer!

    来源商业新知网,原标题:百闻不如一码!手把手教你用Python搭一个Transformer 与基于RNN的方法相比,Transformer 不需要循环,主要是由Attention 机制组成,因而可以充 ...

  2. 网络编程懒人入门(八):手把手教你写基于TCP的Socket长连接

    本文原作者:“水晶虾饺”,原文由“玉刚说”写作平台提供写作赞助,原文版权归“玉刚说”微信公众号所有,即时通讯网收录时有改动. 1.引言 好多小白初次接触即时通讯(比如:IM或者消息推送应用)时,总是不 ...

  3. 手把手教你写基于C++ Winsock的图片下载的网络爬虫

    手把手教你写基于C++ Winsock的图片下载的网络爬虫 先来说一下主要的技术点: 1. 输入起始网址,使用ssacnf函数解析出主机号和路径(仅处理http协议网址) 2. 使用socket套接字 ...

  4. Delphi - 手把手教你基于D7+Access常用管理系统架构的设计与实现 (更新中)

    前言 从事软件开发工作好多年了,学的越深入越觉得自己无知,所以还是要对知识保持敬畏之心,活到老,学到老! 健身和代码一样都不能少,身体是革命的本钱,特别是我们这种高危工种,所以小伙伴们运动起来!有没有 ...

  5. 庐山真面目之十一微服务架构手把手教你搭建基于Jenkins的企业级CI/CD环境

    庐山真面目之十一微服务架构手把手教你搭建基于Jenkins的企业级CI/CD环境 一.介绍 说起微服务架构来,有一个环节是少不了的,那就是CI/CD持续集成的环境.当然,搭建CI/CD环境的工具很多, ...

  6. 手把手教从零开始在GitHub上使用Hexo搭建博客教程(三)-使用Travis自动部署Hexo(1)

    前言 前面两篇文章介绍了在github上使用hexo搭建博客的基本环境和hexo相关参数设置等. 基于目前,博客基本上是可以完美运行了. 但是,有一点是不太好,就是源码同步问题,如果在不同的电脑上写文 ...

  7. 手把手教从零开始在GitHub上使用Hexo搭建博客教程(二)-Hexo参数设置

    前言 前文手把手教从零开始在GitHub上使用Hexo搭建博客教程(一)-附GitHub注册及配置介绍了github注册.git相关设置以及hexo基本操作. 本文主要介绍一下hexo的常用参数设置. ...

  8. 手把手教从零开始在GitHub上使用Hexo搭建博客教程(一)-附GitHub注册及配置

    前言 有朋友问了我关于博客系统搭建相关的问题,由于是做开发相关的工作,我给他推荐的是使用github的gh-pages服务搭建个人博客. 推荐理由: 免费:github提供gh-pages服务是免费的 ...

  9. 手把手教你调试Linux C++ 代码(一步到位包含静态库和动态库调试)

    手把手教你调试Linux C++ 代码 软件调试本身就是一项相对复杂的活动,他不仅要求调试者有着清晰的思路,而且对调试者本身的技能也有很高的要求.Windows下Visual Studio为我们做了很 ...

  10. 手把手教你接口自动化测试 – SoapUI & Groovy

    手把手教你接口自动化测试 – SoapUI & Groovy http://www.cnblogs.com/wade-xu/p/4236295.html 关键词:SoapUI接口测试,接口自动 ...

随机推荐

  1. 无网环境Docker Rpm离线安装

    总体思路:找一台可以联网的linux,下载docker的RPM依赖包而不进行安装(yum localinstall),将所有依赖的rpm环境打包好,再在无网环境中解压逐一安装(rpm: --force ...

  2. uni-app根据不同的类型绑定不同类名

    <template> <view class="page-demo"> <view class="demo" v-for=&quo ...

  3. CPU的指令周期

    本文分享自天翼云开发者社区<CPU的指令周期>,作者:冯****怡 指令周期(Instruction Cycle) CPU中会有 存器.指令寄存器.控制器等多类单元.指令集,就是CPU中用 ...

  4. Java连接数据库 CreateStatement 和 PrepareStatement 的区别与优劣

    一.简介 先说下CreateStatement 和 PrepareStatement 这俩到底是干啥的吧. 作用:其实这俩干的活儿都一样,就是创建了一个对象然后去通过对象调用executeQuery方 ...

  5. 流程控制之Scanner

    Scanner对象 可以通过scanner类(java.util.Scanner)来获取用户的输入 基本语法: Scanner s = new Scanner(System.in); 通过Scanne ...

  6. C# Windows Service 安装与卸载

    安装与卸载的使用工具 C:\Windows\Microsoft.NET\Framework64\v4.0.30319\InstallUtil.exe (一般安装了.NetFramework 后就会有该 ...

  7. SignalR 外部调用自定义Hub类的方法,Clients为null

    这是因为外部调用的类的对象和你连接的Hub类的对象,这两个对象 不!一!样! 解决方法 在自定义的Hub类中,注入IHubContext对象,然后在方法中调用IHubContext对象来向前端推送数据 ...

  8. Typecho 数据备份及程序升级详细步骤教程

    数据库备份看自己,习惯性更新前都备份,出错直接滚回去 数据库备份 直接在宝塔数据库那个模块备份即可,备份完建议下载本地或者保存到OSS 备份网站文件 理论上只需要备份/usr/目录即可,因为这个目录包 ...

  9. ES - 概述

    前言 Q1:ElasticSearch 是什么? 为什么要学习? ElasticSearch 是一个分布式.可扩展.实时的搜索和分析引擎,基于 Lucene 构建.它可以用于全文搜索.结构化搜索.分析 ...

  10. Vuex:让状态管理不再头疼的“管家”

    如果你正在开发一个 Vue.js 应用程序,但发现自己被各种组件之间的状态共享问题搞得焦头烂额,那么 Vuex 就是你需要的"超级管家".Vuex 是专门为 Vue.js 设计的状 ...