Source code for opennmt.decoders.decoder

"""Base class and functions for dynamic decoders."""

import abc

import tensorflow as tf

from opennmt import constants
from opennmt.inputters import text_inputter
from opennmt.layers import common
from opennmt.utils import decoding, misc


[docs]def get_sampling_probability(step, read_probability=None, schedule_type=None, k=None): """Returns the sampling probability as described in https://arxiv.org/abs/1506.03099. Args: step: The training step. read_probability: The probability to read from the inputs. schedule_type: The type of schedule: "constant", "linear", "exponential", or "inverse_sigmoid". k: The convergence constant. Returns: The probability to sample from the output ids as a 0D ``tf.Tensor`` or ``None`` if scheduled sampling is not configured. Raises: ValueError: if :obj:`schedule_type` is set but not :obj:`k` or if :obj:`schedule_type` is ``linear`` but an initial :obj:`read_probability` is not set. TypeError: if :obj:`schedule_type` is invalid. """ if read_probability is None and schedule_type is None: return None if schedule_type is not None and schedule_type != "constant": if k is None: raise ValueError( "scheduled_sampling_k is required when scheduled_sampling_type is set" ) step = tf.cast(step, tf.float32) k = tf.constant(k, tf.float32) if schedule_type == "linear": if read_probability is None: raise ValueError("Linear schedule requires an initial read probability") read_probability = min(read_probability, 1.0) read_probability = tf.maximum(read_probability - k * step, 0.0) elif schedule_type == "exponential": read_probability = tf.pow(k, step) elif schedule_type == "inverse_sigmoid": read_probability = k / (k + tf.exp(step / k)) else: raise TypeError("Unknown scheduled sampling type: {}".format(schedule_type)) return 1.0 - read_probability
[docs]class Decoder(tf.keras.layers.Layer): """Base class for decoders."""
[docs] def __init__( self, num_sources=1, vocab_size=None, output_layer=None, output_layer_bias=True, **kwargs ): """Initializes the decoder parameters. If you don't set one of :obj:`vocab_size` or :obj:`output_layer` here, you should later call the method :meth:`opennmt.decoders.Decoder.initialize` to initialize this decoder instance. Args: num_sources: The number of source contexts expected by this decoder. vocab_size: The output vocabulary size (optional if :obj:`output_layer` is set). output_layer: The output projection layer (optional). output_layer_bias: Add bias after the output projection layer. **kwargs: Additional layer arguments. Raises: ValueError: if the number of source contexts :obj:`num_sources` is not supported by this decoder. """ if num_sources < self.minimum_sources or num_sources > self.maximum_sources: raise ValueError( "This decoder accepts between %d and %d source contexts, " "but received %d" % (self.minimum_sources, self.maximum_sources, num_sources) ) super().__init__(**kwargs) self.num_sources = num_sources self.output_layer = None self.output_layer_bias = output_layer_bias self.memory = None self.memory_sequence_length = None if vocab_size is not None or output_layer is not None: self.initialize(vocab_size=vocab_size, output_layer=output_layer)
@property def minimum_sources(self): """The minimum number of source contexts supported by this decoder.""" return 1 @property def maximum_sources(self): """The maximum number of source contexts supported by this decoder.""" return 1 @property def support_alignment_history(self): """Returns ``True`` if this decoder can return the attention as alignment history.""" return False @property def initialized(self): """Returns ``True`` if this decoder is initialized.""" return self.output_layer is not None
[docs] def initialize(self, vocab_size=None, output_layer=None): """Initializes the decoder configuration. Args: vocab_size: The target vocabulary size. output_layer: The output layer to use. Raises: ValueError: if both :obj:`vocab_size` and :obj:`output_layer` are not set. """ if self.initialized: return if output_layer is not None: self.output_layer = output_layer else: if vocab_size is None: raise ValueError("One of vocab_size and output_layer must be set") self.output_layer = common.Dense( vocab_size, use_bias=self.output_layer_bias )
[docs] def reuse_embeddings(self, embeddings): """Reuses embeddings in the decoder output layer. Args: embeddings: The embeddings matrix to reuse. Raises: RuntimeError: if the decoder was not initialized. """ self._assert_is_initialized() self.output_layer.set_kernel(embeddings, transpose=True)
[docs] def initial_state( self, memory=None, memory_sequence_length=None, initial_state=None, batch_size=None, dtype=None, ): """Returns the initial decoder state. Args: memory: Memory values to query. memory_sequence_length: Memory values length. initial_state: An initial state to start from, e.g. the last encoder state. batch_size: The batch size to use. dtype: The dtype of the state. Returns: A nested structure of tensors representing the decoder state. Raises: RuntimeError: if the decoder was not initialized. ValueError: if one of :obj:`batch_size` or :obj:`dtype` is not set and neither :obj:`initial_state` nor :obj:`memory` are not passed. ValueError: if the number of source contexts (:obj:`memory`) does not match the number defined at the decoder initialization. """ self._assert_is_initialized() self._assert_memory_is_compatible(memory, memory_sequence_length) self.memory = memory self.memory_sequence_length = memory_sequence_length if batch_size is None or dtype is None: sentinel = tf.nest.flatten(memory)[0] if sentinel is None: sentinel = tf.nest.flatten(initial_state)[0] if sentinel is None: raise ValueError( "If batch_size or dtype are not set, then either " "memory or initial_state should be set" ) if batch_size is None: batch_size = tf.shape(sentinel)[0] if dtype is None: dtype = sentinel.dtype return self._get_initial_state(batch_size, dtype, initial_state=initial_state)
[docs] def call( self, inputs, length_or_step=None, state=None, input_fn=None, sampling_probability=None, training=None, ): """Runs the decoder layer on either a complete sequence (e.g. for training or scoring), or a single timestep (e.g. for iterative decoding). Args: inputs: The inputs to decode, can be a 3D (training) or 2D (iterative decoding) tensor. length_or_step: For 3D :obj:`inputs`, the length of each sequence. For 2D :obj:`inputs`, the current decoding timestep. state: The decoder state. input_fn: A callable taking sampled ids and returning the decoding inputs. sampling_probability: When :obj:`inputs` is the full sequence, the probability to read from the last sample instead of the true target. training: Run in training mode. Returns: A tuple with the logits, the decoder state, and an attention vector. Raises: RuntimeError: if the decoder was not initialized. ValueError: if the :obj:`inputs` rank is different than 2 or 3. ValueError: if :obj:`length_or_step` is invalid. """ self._assert_is_initialized() rank = inputs.shape.ndims if rank == 2: if length_or_step.shape.ndims != 0: raise ValueError( "length_or_step should be a scalar with the current timestep" ) outputs, state, attention = self.step( inputs, length_or_step, state=state, memory=self.memory, memory_sequence_length=self.memory_sequence_length, training=training, ) logits = self.output_layer(outputs) elif rank == 3: if length_or_step.shape.ndims != 1: raise ValueError( "length_or_step should contain the length of each sequence" ) logits, state, attention = self.forward( inputs, sequence_length=length_or_step, initial_state=state, memory=self.memory, memory_sequence_length=self.memory_sequence_length, input_fn=input_fn, sampling_probability=sampling_probability, training=training, ) else: raise ValueError("Unsupported input rank %d" % rank) return logits, state, attention
[docs] def forward( self, inputs, sequence_length=None, initial_state=None, memory=None, memory_sequence_length=None, input_fn=None, sampling_probability=None, training=None, ): """Runs the decoder on full sequences. Args: inputs: The 3D decoder input. sequence_length: The length of each input sequence. initial_state: The initial decoder state. memory: Memory values to query. memory_sequence_length: Memory values length. input_fn: A callable taking sampled ids and returning the decoding inputs. sampling_probability: The probability to read from the last sample instead of the true target. training: Run in training mode. Returns: A tuple with the logits, the decoder state, and the attention vector. """ _ = sequence_length fused_projection = True if sampling_probability is not None: if input_fn is None: raise ValueError( "input_fn is required when a sampling probability is set" ) if not tf.is_tensor(sampling_probability) and sampling_probability == 0: sampling_probability = None else: fused_projection = False batch_size, max_step, _ = misc.shape_list(inputs) inputs_ta = tf.TensorArray(inputs.dtype, size=max_step) inputs_ta = inputs_ta.unstack(tf.transpose(inputs, perm=[1, 0, 2])) def _maybe_sample(true_inputs, logits): # Read from samples with a probability. draw = tf.random.uniform([batch_size]) read_sample = tf.less(draw, sampling_probability) sampled_ids = tf.random.categorical(logits, 1) sampled_inputs = input_fn(tf.squeeze(sampled_ids, 1)) inputs = tf.where( tf.broadcast_to(tf.expand_dims(read_sample, -1), tf.shape(true_inputs)), x=sampled_inputs, y=true_inputs, ) return inputs def _body(step, state, inputs, outputs_ta, attention_ta): outputs, state, attention = self.step( inputs, step, state=state, memory=memory, memory_sequence_length=memory_sequence_length, training=training, ) next_inputs = tf.cond( step + 1 < max_step, true_fn=lambda: inputs_ta.read(step + 1), false_fn=lambda: tf.zeros_like(inputs), ) if not fused_projection: outputs = self.output_layer(outputs) if sampling_probability is not None: next_inputs = _maybe_sample(next_inputs, outputs) outputs_ta = outputs_ta.write(step, outputs) if attention is not None: attention_ta = attention_ta.write(step, attention) return step + 1, state, next_inputs, outputs_ta, attention_ta step = tf.constant(0, dtype=tf.int32) outputs_ta = tf.TensorArray(inputs.dtype, size=max_step) attention_ta = tf.TensorArray(inputs.dtype, size=max_step) _, state, _, outputs_ta, attention_ta = tf.while_loop( lambda *arg: True, _body, loop_vars=( step, initial_state, inputs_ta.read(0), outputs_ta, attention_ta, ), parallel_iterations=32, swap_memory=True, maximum_iterations=max_step, ) outputs = tf.transpose(outputs_ta.stack(), perm=[1, 0, 2]) logits = self.output_layer(outputs) if fused_projection else outputs attention = None if self.support_alignment_history: attention = tf.transpose(attention_ta.stack(), perm=[1, 0, 2]) return logits, state, attention
[docs] @abc.abstractmethod def step( self, inputs, timestep, state=None, memory=None, memory_sequence_length=None, training=None, ): """Runs one decoding step. Args: inputs: The 2D decoder input. timestep: The current decoding step. state: The decoder state. memory: Memory values to query. memory_sequence_length: Memory values length. training: Run in training mode. Returns: A tuple with the decoder outputs, the decoder state, and the attention vector. """ raise NotImplementedError()
[docs] def dynamic_decode( self, embeddings, start_ids, end_id=constants.END_OF_SENTENCE_ID, initial_state=None, decoding_strategy=None, sampler=None, maximum_iterations=None, minimum_iterations=0, tflite_output_size=None, ): """Decodes dynamically from :obj:`start_ids`. Args: embeddings: Target embeddings or :class:`opennmt.inputters.WordEmbedder` to apply on decoded ids. start_ids: Initial input IDs of shape :math:`[B]`. end_id: ID of the end of sequence token. initial_state: Initial decoder state. decoding_strategy: A :class:`opennmt.utils.DecodingStrategy` instance that define the decoding logic. Defaults to a greedy search. sampler: A :class:`opennmt.utils.Sampler` instance that samples predictions from the model output. Defaults to an argmax sampling. maximum_iterations: The maximum number of iterations to decode for. minimum_iterations: The minimum number of iterations to decode for. tflite_output_size: If not None will run TFLite safe, is the size of 1D output tensor. Returns: A :class:`opennmt.utils.DecodingResult` instance. See Also: :func:`opennmt.utils.dynamic_decode` """ if isinstance(embeddings, text_inputter.WordEmbedder): input_fn = lambda ids: embeddings({"ids": ids}) else: input_fn = lambda ids: tf.nn.embedding_lookup(embeddings, ids) # TODO: find a better way to pass the state reorder flags. if hasattr(decoding_strategy, "_set_state_reorder_flags"): state_reorder_flags = self._get_state_reorder_flags() decoding_strategy._set_state_reorder_flags(state_reorder_flags) return decoding.dynamic_decode( lambda ids, step, state: self(input_fn(ids), step, state), start_ids, end_id=end_id, initial_state=initial_state, decoding_strategy=decoding_strategy, sampler=sampler, maximum_iterations=maximum_iterations, minimum_iterations=minimum_iterations, attention_history=self.support_alignment_history, attention_size=tf.shape(self.memory)[1] if self.support_alignment_history else None, tflite_output_size=tflite_output_size, )
[docs] def map_v1_weights(self, weights): return self.output_layer.map_v1_weights(weights["dense"])
@abc.abstractmethod def _get_initial_state(self, batch_size, dtype, initial_state=None): """Returns the decoder initial state. Args: batch_size: The batch size of the returned state. dtype; The data type of the state. initial_state: A state to start from. Returns: The decoder state as a nested structure of tensors. """ raise NotImplementedError() def _get_state_reorder_flags(self): """Returns a structure that marks states that should be reordered during beam search. By default all states are reordered. Returns: The same structure as the decoder state with tensors replaced by booleans. """ return None def _assert_is_initialized(self): """Raises an expection if the decoder was not initialized.""" if not self.initialized: raise RuntimeError("The decoder was not initialized") def _assert_memory_is_compatible(self, memory, memory_sequence_length): """Raises an expection if the memory layout is not compatible with this decoder.""" def _num_elements(obj): if obj is None: return 0 elif isinstance(obj, (list, tuple)): return len(obj) else: return 1 num_memory = _num_elements(memory) num_length = _num_elements(memory_sequence_length) if num_memory != num_length and memory_sequence_length is not None: raise ValueError( "got %d memory values but %d length vectors" % (num_memory, num_length) ) if num_memory != self.num_sources: raise ValueError( "expected %d source contexts, but got %d" % (self.num_sources, num_memory) )