"""Implementation of the CNN Decoder part of
"Convolutional Sequence to Sequence Learning"
"""
import torch
import torch.nn as nn
from onmt.modules import ConvMultiStepAttention, GlobalAttention
from onmt.utils.cnn_factory import shape_transform, GatedConv
from onmt.decoders.decoder import DecoderBase
SCALE_WEIGHT = 0.5**0.5
[docs]class CNNDecoder(DecoderBase):
"""Decoder based on "Convolutional Sequence to Sequence Learning"
:cite:`DBLP:journals/corr/GehringAGYD17`.
Consists of residual convolutional layers, with ConvMultiStepAttention.
"""
def __init__(
self,
num_layers,
hidden_size,
attn_type,
copy_attn,
cnn_kernel_width,
dropout,
embeddings,
copy_attn_type,
):
super(CNNDecoder, self).__init__()
self.cnn_kernel_width = cnn_kernel_width
self.embeddings = embeddings
# Decoder State
self.state = {}
input_size = self.embeddings.embedding_size
self.linear = nn.Linear(input_size, hidden_size)
self.conv_layers = nn.ModuleList(
[
GatedConv(hidden_size, cnn_kernel_width, dropout, True)
for i in range(num_layers)
]
)
self.attn_layers = nn.ModuleList(
[ConvMultiStepAttention(hidden_size) for i in range(num_layers)]
)
# CNNDecoder has its own attention mechanism.
# Set up a separate copy attention layer if needed.
assert not copy_attn, "Copy mechanism not yet tested in conv2conv"
if copy_attn:
self.copy_attn = GlobalAttention(hidden_size, attn_type=copy_attn_type)
else:
self.copy_attn = None
[docs] @classmethod
def from_opt(cls, opt, embeddings):
"""Alternate constructor."""
return cls(
opt.dec_layers,
opt.dec_hid_size,
opt.global_attention,
opt.copy_attn,
opt.cnn_kernel_width,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
embeddings,
opt.copy_attn_type,
)
[docs] def init_state(self, _, enc_out, enc_hidden):
"""Init decoder state."""
self.state["src"] = (enc_out + enc_hidden) * SCALE_WEIGHT
self.state["previous_input"] = None
def map_state(self, fn):
self.state["src"] = fn(self.state["src"], 0)
if self.state["previous_input"] is not None:
self.state["previous_input"] = fn(self.state["previous_input"], 0)
def detach_state(self):
self.state["previous_input"] = self.state["previous_input"].detach()
[docs] def forward(self, tgt, enc_out, step=None, **kwargs):
"""See :obj:`onmt.modules.RNNDecoderBase.forward()`"""
if self.state["previous_input"] is not None:
tgt = torch.cat([self.state["previous_input"], tgt], 1)
dec_outs = []
attns = {"std": []}
if self.copy_attn is not None:
attns["copy"] = []
emb = self.embeddings(tgt)
assert emb.dim() == 3 # batch x len x embedding_dim
tgt_emb = emb
# The output of CNNEncoder.
enc_out_t = enc_out
# The combination of output of CNNEncoder and source embeddings.
enc_out_c = self.state["src"]
emb_reshape = tgt_emb.view(tgt_emb.size(0) * tgt_emb.size(1), -1)
linear_out = self.linear(emb_reshape)
x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1)
x = shape_transform(x)
pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1)
pad = pad.type_as(x)
base_target_emb = x
for conv, attention in zip(self.conv_layers, self.attn_layers):
new_target_input = torch.cat([pad, x], 2)
out = conv(new_target_input)
c, attn = attention(base_target_emb, out, enc_out_t, enc_out_c)
x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT
dec_outs = x.squeeze(3).transpose(1, 2)
# Process the result and update the attentions.
if self.state["previous_input"] is not None:
dec_outs = dec_outs[:, self.state["previous_input"].size(1) :, :]
attn = attn[:, self.state["previous_input"].size(1) :].squeeze()
attn = torch.stack([attn])
attns["std"] = attn
if self.copy_attn is not None:
attns["copy"] = attn
# Update the state.
self.state["previous_input"] = tgt
# TODO change the way attns is returned dict => list or tuple (onnx)
return dec_outs, attns
def update_dropout(self, dropout, attention_dropout=None):
for layer in self.conv_layers:
layer.dropout.p = dropout