GreedySearch
- class opennmt.utils.GreedySearch[source]
A basic greedy search strategy.
Inherits from:
opennmt.utils.DecodingStrategy
- initialize(start_ids, attention_size=None)[source]
Initializes the strategy.
- Parameters
start_ids – The start decoding ids.
attention_size – If known, the size of the attention vectors (i.e. the maximum source length).
- Returns
A tuple containing,
The (possibly transformed) start decoding ids.
The tensor of finished flags.
Initial log probabilities per batch.
An dictionary of additional tensors used during the decoding.
- step(step, sampler, log_probs, cum_log_probs, finished, state=None, attention=None, **kwargs)[source]
Updates the strategy state.
- Parameters
step – The current decoding step.
sampler – The sampler that produces predictions.
log_probs – The model log probabilities.
cum_log_probs – The cumulated log probabilities per batch.
finished – The current finished flags.
state – The decoder state.
attention – The attention vector for the current step.
**kwargs – Additional tensors used by this decoding strategy.
- Returns
A tuple containing,
The predicted word ids.
The new cumulated log probabilities.
The updated finished flags.
The updated decoder state.
A dictionary with additional tensors used by this decoding strategy.
- finalize(outputs, end_id, attention=None, **kwargs)[source]
Finalize the predictions.
- Parameters
outputs – The array of sampled ids.
end_id – The end token id.
attention – The array of attention outputs.
**kwargs – Additional tensors used by this decoding strategy.
- Returns
A tuple containing,
The final predictions as a tensor of shape \([B, H, T_t]\).
The final attention history of shape \([B, H, T_t, T_s]\).
The final sequence lengths of shape \([B, H]\).