"""Define RNN-based encoders."""
import tensorflow as tf
import tensorflow_addons as tfa
from opennmt.encoders.encoder import Encoder, SequentialEncoder
from opennmt.layers import common, rnn
from opennmt.layers.reducer import ConcatReducer, JoinReducer, pad_in_time
class _RNNEncoderBase(Encoder):
"""Base class for RNN encoders."""
def __init__(self, rnn_layer, **kwargs):
"""Initializes the encoder.
Args:
rnn_layer: The RNN layer used to encode the inputs.
**kwargs: Additional layer arguments.
"""
super().__init__(**kwargs)
self.rnn = rnn_layer
def call(self, inputs, sequence_length=None, training=None):
mask = self.build_mask(inputs, sequence_length=sequence_length)
outputs, states = self.rnn(inputs, mask=mask, training=training)
return outputs, states, sequence_length
[docs]class RNNEncoder(_RNNEncoderBase):
"""A RNN sequence encoder."""
[docs] def __init__(
self,
num_layers,
num_units,
bidirectional=False,
residual_connections=False,
dropout=0.3,
reducer=ConcatReducer(),
cell_class=None,
**kwargs
):
"""Initializes the parameters of the encoder.
Args:
num_layers: The number of layers.
num_units: The number of units in each layer.
bidirectional: Use a bidirectional RNN.
residual_connections: If ``True``, each layer input will be added to its
output.
dropout: The probability to drop units in each layer output.
reducer: A :class:`opennmt.layers.Reducer` instance to merge
bidirectional state and outputs.
cell_class: The inner cell class or a callable taking :obj:`num_units` as
argument and returning a cell. Defaults to a LSTM cell.
**kwargs: Additional layer arguments.
"""
cell = rnn.make_rnn_cell(
num_layers,
num_units,
dropout=dropout,
residual_connections=residual_connections,
cell_class=cell_class,
)
rnn_layer = rnn.RNN(cell, bidirectional=bidirectional, reducer=reducer)
super().__init__(rnn_layer, **kwargs)
[docs] def map_v1_weights(self, weights):
return self.rnn.map_v1_weights(weights)
[docs]class LSTMEncoder(_RNNEncoderBase):
"""A LSTM sequence encoder.
See Also:
:class:`opennmt.layers.LSTM` for differences between this encoder and
:class:`opennmt.encoders.RNNEncoder` with a `LSTMCell`.
"""
[docs] def __init__(
self,
num_layers,
num_units,
bidirectional=False,
residual_connections=False,
dropout=0.3,
reducer=ConcatReducer(),
**kwargs
):
"""Initializes the parameters of the encoder.
Args:
num_layers: The number of layers.
num_units: The number of units in each layer output.
bidirectional: Make each LSTM layer bidirectional.
residual_connections: If ``True``, each layer input will be added to its
output.
dropout: The probability to drop units in each layer output.
reducer: A :class:`opennmt.layers.Reducer` instance to merge
bidirectional state and outputs.
**kwargs: Additional layer arguments.
"""
lstm_layer = rnn.LSTM(
num_layers,
num_units,
bidirectional=bidirectional,
reducer=reducer,
dropout=dropout,
residual_connections=residual_connections,
)
super().__init__(lstm_layer, **kwargs)
[docs]class GNMTEncoder(SequentialEncoder):
"""The RNN encoder used in GNMT as described in
https://arxiv.org/abs/1609.08144.
"""
[docs] def __init__(self, num_layers, num_units, dropout=0.3):
"""Initializes the parameters of the encoder.
Args:
num_layers: The number of layers.
num_units: The number of units in each layer.
dropout: The probability to drop units in each layer output.
Raises:
ValueError: if :obj:`num_layers` < 2.
"""
if num_layers < 2:
raise ValueError("GNMTEncoder requires at least 2 layers")
bidirectional = LSTMEncoder(1, num_units, bidirectional=True, dropout=dropout)
unidirectional = LSTMEncoder(
num_layers - 1, num_units, dropout=dropout, residual_connections=True
)
super().__init__([bidirectional, unidirectional])
[docs]class RNMTPlusEncoder(SequentialEncoder):
"""The RNMT+ encoder described in https://arxiv.org/abs/1804.09849."""
[docs] def __init__(self, num_layers=6, num_units=1024, cell_class=None, dropout=0.3):
"""Initializes the parameters of the encoder.
Args:
num_layers: The number of layers.
num_units: The number of units in each RNN layer and the final output.
cell_class: The inner cell class or a callable taking :obj:`num_units` as
argument and returning a cell. Defaults to a layer normalized LSTM cell.
dropout: The probability to drop units in each layer output.
"""
if cell_class is None:
cell_class = tfa.rnn.LayerNormLSTMCell
layers = [
RNNEncoder(
1, num_units, bidirectional=True, dropout=0.0, cell_class=cell_class
)
for _ in range(num_layers)
]
layers = [
common.LayerWrapper(layer, output_dropout=dropout, residual_connection=True)
for layer in layers
]
super().__init__(layers)
self.dropout = dropout
self.projection = tf.keras.layers.Dense(num_units)
[docs] def call(self, inputs, sequence_length=None, training=None):
inputs = common.dropout(inputs, self.dropout, training=training)
outputs, state, sequence_length = super().call(
inputs, sequence_length=sequence_length, training=training
)
projected = self.projection(outputs)
return (projected, state, sequence_length)
[docs]class PyramidalRNNEncoder(Encoder):
"""An encoder that reduces the time dimension after each bidirectional layer."""
[docs] def __init__(
self, num_layers, num_units, reduction_factor=2, cell_class=None, dropout=0.3
):
"""Initializes the parameters of the encoder.
Args:
num_layers: The number of layers.
num_units: The number of units in each layer.
reduction_factor: The time reduction factor.
cell_class: The inner cell class or a callable taking :obj:`num_units` as
argument and returning a cell. Defaults to a LSTM cell.
dropout: The probability to drop units in each layer output.
"""
super().__init__()
self.reduction_factor = reduction_factor
self.state_reducer = JoinReducer()
self.layers = [
RNNEncoder(
1,
num_units // 2,
bidirectional=True,
reducer=ConcatReducer(),
cell_class=cell_class,
dropout=dropout,
)
for _ in range(num_layers)
]
[docs] def call(self, inputs, sequence_length=None, training=None):
encoder_state = []
for layer_index, layer in enumerate(self.layers):
input_depth = inputs.shape[-1]
if layer_index == 0:
# For the first input, make the number of timesteps a multiple of the
# total reduction factor.
total_reduction_factor = pow(
self.reduction_factor, len(self.layers) - 1
)
current_length = tf.shape(inputs)[1]
factor = tf.cast(current_length, tf.float32) / total_reduction_factor
new_length = (
tf.cast(tf.math.ceil(factor), tf.int32) * total_reduction_factor
)
inputs = pad_in_time(inputs, new_length - current_length)
# Lengths should not be smaller than the total reduction factor.
sequence_length = tf.maximum(sequence_length, total_reduction_factor)
else:
# In other cases, reduce the time dimension.
inputs = tf.reshape(
inputs,
[tf.shape(inputs)[0], -1, input_depth * self.reduction_factor],
)
if sequence_length is not None:
sequence_length //= self.reduction_factor
outputs, state, sequence_length = layer(
inputs, sequence_length=sequence_length, training=training
)
encoder_state.append(state)
inputs = outputs
return (outputs, self.state_reducer(encoder_state), sequence_length)