webshop's picture
add model
1a0f94f
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import rnn
def duplicate(output, mask, lens, act_sizes):
"""
Duplicate the output based on the action sizes.
"""
output = torch.cat([output[i:i+1].repeat(j, 1, 1) for i, j in enumerate(act_sizes)], dim=0)
mask = torch.cat([mask[i:i+1].repeat(j, 1) for i, j in enumerate(act_sizes)], dim=0)
lens = list(itertools.chain.from_iterable([lens[i:i+1] * j for i, j in enumerate(act_sizes)]))
return output, mask, lens
def get_aggregated(output, lens, method):
"""
Get the aggregated hidden state of the encoder.
B x D
"""
if method == 'mean':
return torch.stack([output[i, :j, :].mean(0) for i, j in enumerate(lens)], dim=0)
elif method == 'last':
return torch.stack([output[i, j-1, :] for i, j in enumerate(lens)], dim=0)
elif method == 'first':
return output[:, 0, :]
class EncoderRNN(nn.Module):
def __init__(self, input_size, num_units, nlayers, concat,
bidir, layernorm, return_last):
super().__init__()
self.layernorm = (layernorm == 'layer')
if layernorm:
self.norm = nn.LayerNorm(input_size)
self.rnns = []
for i in range(nlayers):
if i == 0:
input_size_ = input_size
output_size_ = num_units
else:
input_size_ = num_units if not bidir else num_units * 2
output_size_ = num_units
self.rnns.append(
nn.GRU(input_size_, output_size_, 1,
bidirectional=bidir, batch_first=True))
self.rnns = nn.ModuleList(self.rnns)
self.init_hidden = nn.ParameterList(
[nn.Parameter(
torch.zeros(size=(2 if bidir else 1, 1, num_units)),
requires_grad=True) for _ in range(nlayers)])
self.concat = concat
self.nlayers = nlayers
self.return_last = return_last
self.reset_parameters()
def reset_parameters(self):
with torch.no_grad():
for rnn_layer in self.rnns:
for name, p in rnn_layer.named_parameters():
if 'weight_ih' in name:
torch.nn.init.xavier_uniform_(p.data)
elif 'weight_hh' in name:
torch.nn.init.orthogonal_(p.data)
elif 'bias' in name:
p.data.fill_(0.0)
else:
p.data.normal_(std=0.1)
def get_init(self, bsz, i):
return self.init_hidden[i].expand(-1, bsz, -1).contiguous()
def forward(self, inputs, input_lengths=None):
bsz, slen = inputs.size(0), inputs.size(1)
if self.layernorm:
inputs = self.norm(inputs)
output = inputs
outputs = []
lens = 0
if input_lengths is not None:
lens = input_lengths # .data.cpu().numpy()
for i in range(self.nlayers):
hidden = self.get_init(bsz, i)
# output = self.dropout(output)
if input_lengths is not None:
output = rnn.pack_padded_sequence(output, lens,
batch_first=True,
enforce_sorted=False)
output, hidden = self.rnns[i](output, hidden)
if input_lengths is not None:
output, _ = rnn.pad_packed_sequence(output, batch_first=True)
if output.size(1) < slen:
# used for parallel
# padding = Variable(output.data.new(1, 1, 1).zero_())
padding = torch.zeros(
size=(1, 1, 1), dtype=output.type(),
device=output.device())
output = torch.cat(
[output,
padding.expand(
output.size(0),
slen - output.size(1),
output.size(2))
], dim=1)
if self.return_last:
outputs.append(
hidden.permute(1, 0, 2).contiguous().view(bsz, -1))
else:
outputs.append(output)
if self.concat:
return torch.cat(outputs, dim=2)
return outputs[-1]
class BiAttention(nn.Module):
def __init__(self, input_size, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.input_linear = nn.Linear(input_size, 1, bias=False)
self.memory_linear = nn.Linear(input_size, 1, bias=False)
self.dot_scale = nn.Parameter(
torch.zeros(size=(input_size,)).uniform_(1. / (input_size ** 0.5)),
requires_grad=True)
self.init_parameters()
def init_parameters(self):
return
def forward(self, context, memory, mask):
bsz, input_len = context.size(0), context.size(1)
memory_len = memory.size(1)
context = self.dropout(context)
memory = self.dropout(memory)
input_dot = self.input_linear(context)
memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len)
cross_dot = torch.bmm(
context * self.dot_scale,
memory.permute(0, 2, 1).contiguous())
att = input_dot + memory_dot + cross_dot
att = att - 1e30 * (1 - mask[:, None])
weight_one = F.softmax(att, dim=-1)
output_one = torch.bmm(weight_one, memory)
weight_two = (F.softmax(att.max(dim=-1)[0], dim=-1)
.view(bsz, 1, input_len))
output_two = torch.bmm(weight_two, context)
return torch.cat(
[context, output_one, context * output_one,
output_two * output_one],
dim=-1)