"""Define self-attention decoder."""
import tensorflow as tf
from opennmt.decoders import decoder
from opennmt.layers import common, transformer
from opennmt.layers.position import SinusoidalPositionEncoder
[docs]class SelfAttentionDecoder(decoder.Decoder):
"""Encoder using self-attention as described in
https://arxiv.org/abs/1706.03762.
"""
[docs] def __init__(
self,
num_layers,
num_units=512,
num_heads=8,
ffn_inner_dim=2048,
dropout=0.1,
attention_dropout=0.1,
ffn_dropout=0.1,
ffn_activation=tf.nn.relu,
mha_bias=True,
position_encoder_class=SinusoidalPositionEncoder,
num_sources=1,
maximum_relative_position=None,
attention_reduction=transformer.MultiHeadAttentionReduction.FIRST_HEAD_LAST_LAYER,
pre_norm=True,
**kwargs
):
"""Initializes the parameters of the decoder.
Args:
num_layers: The number of layers.
num_units: The number of hidden units.
num_heads: The number of heads in the multi-head attention.
ffn_inner_dim: The number of units of the inner linear transformation
in the feed forward layer.
dropout: The probability to drop units from the outputs.
attention_dropout: The probability to drop units from the attention.
ffn_dropout: The probability to drop units from the activation output in
the feed forward layer.
ffn_activation: The activation function to apply between the two linear
transformations of the feed forward layer.
mha_bias: Add bias after linear layers in the multi-head attention.
position_encoder_class: The :class:`opennmt.layers.PositionEncoder`
class to use for position encoding (or a callable that returns an
instance).
num_sources: The number of source contexts expected by this decoder.
maximum_relative_position: Maximum relative position representation
(from https://arxiv.org/abs/1803.02155).
attention_reduction: A :class:`opennmt.layers.MultiHeadAttentionReduction`
value to specify how to reduce multi-head attention matrices.
pre_norm: If ``True``, layer normalization is applied before each
sub-layer. Otherwise it is applied after.
**kwargs: Additional layer arguments.
"""
super().__init__(num_sources=num_sources, **kwargs)
self.num_units = num_units
self.num_heads = num_heads
self.dropout = dropout
self.attention_reduction = attention_reduction
self.position_encoder = None
if position_encoder_class is not None:
self.position_encoder = position_encoder_class()
self.layer_norm = common.LayerNorm() if pre_norm else None
self.layers = [
transformer.SelfAttentionDecoderLayer(
self.num_units,
self.num_heads,
ffn_inner_dim,
num_sources=num_sources,
dropout=dropout,
attention_dropout=attention_dropout,
ffn_dropout=ffn_dropout,
ffn_activation=ffn_activation,
mha_bias=mha_bias,
maximum_relative_position=maximum_relative_position,
pre_norm=pre_norm,
)
for i in range(num_layers)
]
@property
def minimum_sources(self):
return 0
@property
def maximum_sources(self):
return 1e6 # An arbitrary large number.
@property
def support_alignment_history(self):
return True
[docs] def map_v1_weights(self, weights):
m = super().map_v1_weights(weights)
m += self.layer_norm.map_v1_weights(weights["LayerNorm"])
for i, layer in enumerate(self.layers):
m += layer.map_v1_weights(weights["layer_%d" % i])
return m
def _run(
self,
inputs,
sequence_length=None,
cache=None,
memory=None,
memory_sequence_length=None,
step=None,
training=None,
):
# Process inputs.
inputs *= self.num_units**0.5
if self.position_encoder is not None:
inputs = self.position_encoder(
inputs, position=step + 1 if step is not None else None
)
inputs = common.dropout(inputs, self.dropout, training=training)
# Prepare query mask.
mask = None
if step is None:
maximum_length = tf.shape(inputs)[1]
if sequence_length is None:
batch_size = tf.shape(inputs)[0]
sequence_length = tf.fill([batch_size], maximum_length)
mask = transformer.future_mask(
sequence_length, maximum_length=maximum_length
)
# Prepare memory mask.
memory_mask = None
if memory is not None:
if not isinstance(memory, (list, tuple)):
memory = (memory,)
if memory_sequence_length is not None:
if not isinstance(memory_sequence_length, (list, tuple)):
memory_sequence_length = (memory_sequence_length,)
memory_mask = [
tf.sequence_mask(mem_length, maxlen=tf.shape(mem)[1])
for mem, mem_length in zip(memory, memory_sequence_length)
]
else:
memory_mask = tuple(None for _ in memory)
# Run each layer.
new_cache = []
attention = []
for i, layer in enumerate(self.layers):
inputs, layer_cache, layer_attention = layer(
inputs,
mask=mask,
memory=memory,
memory_mask=memory_mask,
cache=cache[i] if cache is not None else None,
training=training,
)
attention.append(layer_attention)
new_cache.append(layer_cache)
outputs = self.layer_norm(inputs) if self.layer_norm is not None else inputs
# Convert list of shape num_layers x num_sources to num_sources x num_layers
attention = list(map(list, zip(*attention)))
if attention:
attention = transformer.MultiHeadAttentionReduction.reduce(
attention[0], # Get attention to the first source.
self.attention_reduction,
)
else:
attention = None
return outputs, new_cache, attention
[docs] def forward(
self,
inputs,
sequence_length=None,
initial_state=None,
memory=None,
memory_sequence_length=None,
input_fn=None,
sampling_probability=None,
training=None,
):
_ = initial_state
_ = input_fn
if sampling_probability is not None:
raise ValueError("Scheduled sampling is not supported by this decoder")
outputs, state, attention = self._run(
inputs,
sequence_length=sequence_length,
memory=memory,
memory_sequence_length=memory_sequence_length,
training=training,
)
logits = self.output_layer(outputs)
return logits, state, attention
[docs] def step(
self,
inputs,
timestep,
state=None,
memory=None,
memory_sequence_length=None,
training=None,
):
inputs = tf.expand_dims(inputs, 1)
outputs, state, attention = self._run(
inputs,
cache=state,
memory=memory,
memory_sequence_length=memory_sequence_length,
step=timestep,
training=training,
)
outputs = tf.squeeze(outputs, axis=1)
if attention is not None:
attention = tf.squeeze(attention, axis=1)
return outputs, state, attention
def _get_initial_state(self, batch_size, dtype, initial_state=None):
# The decoder state contains the keys and values projections of the previous timesteps.
_ = initial_state
cache = []
for _ in self.layers:
shape = [batch_size, self.num_heads, 0, self.num_units // self.num_heads]
self_kv = (tf.zeros(shape, dtype=dtype), tf.zeros(shape, dtype=dtype))
memory_kv = [
(tf.zeros(shape, dtype=dtype), tf.zeros(shape, dtype=dtype))
for _ in range(self.num_sources)
]
cache.append(dict(self_kv=self_kv, memory_kv=memory_kv))
return cache
def _get_state_reorder_flags(self):
# We don't need to reorder memory_kv as it is the same for all beams.
return [
{
"self_kv": (True, True),
"memory_kv": [(False, False) for _ in range(self.num_sources)],
}
for _ in self.layers
]