Source code for onmt.modules.global_attention

"""Global attention modules (Luong / Bahdanau)"""
import torch
import torch.nn as nn
import torch.nn.functional as F

from onmt.modules.sparse_activations import sparsemax
from onmt.utils.misc import sequence_mask

# This class is mainly used by decoder.py for RNNs but also
# by the CNN / transformer decoder when copy attention is used
# CNN has its own attention mechanism ConvMultiStepAttention
# Transformer has its own MultiHeadedAttention


[docs]class GlobalAttention(nn.Module): r""" Global attention takes a matrix and a query vector. It then computes a parameterized convex combination of the matrix based on the input query. Constructs a unit mapping a query `q` of size `dim` and a source matrix `H` of size `n x dim`, to an output of size `dim`. .. mermaid:: graph BT A[Query] subgraph RNN C[H 1] D[H 2] E[H N] end F[Attn] G[Output] A --> F C --> F D --> F E --> F C -.-> G D -.-> G E -.-> G F --> G All models compute the output as :math:`c = \sum_{j=1}^{\text{SeqLength}} a_j H_j` where :math:`a_j` is the softmax of a score function. Then then apply a projection layer to [q, c]. However they differ on how they compute the attention score. * Luong Attention (dot, general): * dot: :math:`\text{score}(H_j,q) = H_j^T q` * general: :math:`\text{score}(H_j, q) = H_j^T W_a q` * Bahdanau Attention (mlp): * :math:`\text{score}(H_j, q) = v_a^T \text{tanh}(W_a q + U_a h_j)` Args: dim (int): dimensionality of query and key coverage (bool): use coverage term attn_type (str): type of attention to use, options [dot,general,mlp] attn_func (str): attention function to use, options [softmax,sparsemax] """ def __init__(self, dim, coverage=False, attn_type="dot", attn_func="softmax"): super(GlobalAttention, self).__init__() self.dim = dim assert attn_type in [ "dot", "general", "mlp", ], "Please select a valid attention type (got {:s}).".format(attn_type) self.attn_type = attn_type assert attn_func in [ "softmax", "sparsemax", ], "Please select a valid attention function." self.attn_func = attn_func if self.attn_type == "general": self.linear_in = nn.Linear(dim, dim, bias=False) elif self.attn_type == "mlp": self.linear_context = nn.Linear(dim, dim, bias=False) self.linear_query = nn.Linear(dim, dim, bias=True) self.v = nn.Linear(dim, 1, bias=False) # mlp wants it with bias out_bias = self.attn_type == "mlp" self.linear_out = nn.Linear(dim * 2, dim, bias=out_bias) if coverage: self.linear_cover = nn.Linear(1, dim, bias=False)
[docs] def score(self, h_t, h_s): """ Args: h_t (FloatTensor): sequence of queries ``(batch, tgt_len, dim)`` h_s (FloatTensor): sequence of sources ``(batch, src_len, dim`` Returns: FloatTensor: raw attention scores (unnormalized) for each src index ``(batch, tgt_len, src_len)`` """ src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t = self.linear_in(h_t) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous()) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = torch.tanh(wq + uh) return self.v(wquh).view(tgt_batch, tgt_len, src_len)
[docs] def forward(self, src, enc_out, src_len=None, coverage=None): """ Args: src (FloatTensor): query vectors ``(batch, tgt_len, dim)`` enc_out (FloatTensor): encoder out vectors ``(batch, src_len, dim)`` src_len (LongTensor): source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) Returns: (FloatTensor, FloatTensor): * Computed vector ``(batch, tgt_len, dim)`` * Attention distribtutions for each query ``(batch, tgt_len, src_len)`` """ # one step input if src.dim() == 2: one_step = True src = src.unsqueeze(1) else: one_step = False batch, src_l, dim = enc_out.size() batch_, target_l, dim_ = src.size() if coverage is not None: batch_, src_l_ = coverage.size() if coverage is not None: cover = coverage.view(-1).unsqueeze(1) enc_out += self.linear_cover(cover).view_as(enc_out) enc_out = torch.tanh(enc_out) # compute attention scores, as in Luong et al. align = self.score(src, enc_out) if src_len is not None: mask = ~sequence_mask(src_len, max_len=align.size(-1)) mask = mask.unsqueeze(1) # Make it broadcastable. align.masked_fill_(~mask, -float("inf")) # Softmax or sparsemax to normalize attention weights if self.attn_func == "softmax": align_vectors = F.softmax(align.view(batch * target_l, src_l), -1) else: align_vectors = sparsemax(align.view(batch * target_l, src_l), -1) align_vectors = align_vectors.view(batch, target_l, src_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, enc_out) # concatenate concat_c = torch.cat([c, src], 2).view(batch * target_l, dim * 2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) return attn_h, align_vectors