Source code for onmt.modules.average_attn

# -*- coding: utf-8 -*-
"""Average Attention module."""

import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional
from onmt.modules.position_ffn import PositionwiseFeedForward
from onmt.modules.position_ffn import ActivationFunction


def cumulative_average_mask(
    batch_size: int, t_len: int, device: Optional[torch.device] = None
) -> Tensor:
    """
    Builds the mask to compute the cumulative average as described in
    :cite:`DBLP:journals/corr/abs-1805-00631` -- Figure 3
    Args:
        batch_size (int): batch size
        t_len (int): length of the layer_in

    Returns:
        (Tensor):
        * A Tensor of shape ``(batch_size, t_len, t_len)``
    """

    triangle = torch.tril(torch.ones(t_len, t_len, dtype=torch.float, device=device))
    weights = torch.ones(1, t_len, dtype=torch.float, device=device) / torch.arange(
        1, t_len + 1, dtype=torch.float, device=device
    )
    mask = triangle * weights.transpose(0, 1)

    return mask.unsqueeze(0).expand(batch_size, t_len, t_len)


def cumulative_average(
    layer_in: Tensor, layer_cache: tuple, mask=None, step=None
) -> Tensor:
    """
    Computes the cumulative average as described in
    :cite:`DBLP:journals/corr/abs-1805-00631` -- Equations (1) (5) (6)

    Args:
        layer_in (FloatTensor): sequence to average
            ``(batch_size, input_len, dimension)``
        layer_cache: tuple(bool, dict)
        if layer_cahe[0] is True use step otherwise mask
        mask: mask matrix used to compute the cumulative average
        step: current step of the dynamic decoding

    Returns:
        a tensor of the same shape and type as ``layer_in``.
    """

    if layer_cache[0]:
        average_attention = (layer_in + step * layer_cache[1]["prev_g"]) / (step + 1)
        layer_cache[1]["prev_g"] = average_attention
        return average_attention
    else:
        return torch.matmul(mask.to(layer_in.dtype), layer_in)


[docs]class AverageAttention(nn.Module): # class AverageAttention(torch.jit.ScriptModule): """ Average Attention module from "Accelerating Neural Transformer via an Average Attention Network" :cite:`DBLP:journals/corr/abs-1805-00631`. Args: model_dim (int): the dimension of keys/values/queries, must be divisible by head_count dropout (float): dropout parameter pos_ffn_activation_fn (ActivationFunction): activation function choice for PositionwiseFeedForward layer """ def __init__( self, model_dim, dropout=0.1, aan_useffn=False, pos_ffn_activation_fn=ActivationFunction.relu, ): self.model_dim = model_dim self.aan_useffn = aan_useffn super(AverageAttention, self).__init__() if aan_useffn: self.average_layer = PositionwiseFeedForward( model_dim, model_dim, dropout, pos_ffn_activation_fn ) self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2) self.layer_cache = False, {"prev_g": torch.tensor([])} # @torch.jit.script
[docs] def forward(self, layer_in, mask=None, step=None): """ Args: layer_in (FloatTensor): ``(batch, t_len, dim)`` Returns: (FloatTensor, FloatTensor): * gating_out ``(batch, tlen, dim)`` * average_out average attention ``(batch, input_len, dim)`` """ batch_size = layer_in.size(0) t_len = layer_in.size(1) mask = ( cumulative_average_mask(batch_size, t_len, layer_in.device) if not self.layer_cache[0] else None ) average_out = cumulative_average(layer_in, self.layer_cache, mask, step) if self.aan_useffn: average_out = self.average_layer(average_out) gating_out = self.gating_layer(torch.cat((layer_in, average_out), -1)) input_gate, forget_gate = torch.chunk(gating_out, 2, dim=2) gating_out = ( torch.sigmoid(input_gate) * layer_in + torch.sigmoid(forget_gate) * average_out ) return gating_out, average_out