Source code for opennmt.layers.reducer

"""Define reducers: objects that merge inputs."""

import abc
import functools

import tensorflow as tf

from opennmt.utils import tensor as tensor_util


def pad_in_time(x, padding_length):
    """Helper function to pad a tensor in the time dimension and retain the
    static depth dimension.
    """
    return tf.pad(x, [[0, 0], [0, padding_length], [0, 0]])


def align_in_time(x, length):
    """Aligns the time dimension of :obj:`x` with :obj:`length`."""
    time_dim = tf.shape(x)[1]
    return tf.cond(
        tf.less(time_dim, length),
        true_fn=lambda: pad_in_time(x, length - time_dim),
        false_fn=lambda: x[:, :length],
    )


def pad_with_identity(
    x, sequence_length, max_sequence_length, identity_values=0, maxlen=None
):
    """Pads a tensor with identity values up to :obj:`max_sequence_length`.

    Args:
      x: A ``tf.Tensor`` of shape ``[batch_size, time, depth]``.
      sequence_length: The true sequence length of :obj:`x`.
      max_sequence_length: The sequence length up to which the tensor must contain
        :obj:`identity values`.
      identity_values: The identity value.
      maxlen: Size of the output time dimension. Default is the maximum value in
        obj:`max_sequence_length`.

    Returns:
      A ``tf.Tensor`` of shape ``[batch_size, maxlen, depth]``.
    """
    if maxlen is None:
        maxlen = tf.reduce_max(max_sequence_length)

    mask = tf.sequence_mask(sequence_length, maxlen=maxlen, dtype=x.dtype)
    mask = tf.expand_dims(mask, axis=-1)
    mask_combined = tf.sequence_mask(max_sequence_length, maxlen=maxlen, dtype=x.dtype)
    mask_combined = tf.expand_dims(mask_combined, axis=-1)

    identity_mask = mask_combined * (1.0 - mask)

    x = pad_in_time(x, maxlen - tf.shape(x)[1])
    x = x * mask + (identity_mask * identity_values)

    return x


def pad_n_with_identity(inputs, sequence_lengths, identity_values=0):
    """Pads each input tensors with identity values up to
    ``max(sequence_lengths)`` for each batch.

    Args:
      inputs: A list of ``tf.Tensor``.
      sequence_lengths: A list of sequence length.
      identity_values: The identity value.

    Returns:
      A tuple ``(padded, max_sequence_length)`` which are respectively a list of
      ``tf.Tensor`` where each tensor are padded with identity and the combined
      sequence length.
    """
    max_sequence_length = tf.reduce_max(sequence_lengths, axis=0)
    maxlen = tf.reduce_max([tf.shape(x)[1] for x in inputs])
    padded = [
        pad_with_identity(
            x,
            length,
            max_sequence_length,
            identity_values=identity_values,
            maxlen=maxlen,
        )
        for x, length in zip(inputs, sequence_lengths)
    ]
    return padded, max_sequence_length


[docs]class Reducer(tf.keras.layers.Layer): """Base class for reducers."""
[docs] def zip_and_reduce(self, x, y): """Zips the :obj:`x` with :obj:`y` structures together and reduces all elements. If the structures are nested, they will be flattened first. Args: x: The first structure. y: The second structure. Returns: The same structure as :obj:`x` and :obj:`y` where each element from :obj:`x` is reduced with the correspond element from :obj:`y`. Raises: ValueError: if the two structures are not the same. """ tf.nest.assert_same_structure(x, y) x_flat = tf.nest.flatten(x) y_flat = tf.nest.flatten(y) reduced = list(map(self, zip(x_flat, y_flat))) return tf.nest.pack_sequence_as(x, reduced)
[docs] def call(self, inputs, sequence_length=None): """Reduces all input elements. Args: inputs: A list of ``tf.Tensor``. sequence_length: The length of each input, if reducing sequences. Returns: If :obj:`sequence_length` is set, a tuple ``(reduced_input, reduced_length)``, otherwise a reduced ``tf.Tensor`` only. """ if sequence_length is None: return self.reduce(inputs) else: return self.reduce_sequence(inputs, sequence_lengths=sequence_length)
[docs] @abc.abstractmethod def reduce(self, inputs): """See :meth:`opennmt.layers.Reducer.call`.""" raise NotImplementedError()
[docs] @abc.abstractmethod def reduce_sequence(self, inputs, sequence_lengths): """See :meth:`opennmt.layers.Reducer.call`.""" raise NotImplementedError()
[docs]class SumReducer(Reducer): """A reducer that sums the inputs."""
[docs] def reduce(self, inputs): if len(inputs) == 1: return inputs[0] if len(inputs) == 2: return inputs[0] + inputs[1] return tf.add_n(inputs)
[docs] def reduce_sequence(self, inputs, sequence_lengths): padded, combined_length = pad_n_with_identity( inputs, sequence_lengths, identity_values=0 ) return self.reduce(padded), combined_length
[docs]class MultiplyReducer(Reducer): """A reducer that multiplies the inputs."""
[docs] def reduce(self, inputs): return functools.reduce(lambda a, x: a * x, inputs)
[docs] def reduce_sequence(self, inputs, sequence_lengths): padded, combined_length = pad_n_with_identity( inputs, sequence_lengths, identity_values=1 ) return self.reduce(padded), combined_length
[docs]class ConcatReducer(Reducer): """A reducer that concatenates the inputs."""
[docs] def __init__(self, axis=-1, **kwargs): """Initializes the concat reducer. Args: axis: Dimension along which to concatenate. This reducer supports concatenating in depth or in time. **kwargs: Additional layer arguments. """ super().__init__(**kwargs) self.axis = axis
[docs] def reduce(self, inputs): return tf.concat(inputs, self.axis)
[docs] def reduce_sequence(self, inputs, sequence_lengths): axis = self.axis % inputs[0].shape.ndims if axis == 2: padded, combined_length = pad_n_with_identity(inputs, sequence_lengths) return self.reduce(padded), combined_length elif axis == 1: # Align all input tensors to the maximum combined length. combined_length = tf.add_n(sequence_lengths) maxlen = tf.reduce_max(combined_length) aligned = [align_in_time(x, maxlen) for x in inputs] current_length = None accumulator = None for elem, length in zip(aligned, sequence_lengths): # Make sure padding are 0 vectors as it is required for the next step. mask = tf.sequence_mask(length, maxlen=maxlen, dtype=elem.dtype) elem = elem * tf.expand_dims(mask, -1) if accumulator is None: accumulator = elem current_length = length else: accumulator += tensor_util.roll_sequence(elem, current_length) current_length += length return accumulator, combined_length else: raise ValueError("Unsupported concatenation on axis {}".format(axis))
[docs]class JoinReducer(Reducer): """A reducer that joins its inputs in a single tuple."""
[docs] def reduce(self, inputs): output = [] for elem in inputs: if isinstance(elem, tuple) and not hasattr(elem, "_fields"): for e in elem: output.append(e) else: output.append(elem) return tuple(output)
[docs] def reduce_sequence(self, inputs, sequence_lengths): return self.reduce(inputs), self.reduce(sequence_lengths)
[docs]class DenseReducer(ConcatReducer): """A reducer that concatenates its inputs in depth and applies a linear transformation."""
[docs] def __init__(self, output_size, activation=None, **kwargs): """Initializes the reducer. Args: output_size: The output size of the linear transformation. activation: Activation function (a callable). Set it to ``None`` to maintain a linear activation. **kwargs: Additional layer arguments. """ super().__init__(axis=-1, **kwargs) self.dense = tf.keras.layers.Dense(output_size, activation=activation)
[docs] def reduce(self, inputs): inputs = super().reduce(inputs) return self.dense(inputs)