Source code for onmt.modules.copy_generator

import torch
import torch.nn as nn


def collapse_copy_scores(scores, batch, tgt_vocab, batch_dim=1):
    """
    Given scores from an expanded dictionary
    corresponeding to a batch, sums together copies,
    with a dictionary word when it is ambiguous.
    """
    offset = len(tgt_vocab)
    for b in range(scores.size(batch_dim)):
        blank = []
        fill = []

        src_vocab = batch["src_ex_vocab"][b]

        for i in range(1, len(src_vocab)):
            sw = src_vocab.ids_to_tokens[i]
            ti = tgt_vocab[sw]
            if ti != 0:
                blank.append(offset + i)
                fill.append(ti)
        if blank:
            blank = torch.Tensor(blank).to(torch.int64)
            fill = torch.Tensor(fill).to(torch.int64)
            score = scores[:, b] if batch_dim == 1 else scores[b]
            score.index_add_(1, fill, score.index_select(1, blank))
            score.index_fill_(1, blank, 1e-10)
    return scores


[docs]class CopyGenerator(nn.Module): """An implementation of pointer-generator networks :cite:`DBLP:journals/corr/SeeLM17`. These networks consider copying words directly from the source sequence. The copy generator is an extended version of the standard generator that computes three values. * :math:`p_{softmax}` the standard softmax over `tgt_dict` * :math:`p(z)` the probability of copying a word from the source * :math:`p_{copy}` the probility of copying a particular word. taken from the attention distribution directly. The model returns a distribution over the extend dictionary, computed as :math:`p(w) = p(z=1) p_{copy}(w) + p(z=0) p_{softmax}(w)` Args: input_size (int): size of input representation output_size (int): size of output vocabulary pad_idx (int) """ def __init__(self, input_size, output_size, pad_idx): super(CopyGenerator, self).__init__() self.linear = nn.Linear(input_size, output_size) self.linear_copy = nn.Linear(input_size, 1) self.pad_idx = pad_idx
[docs] def forward(self, hidden, attn, src_map): """ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by copying source words. Args: hidden (FloatTensor): hidden output ``(batch x tlen, input_size)`` attn (FloatTensor): attn for each ``(batch x tlen, slen)`` src_map (FloatTensor): A sparse indicator matrix mapping each source word to its index in the "extended" vocab containing. ``(batch, src_len, extra_words)`` """ _, slen = attn.size() batch, _, cvocab = src_map.size() # Original probabilities. logits = self.linear(hidden) logits[:, self.pad_idx] = -float("inf") prob = torch.softmax(logits, 1) # Probability of copying p(z=1) batch. p_copy = torch.sigmoid(self.linear_copy(hidden)) # Probability of not copying: p_{word}(w) * (1 - p(z)) out_prob = torch.mul(prob, 1 - p_copy) mul_attn = torch.mul(attn, p_copy) copy_prob = torch.bmm(mul_attn.view(-1, batch, slen).transpose(0, 1), src_map) copy_prob = copy_prob.contiguous().view(-1, cvocab) return torch.cat([out_prob, copy_prob], 1)
class CopyGeneratorLoss(nn.Module): """Copy generator criterion.""" def __init__( self, vocab_size, force_copy, unk_index=0, ignore_index=-100, eps=1e-20 ): super(CopyGeneratorLoss, self).__init__() self.force_copy = force_copy self.eps = eps self.vocab_size = vocab_size self.ignore_index = ignore_index self.unk_index = unk_index def forward(self, scores, align, target): """ Args: scores (FloatTensor): ``(batch_size*tgt_len)`` x dynamic vocab size whose sum along dim 1 is less than or equal to 1, i.e. cols softmaxed. align (LongTensor): ``(batch_size x tgt_len)`` target (LongTensor): ``(batch_size x tgt_len)`` """ # probabilities assigned by the model to the gold targets vocab_probs = scores.gather(1, target.unsqueeze(1)).squeeze(1) # probability of tokens copied from source copy_ix = align.unsqueeze(1) + self.vocab_size copy_tok_probs = scores.gather(1, copy_ix).squeeze(1) # Set scores for unk to 0 and add eps copy_tok_probs[align == self.unk_index] = 0 copy_tok_probs += self.eps # to avoid -inf logs # find the indices in which you do not use the copy mechanism non_copy = align == self.unk_index if not self.force_copy: non_copy = non_copy | (target != self.unk_index) probs = torch.where(non_copy, copy_tok_probs + vocab_probs, copy_tok_probs) loss = -probs.log() # just NLLLoss; can the module be incorporated? # Drop padding. loss[target == self.ignore_index] = 0 return loss