Spaces:
No application file
No application file
import torch | |
from torch import nn | |
import math | |
class RelativePositionalEncoding(nn.Module): | |
def __init__(self, emb_size: int, max_len: int = 5000): | |
super(RelativePositionalEncoding, self).__init__() | |
self.emb_size = emb_size | |
self.max_len = max_len | |
relative_positions = torch.arange(-max_len, max_len + 1, dtype=torch.long) | |
scales = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size) | |
relative_positional_encodings = torch.zeros((2 * max_len + 1, emb_size)) | |
relative_positional_encodings[:, 0::2] = torch.sin(relative_positions.unsqueeze(-1) * scales) | |
relative_positional_encodings[:, 1::2] = torch.cos(relative_positions.unsqueeze(-1) * scales) | |
self.register_buffer('relative_positional_encodings', relative_positional_encodings) | |
def forward(self, length: int): | |
center_pos = self.max_len | |
return self.relative_positional_encodings[center_pos - length + 1 : center_pos + 1] | |
class TokenEmbedding(nn.Module): | |
def __init__(self, vocab_size: int, emb_size: int): | |
super(TokenEmbedding, self).__init__() | |
self.embedding = nn.Embedding(vocab_size, emb_size) | |
self.emb_size = emb_size | |
def forward(self, tokens: torch.Tensor): | |
return self.embedding(tokens.long()) * math.sqrt(self.emb_size) | |
class TransformerModelRelative(nn.Module): | |
def __init__(self, num_tokens_en, num_tokens_fr, embed_size, nhead, dim_feedforward, max_seq_length): | |
super(TransformerModel, self).__init__() | |
self.embed_size = embed_size | |
self.src_tok_emb = TokenEmbedding(num_tokens_en, embed_size) | |
self.tgt_tok_emb = TokenEmbedding(num_tokens_fr, embed_size) | |
self.positional_encoding = RelativePositionalEncoding(embed_size, max_len=max_seq_length) | |
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0.1) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3) | |
decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0.1) | |
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=3) | |
self.generator = nn.Linear(embed_size, num_tokens_fr) | |
def encode(self, src, src_mask): | |
src_emb = self.src_tok_emb(src) + self.positional_encoding(src.size(1)) | |
return self.transformer_encoder(src_emb, src_key_padding_mask=src_mask) | |
def decode(self, tgt, memory, tgt_mask, tgt_key_padding_mask): | |
tgt_emb = self.tgt_tok_emb(tgt) + self.positional_encoding(tgt.size(1)) | |
return self.transformer_decoder(tgt_emb, memory, tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask) | |
def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None): | |
memory = self.encode(src, src_padding_mask) | |
output = self.decode(tgt, memory, tgt_mask, tgt_padding_mask) | |
return self.generator(output) | |
def generate_square_subsequent_mask(self, sz): | |
mask = torch.triu(torch.ones(sz, sz)).transpose(0, 1) | |
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
return mask |