GreedySearch

class opennmt.utils.GreedySearch[source]

A basic greedy search strategy.

Inherits from: opennmt.utils.DecodingStrategy

initialize(start_ids, attention_size=None)[source]

Initializes the strategy.

Parameters
  • start_ids – The start decoding ids.

  • attention_size – If known, the size of the attention vectors (i.e. the maximum source length).

Returns

A tuple containing,

  • The (possibly transformed) start decoding ids.

  • The tensor of finished flags.

  • Initial log probabilities per batch.

  • An dictionary of additional tensors used during the decoding.

step(step, sampler, log_probs, cum_log_probs, finished, state=None, attention=None, **kwargs)[source]

Updates the strategy state.

Parameters
  • step – The current decoding step.

  • sampler – The sampler that produces predictions.

  • log_probs – The model log probabilities.

  • cum_log_probs – The cumulated log probabilities per batch.

  • finished – The current finished flags.

  • state – The decoder state.

  • attention – The attention vector for the current step.

  • **kwargs – Additional tensors used by this decoding strategy.

Returns

A tuple containing,

  • The predicted word ids.

  • The new cumulated log probabilities.

  • The updated finished flags.

  • The updated decoder state.

  • A dictionary with additional tensors used by this decoding strategy.

finalize(outputs, end_id, attention=None, **kwargs)[source]

Finalize the predictions.

Parameters
  • outputs – The array of sampled ids.

  • end_id – The end token id.

  • attention – The array of attention outputs.

  • **kwargs – Additional tensors used by this decoding strategy.

Returns

A tuple containing,

  • The final predictions as a tensor of shape \([B, H, T_t]\).

  • The final attention history of shape \([B, H, T_t, T_s]\).

  • The final sequence lengths of shape \([B, H]\).