Source code for opennmt.layers.bridge

"""Define bridges: logic of passing the encoder state to the decoder."""

import abc

import tensorflow as tf


def assert_state_is_compatible(expected_state, state):
    """Asserts that states are compatible.

    Args:
      expected_state: The reference state.
      state: The state that must be compatible with :obj:`expected_state`.

    Raises:
      ValueError: if the states are incompatible.
    """
    # Check structure compatibility.
    tf.nest.assert_same_structure(expected_state, state)

    # Check shape compatibility.
    expected_state_flat = tf.nest.flatten(expected_state)
    state_flat = tf.nest.flatten(state)

    for x, y in zip(expected_state_flat, state_flat):
        if tf.is_tensor(x):
            expected_depth = x.shape[-1]
            depth = y.shape[-1]
            if depth != expected_depth:
                raise ValueError(
                    "Tensor in state has shape %s which is incompatible "
                    "with the target shape %s" % (y.shape, x.shape)
                )


[docs]class Bridge(tf.keras.layers.Layer): """Base class for bridges."""
[docs] def __call__(self, encoder_state, decoder_zero_state): """Returns the initial decoder state. Args: encoder_state: The encoder state. decoder_zero_state: The default decoder state. Returns: The decoder initial state. """ return super().__call__([encoder_state, decoder_zero_state])
[docs] @abc.abstractmethod def call(self, states): raise NotImplementedError()
[docs]class CopyBridge(Bridge): """A bridge that passes the encoder state as is."""
[docs] def call(self, states): encoder_state, decoder_state = states assert_state_is_compatible(encoder_state, decoder_state) flat_encoder_state = tf.nest.flatten(encoder_state) return tf.nest.pack_sequence_as(decoder_state, flat_encoder_state)
[docs]class ZeroBridge(Bridge): """A bridge that does not pass information from the encoder."""
[docs] def call(self, states): # Simply return the default decoder state. return states[1]
[docs]class DenseBridge(Bridge): """A bridge that applies a parameterized linear transformation from the encoder state to the decoder state size. """
[docs] def __init__(self, activation=None): """Initializes the bridge. Args: activation: Activation function (a callable). Set it to ``None`` to maintain a linear activation. """ super().__init__() self.activation = activation self.decoder_state_sizes = None self.linear = None
[docs] def build(self, input_shape): decoder_shape = input_shape[1] self.decoder_state_sizes = [ shape[-1] for shape in tf.nest.flatten(decoder_shape) ] self.linear = tf.keras.layers.Dense( sum(self.decoder_state_sizes), activation=self.activation )
[docs] def call(self, states): encoder_state, decoder_state = states encoder_state_flat = tf.nest.flatten(encoder_state) encoder_state_single = tf.concat(encoder_state_flat, 1) transformed = self.linear(encoder_state_single) splitted = tf.split(transformed, self.decoder_state_sizes, axis=1) return tf.nest.pack_sequence_as(decoder_state, splitted)