Source code for opennmt.decoders.rnn_decoder

"""Define RNN-based decoders."""

import tensorflow as tf
import tensorflow_addons as tfa

from opennmt.decoders import decoder
from opennmt.layers import bridge, common, rnn, transformer
from opennmt.layers.rnn import map_v1_weights_to_cell


[docs]class RNNDecoder(decoder.Decoder): """A basic RNN decoder."""
[docs] def __init__( self, num_layers, num_units, bridge_class=None, cell_class=None, dropout=0.3, residual_connections=False, **kwargs ): """Initializes the decoder parameters. Args: num_layers: The number of layers. num_units: The number of units in each layer. bridge_class: A :class:`opennmt.layers.Bridge` class to pass the encoder state to the decoder. Default to :class:`opennmt.layers.ZeroBridge`. cell_class: The inner cell class or a callable taking :obj:`num_units` as argument and returning a cell. Defaults to a LSTM cell. dropout: The probability to drop units in each layer output. residual_connections: If ``True``, each layer input will be added to its output. **kwargs: Additional layer arguments. """ super().__init__(**kwargs) self.dropout = dropout self.cell = rnn.make_rnn_cell( num_layers, num_units, dropout=dropout, residual_connections=residual_connections, cell_class=cell_class, ) if bridge_class is None: bridge_class = bridge.ZeroBridge self.bridge = bridge_class()
def _get_initial_state(self, batch_size, dtype, initial_state=None): cell_initial_state = self.cell.get_initial_state( batch_size=batch_size, dtype=dtype ) if initial_state is not None: cell_initial_state = self.bridge(initial_state, cell_initial_state) return cell_initial_state
[docs] def step( self, inputs, timestep, state=None, memory=None, memory_sequence_length=None, training=None, ): outputs, state = self.cell(inputs, state, training=training) return outputs, state, None
[docs]class AttentionalRNNDecoder(RNNDecoder): """A RNN decoder with attention."""
[docs] def __init__( self, num_layers, num_units, bridge_class=None, attention_mechanism_class=None, cell_class=None, dropout=0.3, residual_connections=False, first_layer_attention=False, attention_layer_activation=tf.math.tanh, **kwargs ): """Initializes the decoder parameters. Args: num_layers: The number of layers. num_units: The number of units in each layer. bridge: A :class:`opennmt.layers.Bridge` to pass the encoder state to the decoder. attention_mechanism_class: A class inheriting from ``tfa.seq2seq.AttentionMechanism``. Defaults to ``tfa.seq2seq.LuongAttention``. cell_class: The inner cell class or a callable taking :obj:`num_units` as argument and returning a cell. dropout: The probability to drop units in each layer output. residual_connections: If ``True``, each layer input will be added to its output. first_layer_attention: If ``True``, output attention after the first layer. attention_layer_activation: The activation to produce the attentional hidden state. Defaults to tanh following Luong paper (equation (5) in https://arxiv.org/abs/1508.04025). **kwargs: Additional layer arguments. """ super().__init__( num_layers, num_units, bridge_class=bridge_class, cell_class=cell_class, dropout=dropout, residual_connections=residual_connections, **kwargs, ) if attention_mechanism_class is None: attention_mechanism_class = tfa.seq2seq.LuongAttention self.attention_mechanism = attention_mechanism_class(self.cell.output_size) def _add_attention(cell): # Produce Luong-style attentional hidden states. attention_layer = common.Dense( cell.output_size, use_bias=False, activation=attention_layer_activation ) wrapper = tfa.seq2seq.AttentionWrapper( cell, self.attention_mechanism, attention_layer=attention_layer ) return wrapper if first_layer_attention: self.cell.cells[0] = _add_attention(self.cell.cells[0]) else: self.cell = _add_attention(self.cell) self.dropout = dropout self.first_layer_attention = first_layer_attention
@property def support_alignment_history(self): return True def _get_initial_state(self, batch_size, dtype, initial_state=None): # Reset memory of attention mechanism. self.attention_mechanism.setup_memory( self.memory, memory_sequence_length=self.memory_sequence_length ) decoder_state = self.cell.get_initial_state(batch_size=batch_size, dtype=dtype) if initial_state is not None: if self.first_layer_attention: cell_state = list(decoder_state) cell_state[0] = decoder_state[0].cell_state cell_state = self.bridge(initial_state, cell_state) cell_state[0] = decoder_state[0].clone(cell_state=cell_state[0]) decoder_state = tuple(cell_state) else: cell_state = self.bridge(initial_state, decoder_state.cell_state) decoder_state = decoder_state.clone(cell_state=cell_state) return decoder_state
[docs] def step( self, inputs, timestep, state=None, memory=None, memory_sequence_length=None, training=None, ): outputs, state = self.cell(inputs, state, training=training) outputs = common.dropout(outputs, self.dropout, training=training) if self.first_layer_attention: attention = state[0].alignments else: attention = state.alignments return outputs, state, attention
[docs] def map_v1_weights(self, weights): if self.first_layer_attention or not isinstance( self.attention_mechanism, tfa.seq2seq.LuongAttention ): raise ValueError( "Can only map V1 weights for RNN decoder with Luong attention " "on the last layer" ) m = super().map_v1_weights(weights) m += common.Dense.map_v1_weights( self.attention_mechanism.memory_layer, weights["memory_layer"] ) weights = weights["decoder"]["attention_wrapper"] m += common.Dense.map_v1_weights( self.cell._attention_layers[0], weights["attention_layer"] ) m += map_v1_weights_to_cell(self.cell._cell, weights) return m
[docs]class RNMTPlusDecoder(decoder.Decoder): """The RNMT+ decoder described in https://arxiv.org/abs/1804.09849."""
[docs] def __init__( self, num_layers, num_units, num_heads, dropout=0.3, cell_class=None, **kwargs ): """Initializes the decoder parameters. Args: num_layers: The number of layers. num_units: The number of units in each layer. num_heads: The number of attention heads. dropout: The probability to drop units from the decoder input and in each layer output. cell_class: The inner cell class or a callable taking :obj:`num_units` as argument and returning a cell. Defaults to a layer normalized LSTM cell. **kwargs: Additional layer arguments. """ super().__init__(**kwargs) if cell_class is None: cell_class = tfa.rnn.LayerNormLSTMCell self.num_heads = num_heads self.num_units = num_units self.dropout = dropout self.cells = [cell_class(num_units) for _ in range(num_layers)] self.multi_head_attention = transformer.MultiHeadAttention( num_heads, num_units, dropout=dropout, return_attention=True )
@property def support_alignment_history(self): return True def _get_initial_state(self, batch_size, dtype, initial_state=None): return tuple( cell.get_initial_state(batch_size=batch_size, dtype=dtype) for cell in self.cells )
[docs] def step( self, inputs, timestep, state=None, memory=None, memory_sequence_length=None, training=None, ): inputs = common.dropout(inputs, rate=self.dropout, training=training) new_states = [] last_outputs, state_0 = self.cells[0](inputs, state[0]) new_states.append(state_0) if memory_sequence_length is not None: memory_mask = tf.sequence_mask( memory_sequence_length, maxlen=tf.shape(memory)[1] ) else: memory_mask = None context, _, attention = self.multi_head_attention( tf.expand_dims(last_outputs, 1), memory=memory, mask=memory_mask, training=training, ) attention = attention[:, 0, 0] # Use the first head for the attention vector. context = tf.squeeze(context, axis=1) for i in range(1, len(self.cells)): inputs = tf.concat([last_outputs, context], axis=-1) outputs, state_i = self.cells[i](inputs, state[i], training=training) new_states.append(state_i) outputs = common.dropout(outputs, rate=self.dropout, training=training) if i >= 2: outputs += last_outputs last_outputs = outputs final = tf.concat([last_outputs, context], -1) return final, tuple(new_states), attention