translation_service / relation_model.py
ashwinradhe's picture
Update relation_model.py
b8fa43f verified
import torch
from torch import nn
import math
class RelativePositionalEncoding(nn.Module):
def __init__(self, emb_size: int, max_len: int = 5000):
super(RelativePositionalEncoding, self).__init__()
self.emb_size = emb_size
self.max_len = max_len
relative_positions = torch.arange(-max_len, max_len + 1, dtype=torch.long)
scales = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
relative_positional_encodings = torch.zeros((2 * max_len + 1, emb_size))
relative_positional_encodings[:, 0::2] = torch.sin(relative_positions.unsqueeze(-1) * scales)
relative_positional_encodings[:, 1::2] = torch.cos(relative_positions.unsqueeze(-1) * scales)
self.register_buffer('relative_positional_encodings', relative_positional_encodings)
def forward(self, length: int):
center_pos = self.max_len
return self.relative_positional_encodings[center_pos - length + 1 : center_pos + 1]
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, emb_size: int):
super(TokenEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.emb_size = emb_size
def forward(self, tokens: torch.Tensor):
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
class TransformerModelRelative(nn.Module):
def __init__(self, num_tokens_en, num_tokens_fr, embed_size, nhead, dim_feedforward, max_seq_length):
super(TransformerModel, self).__init__()
self.embed_size = embed_size
self.src_tok_emb = TokenEmbedding(num_tokens_en, embed_size)
self.tgt_tok_emb = TokenEmbedding(num_tokens_fr, embed_size)
self.positional_encoding = RelativePositionalEncoding(embed_size, max_len=max_seq_length)
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0.1)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)
decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0.1)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=3)
self.generator = nn.Linear(embed_size, num_tokens_fr)
def encode(self, src, src_mask):
src_emb = self.src_tok_emb(src) + self.positional_encoding(src.size(1))
return self.transformer_encoder(src_emb, src_key_padding_mask=src_mask)
def decode(self, tgt, memory, tgt_mask, tgt_key_padding_mask):
tgt_emb = self.tgt_tok_emb(tgt) + self.positional_encoding(tgt.size(1))
return self.transformer_decoder(tgt_emb, memory, tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)
def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
memory = self.encode(src, src_padding_mask)
output = self.decode(tgt, memory, tgt_mask, tgt_padding_mask)
return self.generator(output)
def generate_square_subsequent_mask(self, sz):
mask = torch.triu(torch.ones(sz, sz)).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask