"""Define RNN-based encoders."""
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from onmt.encoders.encoder import EncoderBase
from onmt.utils.rnn_factory import rnn_factory
[docs]class RNNEncoder(EncoderBase):
"""A generic recurrent neural network encoder.
Args:
rnn_type (str):
style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU]
bidirectional (bool) : use a bidirectional RNN
num_layers (int) : number of stacked layers
hidden_size (int) : hidden size of each layer
dropout (float) : dropout value for :class:`torch.nn.Dropout`
embeddings (onmt.modules.Embeddings): embedding module to use
"""
def __init__(
self,
rnn_type,
bidirectional,
num_layers,
hidden_size,
dropout=0.0,
embeddings=None,
use_bridge=False,
):
super(RNNEncoder, self).__init__()
assert embeddings is not None
num_directions = 2 if bidirectional else 1
assert hidden_size % num_directions == 0
hidden_size = hidden_size // num_directions
self.embeddings = embeddings
self.rnn, self.no_pack_padded_seq = rnn_factory(
rnn_type,
input_size=embeddings.embedding_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
)
# Initialize the bridge layer
self.use_bridge = use_bridge
if self.use_bridge:
self._initialize_bridge(rnn_type, hidden_size, num_layers)
[docs] @classmethod
def from_opt(cls, opt, embeddings):
"""Alternate constructor."""
return cls(
opt.rnn_type,
opt.brnn,
opt.enc_layers,
opt.enc_hid_size,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
embeddings,
opt.bridge,
)
[docs] def forward(self, src, src_len=None):
"""See :func:`EncoderBase.forward()`"""
emb = self.embeddings(src)
packed_emb = emb
if src_len is not None and not self.no_pack_padded_seq:
# src lengths data is wrapped inside a Tensor.
src_len_list = src_len.view(-1).tolist()
packed_emb = pack(emb, src_len_list, batch_first=True, enforce_sorted=False)
enc_out, enc_final_hs = self.rnn(packed_emb)
if src_len is not None and not self.no_pack_padded_seq:
enc_out = unpack(enc_out, batch_first=True)[0]
if self.use_bridge:
enc_final_hs = self._bridge(enc_final_hs)
return enc_out, enc_final_hs, src_len
def _initialize_bridge(self, rnn_type, hidden_size, num_layers):
# LSTM has hidden and cell state, other only one
number_of_states = 2 if rnn_type == "LSTM" else 1
# Total number of states
self.total_hidden_dim = hidden_size * num_layers
# Build a linear layer for each
self.bridge = nn.ModuleList(
[
nn.Linear(self.total_hidden_dim, self.total_hidden_dim, bias=True)
for _ in range(number_of_states)
]
)
def _bridge(self, hidden):
"""Forward hidden state through bridge.
final hidden state ``(num_layers x dir, batch, hidden_size)``
"""
def bottle_hidden(linear, states):
"""
Transform from 3D to 2D, apply linear and return initial size
"""
states = states.permute(1, 0, 2).contiguous()
size = states.size()
result = linear(states.view(-1, self.total_hidden_dim))
result = F.relu(result).view(size)
return result.permute(1, 0, 2).contiguous()
if isinstance(hidden, tuple): # LSTM
outs = tuple(
[
bottle_hidden(layer, hidden[ix])
for ix, layer in enumerate(self.bridge)
]
)
else:
outs = bottle_hidden(self.bridge[0], hidden)
return outs
def update_dropout(self, dropout, attention_dropout=None):
self.rnn.dropout = dropout