BeamSearch
- class opennmt.utils.BeamSearch(beam_size, length_penalty=0, coverage_penalty=0, tflite_output_size=None)[source]
A beam search strategy.
Inherits from:
opennmt.utils.DecodingStrategy
- __init__(beam_size, length_penalty=0, coverage_penalty=0, tflite_output_size=None)[source]
Initializes the decoding strategy.
- Parameters
beam_size – The number of paths to consider per batch.
length_penalty – Length penalty, see https://arxiv.org/abs/1609.08144.
coverage_penalty – Coverage penalty, see https://arxiv.org/abs/1609.08144.
tflite_output_size – None if not TFLite exporting. Is the output size of TFLite model
- property num_hypotheses
The number of hypotheses returned by this strategy.
- initialize(start_ids, attention_size=None)[source]
Initializes the strategy.
- Parameters
start_ids – The start decoding ids.
attention_size – If known, the size of the attention vectors (i.e. the maximum source length).
- Returns
A tuple containing,
The (possibly transformed) start decoding ids.
The tensor of finished flags.
Initial log probabilities per batch.
An dictionary of additional tensors used during the decoding.
- step(step, sampler, log_probs, cum_log_probs, finished, state=None, attention=None, **kwargs)[source]
Updates the strategy state.
- Parameters
step – The current decoding step.
sampler – The sampler that produces predictions.
log_probs – The model log probabilities.
cum_log_probs – The cumulated log probabilities per batch.
finished – The current finished flags.
state – The decoder state.
attention – The attention vector for the current step.
**kwargs – Additional tensors used by this decoding strategy.
- Returns
A tuple containing,
The predicted word ids.
The new cumulated log probabilities.
The updated finished flags.
The updated decoder state.
A dictionary with additional tensors used by this decoding strategy.
- finalize(outputs, end_id, attention=None, **kwargs)[source]
Finalize the predictions.
- Parameters
outputs – The array of sampled ids.
end_id – The end token id.
attention – The array of attention outputs.
**kwargs – Additional tensors used by this decoding strategy.
- Returns
A tuple containing,
The final predictions as a tensor of shape \([B, H, T_t]\).
The final attention history of shape \([B, H, T_t, T_s]\).
The final sequence lengths of shape \([B, H]\).