"""Global attention modules (Luong / Bahdanau)"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from onmt.modules.sparse_activations import sparsemax
from onmt.utils.misc import sequence_mask
# This class is mainly used by decoder.py for RNNs but also
# by the CNN / transformer decoder when copy attention is used
# CNN has its own attention mechanism ConvMultiStepAttention
# Transformer has its own MultiHeadedAttention
[docs]class GlobalAttention(nn.Module):
r"""
Global attention takes a matrix and a query vector. It
then computes a parameterized convex combination of the matrix
based on the input query.
Constructs a unit mapping a query `q` of size `dim`
and a source matrix `H` of size `n x dim`, to an output
of size `dim`.
.. mermaid::
graph BT
A[Query]
subgraph RNN
C[H 1]
D[H 2]
E[H N]
end
F[Attn]
G[Output]
A --> F
C --> F
D --> F
E --> F
C -.-> G
D -.-> G
E -.-> G
F --> G
All models compute the output as
:math:`c = \sum_{j=1}^{\text{SeqLength}} a_j H_j` where
:math:`a_j` is the softmax of a score function.
Then then apply a projection layer to [q, c].
However they
differ on how they compute the attention score.
* Luong Attention (dot, general):
* dot: :math:`\text{score}(H_j,q) = H_j^T q`
* general: :math:`\text{score}(H_j, q) = H_j^T W_a q`
* Bahdanau Attention (mlp):
* :math:`\text{score}(H_j, q) = v_a^T \text{tanh}(W_a q + U_a h_j)`
Args:
dim (int): dimensionality of query and key
coverage (bool): use coverage term
attn_type (str): type of attention to use, options [dot,general,mlp]
attn_func (str): attention function to use, options [softmax,sparsemax]
"""
def __init__(self, dim, coverage=False, attn_type="dot", attn_func="softmax"):
super(GlobalAttention, self).__init__()
self.dim = dim
assert attn_type in [
"dot",
"general",
"mlp",
], "Please select a valid attention type (got {:s}).".format(attn_type)
self.attn_type = attn_type
assert attn_func in [
"softmax",
"sparsemax",
], "Please select a valid attention function."
self.attn_func = attn_func
if self.attn_type == "general":
self.linear_in = nn.Linear(dim, dim, bias=False)
elif self.attn_type == "mlp":
self.linear_context = nn.Linear(dim, dim, bias=False)
self.linear_query = nn.Linear(dim, dim, bias=True)
self.v = nn.Linear(dim, 1, bias=False)
# mlp wants it with bias
out_bias = self.attn_type == "mlp"
self.linear_out = nn.Linear(dim * 2, dim, bias=out_bias)
if coverage:
self.linear_cover = nn.Linear(1, dim, bias=False)
[docs] def score(self, h_t, h_s):
"""
Args:
h_t (FloatTensor): sequence of queries ``(batch, tgt_len, dim)``
h_s (FloatTensor): sequence of sources ``(batch, src_len, dim``
Returns:
FloatTensor: raw attention scores (unnormalized) for each src index
``(batch, tgt_len, src_len)``
"""
src_batch, src_len, src_dim = h_s.size()
tgt_batch, tgt_len, tgt_dim = h_t.size()
if self.attn_type in ["general", "dot"]:
if self.attn_type == "general":
h_t = self.linear_in(h_t)
h_s_ = h_s.transpose(1, 2)
# (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len)
return torch.bmm(h_t, h_s_)
else:
dim = self.dim
wq = self.linear_query(h_t)
wq = wq.view(tgt_batch, tgt_len, 1, dim)
wq = wq.expand(tgt_batch, tgt_len, src_len, dim)
uh = self.linear_context(h_s.contiguous())
uh = uh.view(src_batch, 1, src_len, dim)
uh = uh.expand(src_batch, tgt_len, src_len, dim)
# (batch, t_len, s_len, d)
wquh = torch.tanh(wq + uh)
return self.v(wquh).view(tgt_batch, tgt_len, src_len)
[docs] def forward(self, src, enc_out, src_len=None, coverage=None):
"""
Args:
src (FloatTensor): query vectors ``(batch, tgt_len, dim)``
enc_out (FloatTensor): encoder out vectors ``(batch, src_len, dim)``
src_len (LongTensor): source context lengths ``(batch,)``
coverage (FloatTensor): None (not supported yet)
Returns:
(FloatTensor, FloatTensor):
* Computed vector ``(batch, tgt_len, dim)``
* Attention distribtutions for each query
``(batch, tgt_len, src_len)``
"""
# one step input
if src.dim() == 2:
one_step = True
src = src.unsqueeze(1)
else:
one_step = False
batch, src_l, dim = enc_out.size()
batch_, target_l, dim_ = src.size()
if coverage is not None:
batch_, src_l_ = coverage.size()
if coverage is not None:
cover = coverage.view(-1).unsqueeze(1)
enc_out += self.linear_cover(cover).view_as(enc_out)
enc_out = torch.tanh(enc_out)
# compute attention scores, as in Luong et al.
align = self.score(src, enc_out)
if src_len is not None:
mask = ~sequence_mask(src_len, max_len=align.size(-1))
mask = mask.unsqueeze(1) # Make it broadcastable.
align.masked_fill_(~mask, -float("inf"))
# Softmax or sparsemax to normalize attention weights
if self.attn_func == "softmax":
align_vectors = F.softmax(align.view(batch * target_l, src_l), -1)
else:
align_vectors = sparsemax(align.view(batch * target_l, src_l), -1)
align_vectors = align_vectors.view(batch, target_l, src_l)
# each context vector c_t is the weighted average
# over all the source hidden states
c = torch.bmm(align_vectors, enc_out)
# concatenate
concat_c = torch.cat([c, src], 2).view(batch * target_l, dim * 2)
attn_h = self.linear_out(concat_c).view(batch, target_l, dim)
if self.attn_type in ["general", "dot"]:
attn_h = torch.tanh(attn_h)
if one_step:
attn_h = attn_h.squeeze(1)
align_vectors = align_vectors.squeeze(1)
return attn_h, align_vectors