Source code for opennmt.utils.losses

"""Define losses."""

import tensorflow as tf


def _smooth_one_hot_labels(logits, labels, label_smoothing):
    num_classes = logits.shape[-1]
    on_value = 1.0 - label_smoothing
    off_value = label_smoothing / (num_classes - 1)
    return tf.one_hot(
        labels,
        num_classes,
        on_value=tf.cast(on_value, logits.dtype),
        off_value=tf.cast(off_value, logits.dtype),
    )


def _softmax_cross_entropy(logits, labels, label_smoothing, training):
    # Computes the softmax in full precision.
    logits = tf.cast(logits, tf.float32)
    if training and label_smoothing > 0.0:
        smoothed_labels = _smooth_one_hot_labels(logits, labels, label_smoothing)
        return tf.nn.softmax_cross_entropy_with_logits(smoothed_labels, logits)
    else:
        return tf.nn.sparse_softmax_cross_entropy_with_logits(labels, logits)


[docs]def cross_entropy_sequence_loss( logits, labels, sequence_length=None, label_smoothing=0.0, average_in_time=False, training=None, sequence_weight=None, mask_outliers=False, ): """Computes the cross entropy loss of sequences. Args: logits: The unscaled probabilities with shape :math:`[B, T, V]`. labels: The true labels with shape :math:`[B, T]`. sequence_length: The length of each sequence with shape :math:`[B]`. label_smoothing: The label smoothing value. average_in_time: If ``True``, also average the loss in the time dimension. training: Compute training loss. sequence_weight: The weight of each sequence with shape :math:`[B]`. mask_outliers: Mask large training loss values considered as outliers. Returns: A tuple (cumulated loss, loss normalizer, token-level normalizer). """ cross_entropy = _softmax_cross_entropy(logits, labels, label_smoothing, training) dtype = cross_entropy.dtype shape = tf.shape(logits) batch_size = shape[0] max_time = shape[1] if sequence_length is None: sequence_length = tf.fill([batch_size], max_time) weight = tf.sequence_mask(sequence_length, maxlen=max_time, dtype=dtype) if training and mask_outliers: import tensorflow_probability as tfp # Outliers are detected using the interquantile range (IQR). examples_loss = tf.reduce_sum(cross_entropy * weight, axis=-1) examples_score = examples_loss / tf.reduce_sum(weight, axis=-1) percentiles = tfp.stats.percentile(examples_score, [25, 75]) iqr = percentiles[1] - percentiles[0] threshold = percentiles[1] + 1.5 * iqr if sequence_weight is None: sequence_weight = tf.ones([batch_size], dtype=dtype) sequence_weight = tf.where( examples_score > threshold, x=tf.zeros_like(sequence_weight), y=sequence_weight, ) if sequence_weight is not None: sequence_weight = tf.cast(sequence_weight, dtype) weight *= tf.expand_dims(sequence_weight, 1) loss = tf.reduce_sum(cross_entropy * weight) loss_token_normalizer = tf.reduce_sum(weight) if average_in_time or not training: loss_normalizer = loss_token_normalizer elif sequence_weight is not None: loss_normalizer = tf.reduce_sum(sequence_weight) else: loss_normalizer = tf.cast(batch_size, dtype) return loss, loss_normalizer, loss_token_normalizer
[docs]def cross_entropy_loss(logits, labels, label_smoothing=0.0, training=None, weight=None): """Computes the cross entropy loss. Args: logits: The unscaled probabilities with shape :math:`[B, V]`. labels: The true labels with shape :math:`[B]`. label_smoothing: The label smoothing value. training: Compute training loss. weight: The weight of each example with shape :math:`[B]`. Returns: The cumulated loss and the loss normalizer. """ cross_entropy = _softmax_cross_entropy(logits, labels, label_smoothing, training) if weight is not None: weight = tf.cast(weight, cross_entropy.dtype) cross_entropy *= weight loss_normalizer = tf.reduce_sum(weight) else: batch_size = tf.shape(cross_entropy)[0] loss_normalizer = tf.cast(batch_size, cross_entropy.dtype) loss = tf.reduce_sum(cross_entropy) return loss, loss_normalizer
[docs]def guided_alignment_cost( attention_probs, gold_alignment, sequence_length=None, cost_type="ce", weight=1 ): """Computes the guided alignment cost. Args: attention_probs: The attention probabilities, a float ``tf.Tensor`` of shape :math:`[B, T_t, T_s]`. gold_alignment: The true alignment matrix, a float ``tf.Tensor`` of shape :math:`[B, T_t, T_s]`. sequence_length: The length of each sequence. cost_type: The type of the cost function to compute (can be: ce, mse). weight: The weight applied to the cost. Returns: The guided alignment cost. Raises: ValueError: if :obj:`cost_type` is invalid. """ if cost_type == "ce": loss = tf.keras.losses.CategoricalCrossentropy( reduction=tf.keras.losses.Reduction.SUM ) elif cost_type == "mse": loss = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.SUM) else: raise ValueError("invalid guided alignment cost: %s" % cost_type) if sequence_length is not None: sample_weight = tf.sequence_mask( sequence_length, maxlen=tf.shape(attention_probs)[1], dtype=attention_probs.dtype, ) sample_weight = tf.expand_dims(sample_weight, -1) normalizer = tf.reduce_sum(sequence_length) else: sample_weight = None normalizer = tf.size(attention_probs) attention_probs = tf.cast(attention_probs, tf.float32) cost = loss(gold_alignment, attention_probs, sample_weight=sample_weight) cost /= tf.cast(normalizer, cost.dtype) return weight * cost
[docs]def regularization_penalty(regularization_type, scale, weights): """Computes the weights regularization penalty. Args: regularization_type: The regularization type: ``l1``, ``l2``, or ``l1_l2``. scale: The regularization multiplier. If :obj:`regularization_type` is ``l1_l2``, this should be a list or tuple containing the L1 regularization scale and the L2 regularization scale. weights: The list of weights. Returns: The regularization penalty. Raises: ValueError: if :obj:`regularization_type` is invalid or is ``l1_l2`` but :obj:`scale` is not a sequence. """ regularization_type = regularization_type.lower() if regularization_type == "l1": regularizer = tf.keras.regularizers.l1(l=float(scale)) elif regularization_type == "l2": regularizer = tf.keras.regularizers.l2(l=float(scale)) elif regularization_type == "l1_l2": if not isinstance(scale, (list, tuple)) or len(scale) != 2: raise ValueError("l1_l2 regularization requires 2 scale values") regularizer = tf.keras.regularizers.l1_l2( l1=float(scale[0]), l2=float(scale[1]) ) else: raise ValueError("invalid regularization type %s" % regularization_type) weights = list(filter(lambda v: not _is_bias(v), weights)) penalty = tf.add_n([regularizer(w) for w in weights]) return penalty
def _is_bias(variable): return len(variable.shape) == 1 and variable.name.endswith("bias:0") def _negative_log_likelihood(logits, labels, sequence_length): nll_num, nll_den, _ = cross_entropy_sequence_loss( logits, labels, sequence_length, average_in_time=True ) return nll_num / nll_den
[docs]def max_margin_loss( true_logits, true_labels, true_sequence_length, negative_logits, negative_labels, negative_sequence_length, eta=0.1, ): """Computes the max-margin loss described in https://www.aclweb.org/anthology/P19-1623. Args: true_logits: The unscaled probabilities from the true example. negative_logits: The unscaled probabilities from the negative example. true_labels: The true labels. true_sequence_length: The length of each true sequence. negative_labels: The negative labels. negative_sequence_length: The length of each negative sequence. eta: Ensure that the margin is higher than this value. Returns: The max-margin loss. """ true_nll = _negative_log_likelihood(true_logits, true_labels, true_sequence_length) negative_nll = _negative_log_likelihood( negative_logits, negative_labels, negative_sequence_length ) margin = true_nll - negative_nll + eta return tf.maximum(margin, 0)