import torch from torch import nn import math class RelativePositionalEncoding(nn.Module): def __init__(self, emb_size: int, max_len: int = 5000): super(RelativePositionalEncoding, self).__init__() self.emb_size = emb_size self.max_len = max_len relative_positions = torch.arange(-max_len, max_len + 1, dtype=torch.long) scales = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size) relative_positional_encodings = torch.zeros((2 * max_len + 1, emb_size)) relative_positional_encodings[:, 0::2] = torch.sin(relative_positions.unsqueeze(-1) * scales) relative_positional_encodings[:, 1::2] = torch.cos(relative_positions.unsqueeze(-1) * scales) self.register_buffer('relative_positional_encodings', relative_positional_encodings) def forward(self, length: int): center_pos = self.max_len return self.relative_positional_encodings[center_pos - length + 1 : center_pos + 1] class TokenEmbedding(nn.Module): def __init__(self, vocab_size: int, emb_size: int): super(TokenEmbedding, self).__init__() self.embedding = nn.Embedding(vocab_size, emb_size) self.emb_size = emb_size def forward(self, tokens: torch.Tensor): return self.embedding(tokens.long()) * math.sqrt(self.emb_size) class TransformerModelRelative(nn.Module): def __init__(self, num_tokens_en, num_tokens_fr, embed_size, nhead, dim_feedforward, max_seq_length): super(TransformerModel, self).__init__() self.embed_size = embed_size self.src_tok_emb = TokenEmbedding(num_tokens_en, embed_size) self.tgt_tok_emb = TokenEmbedding(num_tokens_fr, embed_size) self.positional_encoding = RelativePositionalEncoding(embed_size, max_len=max_seq_length) encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0.1) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3) decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0.1) self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=3) self.generator = nn.Linear(embed_size, num_tokens_fr) def encode(self, src, src_mask): src_emb = self.src_tok_emb(src) + self.positional_encoding(src.size(1)) return self.transformer_encoder(src_emb, src_key_padding_mask=src_mask) def decode(self, tgt, memory, tgt_mask, tgt_key_padding_mask): tgt_emb = self.tgt_tok_emb(tgt) + self.positional_encoding(tgt.size(1)) return self.transformer_decoder(tgt_emb, memory, tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask) def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None): memory = self.encode(src, src_padding_mask) output = self.decode(tgt, memory, tgt_mask, tgt_padding_mask) return self.generator(output) def generate_square_subsequent_mask(self, sz): mask = torch.triu(torch.ones(sz, sz)).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask