"""Define the self-attention encoder."""
import tensorflow as tf
from opennmt.encoders.encoder import Encoder
from opennmt.layers import common, transformer
from opennmt.layers.position import SinusoidalPositionEncoder
[docs]class SelfAttentionEncoder(Encoder):
    """Encoder using self-attention as described in
    https://arxiv.org/abs/1706.03762.
    """
[docs]    def __init__(
        self,
        num_layers,
        num_units=512,
        num_heads=8,
        ffn_inner_dim=2048,
        dropout=0.1,
        attention_dropout=0.1,
        ffn_dropout=0.1,
        ffn_activation=tf.nn.relu,
        mha_bias=True,
        position_encoder_class=SinusoidalPositionEncoder,
        maximum_relative_position=None,
        pre_norm=True,
        **kwargs
    ):
        """Initializes the parameters of the encoder.
        Args:
          num_layers: The number of layers.
          num_units: The number of hidden units.
          num_heads: The number of heads in the multi-head attention.
          ffn_inner_dim: The number of units of the inner linear transformation
            in the feed forward layer.
          dropout: The probability to drop units from the outputs.
          attention_dropout: The probability to drop units from the attention.
          ffn_dropout: The probability to drop units from the activation output in
            the feed forward layer.
          ffn_activation: The activation function to apply between the two linear
            transformations of the feed forward layer.
          mha_bias: Add bias after linear layers in the multi-head attention.
          position_encoder_class: The :class:`opennmt.layers.PositionEncoder`
            class to use for position encoding (or a callable that returns an
            instance).
          maximum_relative_position: Maximum relative position representation
            (from https://arxiv.org/abs/1803.02155).
          pre_norm: If ``True``, layer normalization is applied before each
            sub-layer. Otherwise it is applied after.
          **kwargs: Additional layer arguments.
        """
        super().__init__(**kwargs)
        self.num_units = num_units
        self.dropout = dropout
        self.position_encoder = None
        if position_encoder_class is not None:
            self.position_encoder = position_encoder_class()
        self.layer_norm = common.LayerNorm() if pre_norm else None
        self.layers = [
            transformer.SelfAttentionEncoderLayer(
                num_units,
                num_heads,
                ffn_inner_dim,
                dropout=dropout,
                attention_dropout=attention_dropout,
                ffn_dropout=ffn_dropout,
                ffn_activation=ffn_activation,
                mha_bias=mha_bias,
                maximum_relative_position=maximum_relative_position,
                pre_norm=pre_norm,
            )
            for i in range(num_layers)
        ] 
[docs]    def call(self, inputs, sequence_length=None, training=None):
        inputs *= self.num_units**0.5
        if self.position_encoder is not None:
            inputs = self.position_encoder(inputs)
        inputs = common.dropout(inputs, self.dropout, training=training)
        mask = self.build_mask(inputs, sequence_length=sequence_length)
        for layer in self.layers:
            inputs = layer(inputs, mask=mask, training=training)
        outputs = self.layer_norm(inputs) if self.layer_norm is not None else inputs
        return outputs, None, sequence_length 
[docs]    def map_v1_weights(self, weights):
        m = []
        m += self.layer_norm.map_v1_weights(weights["LayerNorm"])
        for i, layer in enumerate(self.layers):
            m += layer.map_v1_weights(weights["layer_%d" % i])
        return m