MultiHeadAttention

class opennmt.layers.MultiHeadAttention(*args, **kwargs)[source]

Computes the multi-head attention as described in https://arxiv.org/abs/1706.03762.

Inherits from: keras.src.engine.base_layer.Layer

__init__(num_heads, num_units, bias=True, dropout=0.1, return_attention=False, maximum_relative_position=None, **kwargs)[source]

Initializes this layer.

Parameters
  • num_heads – The number of attention heads.

  • num_units – The number of hidden units.

  • bias – Add bias after linear layers.

  • dropout – The probability to drop units from the inputs.

  • return_attention – If True, also return the attention weights.

  • maximum_relative_position – Maximum relative position representation (from https://arxiv.org/abs/1803.02155).

  • kwargs – Additional layer arguments.

map_v1_weights(weights)[source]
build(input_shape)[source]

Creates the variables of the layer (for subclass implementers).

This is a method that implementers of subclasses of Layer or Model can override if they need a state-creation step in-between layer instantiation and layer call. It is invoked automatically before the first execution of call().

This is typically used to create the weights of Layer subclasses (at the discretion of the subclass implementer).

Parameters

input_shape – Instance of TensorShape, or list of instances of TensorShape if the layer expects a list of inputs (one instance per input).

call(inputs, memory=None, mask=None, cache=None, training=None)[source]

Runs the layer.

Parameters
  • inputs – The sequence of queries. A tensor of shape \([B, T_1, ...]\).

  • memory – The sequence to attend. A tensor of shape \([B, T_2, ...]\). If None, computes self-attention.

  • mask – The dot product mask. A boolean tensor of shape \([B, T_2]\) or \([B, T_1, T_2]\).

  • cache – An optional tuple containing projected keys and values from the previous step. Tensors of shape \([B, H, T_2, D / H]\).

  • training – Run in training mode.

Returns

A tuple with the attention context, the updated cache and the attention weights (if return_attention is True).