opennmt.utils.max_margin_loss

opennmt.utils.max_margin_loss(true_logits, true_labels, true_sequence_length, negative_logits, negative_labels, negative_sequence_length, eta=0.1)[source]

Computes the max-margin loss described in https://www.aclweb.org/anthology/P19-1623.

Parameters
  • true_logits – The unscaled probabilities from the true example.

  • negative_logits – The unscaled probabilities from the negative example.

  • true_labels – The true labels.

  • true_sequence_length – The length of each true sequence.

  • negative_labels – The negative labels.

  • negative_sequence_length – The length of each negative sequence.

  • eta – Ensure that the margin is higher than this value.

Returns

The max-margin loss.