Source code for opennmt.encoders.self_attention_encoder

"""Define the self-attention encoder."""

import tensorflow as tf

from opennmt.encoders.encoder import Encoder
from opennmt.layers import common, transformer
from opennmt.layers.position import SinusoidalPositionEncoder


[docs]class SelfAttentionEncoder(Encoder): """Encoder using self-attention as described in https://arxiv.org/abs/1706.03762. """
[docs] def __init__( self, num_layers, num_units=512, num_heads=8, ffn_inner_dim=2048, dropout=0.1, attention_dropout=0.1, ffn_dropout=0.1, ffn_activation=tf.nn.relu, mha_bias=True, position_encoder_class=SinusoidalPositionEncoder, maximum_relative_position=None, pre_norm=True, **kwargs ): """Initializes the parameters of the encoder. Args: num_layers: The number of layers. num_units: The number of hidden units. num_heads: The number of heads in the multi-head attention. ffn_inner_dim: The number of units of the inner linear transformation in the feed forward layer. dropout: The probability to drop units from the outputs. attention_dropout: The probability to drop units from the attention. ffn_dropout: The probability to drop units from the activation output in the feed forward layer. ffn_activation: The activation function to apply between the two linear transformations of the feed forward layer. mha_bias: Add bias after linear layers in the multi-head attention. position_encoder_class: The :class:`opennmt.layers.PositionEncoder` class to use for position encoding (or a callable that returns an instance). maximum_relative_position: Maximum relative position representation (from https://arxiv.org/abs/1803.02155). pre_norm: If ``True``, layer normalization is applied before each sub-layer. Otherwise it is applied after. **kwargs: Additional layer arguments. """ super().__init__(**kwargs) self.num_units = num_units self.dropout = dropout self.position_encoder = None if position_encoder_class is not None: self.position_encoder = position_encoder_class() self.layer_norm = common.LayerNorm() if pre_norm else None self.layers = [ transformer.SelfAttentionEncoderLayer( num_units, num_heads, ffn_inner_dim, dropout=dropout, attention_dropout=attention_dropout, ffn_dropout=ffn_dropout, ffn_activation=ffn_activation, mha_bias=mha_bias, maximum_relative_position=maximum_relative_position, pre_norm=pre_norm, ) for i in range(num_layers) ]
[docs] def call(self, inputs, sequence_length=None, training=None): inputs *= self.num_units**0.5 if self.position_encoder is not None: inputs = self.position_encoder(inputs) inputs = common.dropout(inputs, self.dropout, training=training) mask = self.build_mask(inputs, sequence_length=sequence_length) for layer in self.layers: inputs = layer(inputs, mask=mask, training=training) outputs = self.layer_norm(inputs) if self.layer_norm is not None else inputs return outputs, None, sequence_length
[docs] def map_v1_weights(self, weights): m = [] m += self.layer_norm.map_v1_weights(weights["LayerNorm"]) for i, layer in enumerate(self.layers): m += layer.map_v1_weights(weights["layer_%d" % i]) return m