max_margin_loss
- opennmt.utils.max_margin_loss(true_logits, true_labels, true_sequence_length, negative_logits, negative_labels, negative_sequence_length, eta=0.1)[source]
Computes the max-margin loss described in https://www.aclweb.org/anthology/P19-1623.
- Parameters
true_logits – The unscaled probabilities from the true example.
negative_logits – The unscaled probabilities from the negative example.
true_labels – The true labels.
true_sequence_length – The length of each true sequence.
negative_labels – The negative labels.
negative_sequence_length – The length of each negative sequence.
eta – Ensure that the margin is higher than this value.
- Returns
The max-margin loss.