RNMTPlusDecoder
- class opennmt.decoders.RNMTPlusDecoder(*args, **kwargs)[source]
The RNMT+ decoder described in https://arxiv.org/abs/1804.09849.
Inherits from:
opennmt.decoders.Decoder
- __init__(num_layers, num_units, num_heads, dropout=0.3, cell_class=None, **kwargs)[source]
Initializes the decoder parameters.
- Parameters
num_layers – The number of layers.
num_units – The number of units in each layer.
num_heads – The number of attention heads.
dropout – The probability to drop units from the decoder input and in each layer output.
cell_class – The inner cell class or a callable taking
num_units
as argument and returning a cell. Defaults to a layer normalized LSTM cell.**kwargs – Additional layer arguments.
- property support_alignment_history
Returns
True
if this decoder can return the attention as alignment history.
- step(inputs, timestep, state=None, memory=None, memory_sequence_length=None, training=None)[source]
Runs one decoding step.
- Parameters
inputs – The 2D decoder input.
timestep – The current decoding step.
state – The decoder state.
memory – Memory values to query.
memory_sequence_length – Memory values length.
training – Run in training mode.
- Returns
A tuple with the decoder outputs, the decoder state, and the attention vector.