Source code for opennmt.utils.decoding

"""Dynamic decoding utilities."""

import abc
import collections

import tensorflow as tf
import tensorflow_addons as tfa

from opennmt import constants
from opennmt.utils import misc


[docs]class Sampler(abc.ABC): """Base class for samplers."""
[docs] @abc.abstractmethod def __call__(self, scores, num_samples=1): """Samples predictions. Args: scores: The scores to sample from, a tensor of shape ``[batch_size, vocab_size]``. num_samples: The number of samples per batch to produce. Returns: A tuple ``(sample_ids, sample_scores)``. """ raise NotImplementedError()
[docs] @staticmethod def from_params(params): """Constructs a sampler based on user parameters. Args: params: A dictionary of user parameters. Returns: A :class:`opennmt.utils.Sampler` instance. """ sampling_topk = params.get("sampling_topk", 1) if sampling_topk == 1: return BestSampler() else: return RandomSampler( from_top_k=sampling_topk, temperature=params.get("sampling_temperature"))
[docs]class RandomSampler(Sampler): """Randomly samples from model outputs."""
[docs] def __init__(self, from_top_k=None, temperature=None): """Initializes the random sampler. Args: from_top_k: Sample from the top K predictions instead of the full distribution. temperature: Divide logits by this value. High temperatures generate more random samples. """ if from_top_k is not None and from_top_k <= 0: from_top_k = None self.from_top_k = from_top_k self.temperature = temperature
def __call__(self, scores, num_samples=1): if self.from_top_k is None: sample_ids = _sample_from(scores, num_samples, temperature=self.temperature) else: top_scores, top_ids = tf.nn.top_k(scores, k=self.from_top_k) sample_ids = _sample_from(top_scores, num_samples, temperature=self.temperature) sample_ids = _gather_from_word_indices(top_ids, sample_ids) sample_scores = _gather_from_word_indices(scores, sample_ids) return sample_ids, sample_scores
[docs]class BestSampler(Sampler): """Sample the best predictions.""" def __call__(self, scores, num_samples=1): sample_scores, sample_ids = tf.nn.top_k(scores, k=num_samples) return sample_ids, sample_scores
[docs]class DecodingStrategy(abc.ABC): """Base class for decoding strategies.""" @property def num_hypotheses(self): """The number of hypotheses returned by this strategy.""" return 1
[docs] @staticmethod def from_params(params): """Constructs a decoding strategy based on user parameters. Args: params: A dictionary of user parameters. Returns: A :class:`opennmt.utils.DecodingStrategy` instance. """ beam_size = params.get("beam_width", 1) if beam_size > 1: return BeamSearch( beam_size, length_penalty=params.get("length_penalty", 0), coverage_penalty=params.get("coverage_penalty", 0)) else: return GreedySearch()
@abc.abstractmethod def _initialize(self, batch_size, start_ids, attention_size=None): """Initializes the strategy. Args: batch_size: The batch size. start_ids: The start decoding ids. attention_size: If known, the size of the attention vectors (i.e. the maximum source length). Returns: A tuple containing, - The (possibly transformed) start decoding ids. - The tensor of finished flags. - Initial log probabilities per batch. - A sequence of additional tensors used during the decoding. """ raise NotImplementedError() @abc.abstractmethod def _step(self, step, sampler, log_probs, cum_log_probs, finished, state, extra_vars, attention=None): """Updates the strategy state. Args: step: The current decoding step. sampler: The sampler that produces predictions. log_probs: The model log probabilities. cum_log_probs: The cumulated log probabilities per batch. finished: The current finished flags. state: The decoder state. extra_vars: Additional tensors from this decoding strategy. attention: The attention vector for the current step. Returns: A tuple containing, - The predicted word ids. - The new cumulated log probabilities. - The updated finished flags. - The update decoder state. - Additional tensors from this decoding strategy. """ raise NotImplementedError() @abc.abstractmethod def _finalize(self, outputs, end_id, extra_vars, attention=None): """Finalize the predictions. Args: outputs: The array of sampled ids. end_id: The end token id. extra_vars: Additional tensors from this decoding strategy. attention: The array of attention outputs. Returns: A tuple containing, - The final predictions as a tensor of shape [B, H, T]. - The final attention history of shape [B, H, T, S]. - The final sequence lengths of shape [B, H]. """ raise NotImplementedError()
[docs]class GreedySearch(DecodingStrategy): """A basic greedy search strategy.""" def _initialize(self, batch_size, start_ids, attention_size=None): finished = tf.zeros([batch_size], dtype=tf.bool) initial_log_probs = tf.zeros([batch_size], dtype=tf.float32) return start_ids, finished, initial_log_probs, [] def _step(self, step, sampler, log_probs, cum_log_probs, finished, state, extra_vars, attention=None): sample_ids, sample_log_probs = sampler(log_probs) sample_ids = tf.reshape(sample_ids, [-1]) sample_log_probs = tf.reshape(sample_log_probs, [-1]) cum_log_probs += sample_log_probs return sample_ids, cum_log_probs, finished, state, extra_vars def _finalize(self, outputs, end_id, extra_vars, attention=None): ids = tf.transpose(outputs.stack()) ids = tf.expand_dims(ids, 1) lengths = _lengths_from_ids(ids, end_id) if attention is not None: attention = tf.transpose(attention.stack(), perm=[1, 0, 2]) attention = tf.expand_dims(attention, 1) return ids, attention, lengths
[docs]class BeamSearch(DecodingStrategy): """A beam search strategy."""
[docs] def __init__(self, beam_size, length_penalty=0, coverage_penalty=0): """Initializes the decoding strategy. Args: beam_size: The number of paths to consider per batch. length_penalty: Length penalty, see https://arxiv.org/abs/1609.08144. coverage_penalty: Coverage penalty, see https://arxiv.org/abs/1609.08144. """ self.beam_size = beam_size self.length_penalty = length_penalty self.coverage_penalty = coverage_penalty self._state_reorder_flags = None
@property def num_hypotheses(self): return self.beam_size def _set_state_reorder_flags(self, state_reorder_flags): """Sets state reorder flags, a structure matching the decoder state that indicates which tensor should be reorded during beam search. """ self._state_reorder_flags = state_reorder_flags def _initialize(self, batch_size, start_ids, attention_size=None): start_ids = tfa.seq2seq.tile_batch(start_ids, self.beam_size) finished = tf.zeros([batch_size * self.beam_size], dtype=tf.bool) # Give all probability to first beam for the first iteration. initial_log_probs = tf.tile([0.] + [-float("inf")] * (self.beam_size - 1), [batch_size]) parent_ids = tf.TensorArray(tf.int32, size=0, dynamic_size=True) sequence_lengths = tf.zeros([batch_size * self.beam_size], dtype=tf.int32) extra_vars = [parent_ids, sequence_lengths] if self.coverage_penalty != 0: if attention_size is None: raise ValueError("The attention size should be known to support coverage penalty") accumulated_attention = tf.zeros([batch_size * self.beam_size, attention_size]) extra_vars.append(accumulated_attention) return start_ids, finished, initial_log_probs, tuple(extra_vars) def _get_scores(self, log_probs, sequence_lengths, finished, accumulated_attention=None): scores = log_probs if self.length_penalty != 0: expand_sequence_lengths = tf.expand_dims(sequence_lengths, 1) scores /= tf.pow(((5. + tf.cast(expand_sequence_lengths + 1, scores.dtype)) / 6.), self.length_penalty) if self.coverage_penalty != 0: # Mask out of range steps with ones (log(1) == 0). accumulated_attention = tf.where( tf.equal(accumulated_attention, 0.0), x=tf.ones_like(accumulated_attention), y=accumulated_attention) coverage_penalty = tf.reduce_sum( tf.math.log(tf.minimum(accumulated_attention, 1.0)), 1) # Apply coverage penalty to finished predictions. coverage_penalty *= tf.cast(finished, coverage_penalty.dtype) scores += self.coverage_penalty * tf.expand_dims(coverage_penalty, 1) return scores def _step(self, step, sampler, log_probs, cum_log_probs, finished, state, extra_vars, attention=None): parent_ids = extra_vars[0] sequence_lengths = extra_vars[1] if self.coverage_penalty != 0: if attention is None: raise ValueError("Coverage penalty is enabled but the model did not " "return an attention vector") not_finished = tf.math.logical_not(finished) attention *= tf.expand_dims(tf.cast(not_finished, attention.dtype), 1) accumulated_attention = extra_vars[2] + attention else: accumulated_attention = None # Compute scores from log probabilities. vocab_size = log_probs.shape[-1] total_probs = log_probs + tf.expand_dims(cum_log_probs, 1) # Add current beam probability. scores = self._get_scores( total_probs, sequence_lengths, finished, accumulated_attention=accumulated_attention) scores = tf.reshape(scores, [-1, self.beam_size * vocab_size]) total_probs = tf.reshape(total_probs, [-1, self.beam_size * vocab_size]) # Sample predictions. sample_ids, sample_scores = sampler(scores, num_samples=self.beam_size) cum_log_probs = tf.reshape(_gather_from_word_indices(total_probs, sample_ids), [-1]) sample_ids = tf.reshape(sample_ids, [-1]) sample_scores = tf.reshape(sample_scores, [-1]) # Resolve beam origin and word ids. word_ids = sample_ids % vocab_size beam_ids = sample_ids // vocab_size beam_indices = ( (tf.range(tf.shape(word_ids)[0]) // self.beam_size) * self.beam_size + beam_ids) # Update sequence_length of unfinished sequence. sequence_lengths = tf.where(finished, x=sequence_lengths, y=sequence_lengths + 1) # Update state and flags. finished = tf.gather(finished, beam_indices) sequence_lengths = tf.gather(sequence_lengths, beam_indices) parent_ids = parent_ids.write(step, beam_ids) extra_vars = [parent_ids, sequence_lengths] if accumulated_attention is not None: accumulated_attention = tf.gather(accumulated_attention, beam_indices) extra_vars.append(accumulated_attention) state = _reorder_state(state, beam_indices, reorder_flags=self._state_reorder_flags) return word_ids, cum_log_probs, finished, state, tuple(extra_vars) def _finalize(self, outputs, end_id, extra_vars, attention=None): parent_ids = extra_vars[0] sequence_lengths = extra_vars[1] maximum_lengths = tf.reduce_max(tf.reshape(sequence_lengths, [-1, self.beam_size]), axis=-1) max_time = outputs.size() array_shape = [max_time, -1, self.beam_size] step_ids = tf.reshape(outputs.stack(), array_shape) parent_ids = tf.reshape(parent_ids.stack(), array_shape) ids = tfa.seq2seq.gather_tree(step_ids, parent_ids, maximum_lengths, end_id) ids = tf.transpose(ids, perm=[1, 2, 0]) lengths = _lengths_from_ids(ids, end_id) if attention is not None: attention = tfa.seq2seq.gather_tree_from_array( attention.stack(), parent_ids, lengths) attention = tf.transpose(attention, perm=[1, 0, 2]) attention = tf.reshape( attention, [tf.shape(ids)[0], self.beam_size, max_time, -1]) return ids, attention, lengths
[docs]class DecodingResult( collections.namedtuple("DecodingResult", ("ids", "lengths", "log_probs", "attention", "state"))): """Final decoding result. Args: ids: The predicted ids of shape :math:`[B, H, T]`. lengths: The produced sequences length of shape :math:`[B, H]`. log_probs: The cumulated log probabilities of shape :math:`[B, H]`. attention: The attention history of shape :math:`[B, H, T_t, T_s]`. state: The final decoding state. """
[docs]def dynamic_decode(symbols_to_logits_fn, start_ids, end_id=constants.END_OF_SENTENCE_ID, initial_state=None, decoding_strategy=None, sampler=None, maximum_iterations=None, minimum_iterations=0, attention_history=False, attention_size=None): """Dynamic decoding. Args: symbols_to_logits_fn: A callable taking ``(symbols, step, state)`` and returning ``(logits, state, attention)`` (``attention`` is optional). start_ids: Initial input IDs of shape :math:`[B]`. end_id: ID of the end of sequence token. initial_state: Initial decoder state. decoding_strategy: A :class:`opennmt.utils.DecodingStrategy` instance that defines the decoding logic. Defaults to a greedy search. sampler: A :class:`opennmt.utils.Sampler` instance that samples predictions from the model output. Defaults to an argmax sampling. maximum_iterations: The maximum number of iterations to decode for. minimum_iterations: The minimum number of iterations to decode for. attention_history: Gather attention history during the decoding. attention_size: If known, the size of the attention vectors (i.e. the maximum source length). Returns: A :class:`opennmt.utils.DecodingResult` instance. """ if initial_state is None: initial_state = {} if decoding_strategy is None: decoding_strategy = GreedySearch() if sampler is None: sampler = BestSampler() def _cond(step, finished, state, inputs, outputs, attention, cum_log_probs, extra_vars): # pylint: disable=unused-argument return tf.reduce_any(tf.logical_not(finished)) def _body(step, finished, state, inputs, outputs, attention, cum_log_probs, extra_vars): # Get log probs from the model. result = symbols_to_logits_fn(inputs, step, state) logits, state = result[0], result[1] attn = result[2] if len(result) > 2 else None logits = tf.cast(logits, tf.float32) # Penalize or force EOS. batch_size, vocab_size = misc.shape_list(logits) eos_max_prob = tf.one_hot( tf.fill([batch_size], end_id), vocab_size, on_value=logits.dtype.max, off_value=logits.dtype.min) logits = tf.cond( step < minimum_iterations, true_fn=lambda: _penalize_token(logits, end_id), false_fn=lambda: tf.where( tf.broadcast_to(tf.expand_dims(finished, -1), tf.shape(logits)), x=eos_max_prob, y=logits)) log_probs = tf.nn.log_softmax(logits) # Run one decoding strategy step. output, next_cum_log_probs, finished, state, extra_vars = ( decoding_strategy._step( # pylint: disable=protected-access step, sampler, log_probs, cum_log_probs, finished, state, extra_vars, attention=attn)) # Update loop vars. if attention_history: if attn is None: raise ValueError("attention_history is set but the model did not return attention") attention = attention.write(step, tf.cast(attn, tf.float32)) outputs = outputs.write(step, output) cum_log_probs = tf.where(finished, x=cum_log_probs, y=next_cum_log_probs) finished = tf.logical_or(finished, tf.equal(output, end_id)) return step + 1, finished, state, output, outputs, attention, cum_log_probs, extra_vars start_ids = tf.convert_to_tensor(start_ids) batch_size = tf.shape(start_ids)[0] ids_dtype = start_ids.dtype start_ids = tf.cast(start_ids, tf.int32) start_ids, finished, initial_log_probs, extra_vars = ( decoding_strategy._initialize( # pylint: disable=protected-access batch_size, start_ids, attention_size=attention_size)) step = tf.constant(0, dtype=tf.int32) outputs = tf.TensorArray(tf.int32, size=0, dynamic_size=True) attention = tf.TensorArray(tf.float32, size=0, dynamic_size=True) _, _, state, _, outputs, attention, log_probs, extra_vars = tf.while_loop( _cond, _body, loop_vars=( step, finished, initial_state, start_ids, outputs, attention, initial_log_probs, extra_vars), shape_invariants=( step.shape, finished.shape, tf.nest.map_structure(_get_shape_invariants, initial_state), start_ids.shape, tf.TensorShape(None), tf.TensorShape(None), initial_log_probs.shape, tf.nest.map_structure(_get_shape_invariants, extra_vars)), parallel_iterations=1, maximum_iterations=maximum_iterations) ids, attention, lengths = decoding_strategy._finalize( # pylint: disable=protected-access outputs, end_id, extra_vars, attention=attention if attention_history else None) if attention is not None: attention = attention[:, :, :-1] # Ignore attention for </s>. log_probs = tf.reshape(log_probs, [batch_size, decoding_strategy.num_hypotheses]) ids = tf.cast(ids, ids_dtype) return DecodingResult( ids=ids, lengths=lengths, log_probs=log_probs, attention=attention, state=state)
def _reorder_state(state, indices, reorder_flags=None): """Gather batch indices from the state tensors.""" def _reorder_one(tensor, reorder=True): if not reorder or isinstance(tensor, tf.TensorArray) or tensor.shape.ndims == 0: return tensor return tf.gather(tensor, indices) args = [state] if reorder_flags is not None: tf.nest.assert_same_structure(state, reorder_flags) args.append(reorder_flags) return tf.nest.map_structure(_reorder_one, *args) def _get_shape_invariants(tensor): """Returns the shape of the tensor but sets middle dims to None.""" if isinstance(tensor, tf.TensorArray): shape = None else: shape = tensor.shape.as_list() for i in range(1, len(shape) - 1): shape[i] = None return tf.TensorShape(shape) def _penalize_token(log_probs, token_id, penalty=-1e7): """Penalize token probabilities.""" depth = log_probs.shape[-1] penalty = tf.one_hot([token_id], depth, on_value=tf.cast(penalty, log_probs.dtype)) return log_probs + penalty def _sample_from(logits, num_samples, temperature=None): """Sample N values from the unscaled probability distribution.""" if temperature is not None: logits /= tf.cast(temperature, logits.dtype) return tf.random.categorical(logits, num_samples, dtype=tf.int32) def _gather_from_word_indices(tensor, indices): """Index the depth dim of a 2D tensor.""" return tf.gather(tensor, indices, axis=-1, batch_dims=1) def _lengths_from_ids(ids, end_id): """Compute sequence lengths from word ids.""" lengths = tf.not_equal(ids, end_id) lengths = tf.cast(lengths, tf.int32) lengths = tf.reduce_sum(lengths, axis=-1) return lengths