opennmt.layers.rnn module

RNN functions and classes for TensorFlow 2.0.

class opennmt.layers.rnn.RNNCellWrapper(cell, input_dropout=0, output_dropout=0, residual_connection=False, **kwargs)[source]

Bases: opennmt.layers.common.LayerWrapper

A wrapper for RNN cells.

__init__(cell, input_dropout=0, output_dropout=0, residual_connection=False, **kwargs)[source]

Initializes the wrapper.

Parameters:
  • cell – The cell to wrap.
  • input_dropout – The probability to drop units in the cell input.
  • output_dropout – The probability to drop units in the cell output.
  • residual_connection – Add the inputs to cell outputs (if their shape are compatible).
  • kwargs – Additional layer arguments.
state_size

The cell state size.

output_size

The cell output size.

get_initial_state(inputs=None, batch_size=None, dtype=None)[source]

Returns the initial cell state.

opennmt.layers.rnn.make_rnn_cell(num_layers, num_units, dropout=0, residual_connections=False, cell_class=None, **kwargs)[source]

Convenience function to build a multi-layer RNN cell.

Parameters:
  • num_layers – The number of layers.
  • num_units – The number of units in each layer.
  • dropout – The probability to drop units in each layer output.
  • residual_connections – If True, each layer input will be added to its output.
  • cell_class – The inner cell class or a callable taking num_units as argument and returning a cell. Defaults to a LSTM cell.
  • kwargs – Additional arguments passed to the cell constructor.
Returns:

A tf.keras.layers.StackedRNNCells instance.

class opennmt.layers.rnn.RNN(cell, bidirectional=False, reducer=None, **kwargs)[source]

Bases: tensorflow.python.keras.engine.base_layer.Layer

A generic RNN layer.

__init__(cell, bidirectional=False, reducer=None, **kwargs)[source]

Initializes the layer.

Parameters:
  • cell – The RNN cell to use.
  • bidirectional – Make this layer bidirectional.
  • reducer – A opennmt.layers.reducer.Reducer instance to merge bidirectional state and outputs.
  • kwargs – Additional layer arguments.
call(*args, **kwargs)[source]

Forwards the arguments the RNN layer.

Returns:A tuple with the output sequences and the state.