ashwinradhe commited on
Commit
b8fa43f
·
verified ·
1 Parent(s): c72ce8e

Update relation_model.py

Browse files
Files changed (1) hide show
  1. relation_model.py +1 -1
relation_model.py CHANGED
@@ -26,7 +26,7 @@ class TokenEmbedding(nn.Module):
26
  def forward(self, tokens: torch.Tensor):
27
  return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
28
 
29
- class TransformerModel(nn.Module):
30
  def __init__(self, num_tokens_en, num_tokens_fr, embed_size, nhead, dim_feedforward, max_seq_length):
31
  super(TransformerModel, self).__init__()
32
  self.embed_size = embed_size
 
26
  def forward(self, tokens: torch.Tensor):
27
  return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
28
 
29
+ class TransformerModelRelative(nn.Module):
30
  def __init__(self, num_tokens_en, num_tokens_fr, embed_size, nhead, dim_feedforward, max_seq_length):
31
  super(TransformerModel, self).__init__()
32
  self.embed_size = embed_size