opennmt.utils.guided_alignment_cost

opennmt.utils.guided_alignment_cost(attention_probs, gold_alignment, sequence_length=None, cost_type='ce', weight=1)[source]

Computes the guided alignment cost.

Parameters
  • attention_probs – The attention probabilities, a float tf.Tensor of shape \([B, T_t, T_s]\).

  • gold_alignment – The true alignment matrix, a float tf.Tensor of shape \([B, T_t, T_s]\).

  • sequence_length – The length of each sequence.

  • cost_type – The type of the cost function to compute (can be: ce, mse).

  • weight – The weight applied to the cost.

Returns

The guided alignment cost.

Raises

ValueError – if cost_type is invalid.