"""RNN functions and classes for TensorFlow 2.0."""
import numpy as np
import tensorflow as tf
from opennmt.layers import common
from opennmt.layers import reducer as reducer_lib
[docs]class RNNCellWrapper(common.LayerWrapper):
"""A wrapper for RNN cells."""
[docs] def __init__(
self,
cell,
input_dropout=0,
output_dropout=0,
residual_connection=False,
**kwargs
):
"""Initializes the wrapper.
Args:
cell: The cell to wrap.
input_dropout: The probability to drop units in the cell input.
output_dropout: The probability to drop units in the cell output.
residual_connection: Add the inputs to cell outputs (if their shape are
compatible).
kwargs: Additional layer arguments.
"""
super().__init__(
cell,
input_dropout=input_dropout,
output_dropout=output_dropout,
residual_connection=residual_connection,
**kwargs,
)
self.cell = cell
@property
def state_size(self):
"""The cell state size."""
return self.cell.state_size
@property
def output_size(self):
"""The cell output size."""
return self.cell.output_size
[docs] def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
"""Returns the initial cell state."""
return self.cell.get_initial_state(
inputs=inputs, batch_size=batch_size, dtype=dtype
)
[docs]def make_rnn_cell(
num_layers,
num_units,
dropout=0,
residual_connections=False,
cell_class=None,
**kwargs
):
"""Convenience function to build a multi-layer RNN cell.
Args:
num_layers: The number of layers.
num_units: The number of units in each layer.
dropout: The probability to drop units in each layer output.
residual_connections: If ``True``, each layer input will be added to its output.
cell_class: The inner cell class or a callable taking :obj:`num_units` as
argument and returning a cell. Defaults to a LSTM cell.
kwargs: Additional arguments passed to the cell constructor.
Returns:
A ``tf.keras.layers.StackedRNNCells`` instance.
See Also:
:class:`opennmt.layers.RNNCellWrapper`
"""
if cell_class is None:
cell_class = tf.keras.layers.LSTMCell
cells = []
for _ in range(num_layers):
cell = cell_class(num_units, **kwargs)
if dropout > 0 or residual_connections:
cell = RNNCellWrapper(
cell, output_dropout=dropout, residual_connection=residual_connections
)
cells.append(cell)
return tf.keras.layers.StackedRNNCells(cells)
class _RNNWrapper(tf.keras.layers.Layer):
"""Extend a RNN layer to possibly make it bidirectional and format its outputs."""
def __init__(
self, rnn, bidirectional=False, reducer=reducer_lib.ConcatReducer(), **kwargs
):
"""Initializes the layer.
Args:
rnn: The RNN layer to extend, built with ``return_sequences`` and
``return_state`` enabled.
bidirectional: Make this layer bidirectional.
reducer: A :class:`opennmt.layers.Reducer` instance to merge
bidirectional states and outputs.
**kwargs: Additional layer arguments.
"""
super().__init__(**kwargs)
self.rnn = rnn
self.reducer = reducer
self.bidirectional = bidirectional
if bidirectional:
self.rnn = tf.keras.layers.Bidirectional(self.rnn, merge_mode=None)
def call(self, *args, **kwargs):
"""Forwards the arguments to the RNN layer.
Args:
*args: Positional arguments of the RNN layer.
**kwargs: Keyword arguments of the RNN layer.
Returns:
A tuple with the output sequences and the states.
"""
outputs = self.rnn(*args, **kwargs)
if self.bidirectional:
sequences = outputs[0:2]
states = outputs[2:]
fwd_states = states[: len(states) // 2]
bwd_states = states[len(states) // 2 :]
if self.reducer is not None:
sequences = self.reducer(sequences)
states = tuple(self.reducer.zip_and_reduce(fwd_states, bwd_states))
else:
sequences = tuple(sequences)
states = (fwd_states, bwd_states)
else:
sequences = outputs[0]
states = tuple(outputs[1:])
return sequences, states
[docs]class RNN(_RNNWrapper):
"""A simple RNN layer."""
[docs] def __init__(
self, cell, bidirectional=False, reducer=reducer_lib.ConcatReducer(), **kwargs
):
"""Initializes the layer.
Args:
cell: The RNN cell to use.
bidirectional: Make this layer bidirectional.
reducer: A :class:`opennmt.layers.Reducer` instance to merge
bidirectional states and outputs.
**kwargs: Additional layer arguments.
See Also:
:func:`opennmt.layers.make_rnn_cell`
"""
rnn = tf.keras.layers.RNN(cell, return_sequences=True, return_state=True)
super().__init__(rnn, bidirectional=bidirectional, reducer=reducer, **kwargs)
[docs] def map_v1_weights(self, weights):
m = []
if self.bidirectional:
weights = weights["bidirectional_rnn"]
m += map_v1_weights_to_cell(self.rnn.forward_layer.cell, weights["fw"])
m += map_v1_weights_to_cell(self.rnn.backward_layer.cell, weights["bw"])
else:
weights = weights["rnn"]
m += map_v1_weights_to_cell(self.rnn.cell, weights)
return m
[docs]class LSTM(tf.keras.layers.Layer):
"""A multi-layer LSTM.
This differs from using :class:`opennmt.layers.RNN` with a ``LSTMCell`` in 2
ways:
- It uses ``tf.keras.layers.LSTM`` which is possibly accelerated by cuDNN on
GPU.
- Bidirectional outputs of each layer are reduced before feeding them to the
next layer.
"""
[docs] def __init__(
self,
num_layers,
num_units,
bidirectional=False,
reducer=reducer_lib.ConcatReducer(),
dropout=0,
residual_connections=False,
**kwargs
):
"""Initializes the layer.
Args:
num_layers: Number of stacked LSTM layers.
num_units: Dimension of the output space of each LSTM.
bidirectional: Make each layer bidirectional.
reducer: A :class:`opennmt.layers.Reducer` instance to merge
the bidirectional states and outputs of each layer.
dropout: The probability to drop units in each layer output.
residual_connections: If ``True``, each layer input will be added to its
output.
**kwargs: Additional layer arguments.
"""
super().__init__(**kwargs)
rnn_layers = [
_RNNWrapper(
tf.keras.layers.LSTM(
num_units, return_sequences=True, return_state=True
),
bidirectional=bidirectional,
reducer=reducer,
)
for _ in range(num_layers)
]
self.layers = [
common.LayerWrapper(
layer, output_dropout=dropout, residual_connection=residual_connections
)
for layer in rnn_layers
]
[docs] def call(self, inputs, mask=None, training=None, initial_state=None):
all_states = []
for i, layer in enumerate(self.layers):
outputs, states = layer(
inputs,
mask=mask,
training=training,
initial_state=initial_state[i] if initial_state is not None else None,
)
all_states.append(states)
inputs = outputs
return outputs, tuple(all_states)
def map_v1_weights_to_cell(cell, weights):
"""Maps V1 weights to V2 RNN cell."""
if isinstance(cell, RNNCellWrapper):
cell = cell.cell
if isinstance(cell, tf.keras.layers.StackedRNNCells):
return _map_v1_weights_to_stacked_cells(cell, weights)
elif isinstance(
cell, (tf.keras.layers.LSTMCell, tf.compat.v1.keras.layers.LSTMCell)
):
return _map_v1_weights_to_lstmcell(cell, weights)
else:
raise ValueError("Cannot restore V1 weights for cell %s" % str(cell))
def _map_v1_weights_to_stacked_cells(stacked_cells, weights):
weights = weights["multi_rnn_cell"]
m = []
for i, cell in enumerate(stacked_cells.cells):
m += map_v1_weights_to_cell(cell, weights["cell_%d" % i])
return m
def _map_v1_weights_to_lstmcell(cell, weights):
weights = weights["lstm_cell"]
def _upgrade_weight(weight):
is_bias = len(weight.shape) == 1
i, j, f, o = np.split(weight, 4, axis=-1)
if (
is_bias
): # Add forget_bias which is part of the LSTM formula in TensorFlow 1.
f += 1
return np.concatenate((i, f, j, o), axis=-1) # Swap 2nd and 3rd projection.
def _split_kernel(index):
# TensorFlow 1 had a single kernel of shape [input_dim + units, 4 * units],
# but TensorFlow 2 splits it into "kernel" and "recurrent_kernel".
return tf.nest.map_structure(
lambda w: np.split(w, [w.shape[0] - cell.units])[index], weights["kernel"]
)
weights = tf.nest.map_structure(_upgrade_weight, weights)
m = []
m.append((cell.kernel, _split_kernel(0)))
m.append((cell.recurrent_kernel, _split_kernel(1)))
if cell.use_bias:
m.append((cell.bias, weights["bias"]))
return m