Source code for onmt.translate.greedy_search

import torch
from torch.nn.functional import softmax
from onmt.translate.decode_strategy import DecodeStrategy


def sample_topp(logits, keep_topp):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=1)

    cumulative_probs = torch.cumsum(softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_keep = cumulative_probs.lt(keep_topp)

    # keep indices until overflowing p
    cumsum_mask = sorted_indices_to_keep.cumsum(dim=1)
    last_included = cumsum_mask[:, -1:]
    last_included.clamp_(0, sorted_indices_to_keep.size()[1] - 1)
    sorted_indices_to_keep = sorted_indices_to_keep.scatter_(1, last_included, 1)

    # Set all logits that are not in the top-p to -10000.
    # This puts the probabilities close to 0.
    keep_indices = sorted_indices_to_keep.scatter(
        1,
        sorted_indices,
        sorted_indices_to_keep,
    )
    return logits.masked_fill(~keep_indices, -10000)


def sample_topk(logits, keep_topk):
    top_values, _ = torch.topk(logits, keep_topk, dim=1)
    kth_best = top_values[:, -1].view([-1, 1])
    kth_best = kth_best.repeat([1, logits.shape[1]]).float()

    # Set all logits that are not in the top-k to -10000.
    # This puts the probabilities close to 0.
    ignore = torch.lt(logits, kth_best)
    return logits.masked_fill(ignore, -10000)


[docs]def sample_with_temperature(logits, sampling_temp, keep_topk, keep_topp): """Select next tokens randomly from the top k possible next tokens. Samples from a categorical distribution over the ``keep_topk`` words using the category probabilities ``logits / sampling_temp``. Args: logits (FloatTensor): Shaped ``(batch_size, vocab_size)``. These can be logits (``(-inf, inf)``) or log-probs (``(-inf, 0]``). (The distribution actually uses the log-probabilities ``logits - logits.logsumexp(-1)``, which equals the logits if they are log-probabilities summing to 1.) sampling_temp (float): Used to scale down logits. The higher the value, the more likely it is that a non-max word will be sampled. keep_topk (int): This many words could potentially be chosen. The other logits are set to have probability 0. keep_topp (float): Keep most likely words until the cumulated probability is greater than p. If used with keep_topk: both conditions will be applied Returns: (LongTensor, FloatTensor): * topk_ids: Shaped ``(batch_size, 1)``. These are the sampled word indices in the output vocab. * topk_scores: Shaped ``(batch_size, 1)``. These are essentially ``(logits / sampling_temp)[topk_ids]``. """ if sampling_temp == 0.0 or keep_topk == 1: # For temp=0.0, take the argmax to avoid divide-by-zero errors. # keep_topk=1 is also equivalent to argmax. topk_scores, topk_ids = logits.topk(1, dim=-1) if sampling_temp > 0: topk_scores /= sampling_temp else: logits = torch.div(logits, sampling_temp) if keep_topp > 0: logits = sample_topp(logits, keep_topp) if keep_topk > 0: logits = sample_topk(logits, keep_topk) dist = torch.distributions.Categorical(logits=logits) topk_ids = dist.sample().view(-1, 1) topk_scores = logits.gather(dim=1, index=topk_ids) return topk_ids, topk_scores
[docs]class GreedySearch(DecodeStrategy): """Select next tokens randomly from the top k possible next tokens. The ``scores`` attribute's lists are the score, after applying temperature, of the final prediction (either EOS or the final token in the event that ``max_length`` is reached) Args: pad (int): See base. bos (int): See base. eos (int): See base. unk (int): See base. start (int): See base. n_best (int): Don't stop until at least this many beams have reached EOS. batch_size (int): See base. global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance. min_length (int): See base. max_length (int): See base. ban_unk_token (Boolean): See base. block_ngram_repeat (int): See base. exclusion_tokens (set[int]): See base. return_attention (bool): See base. max_length (int): See base. sampling_temp (float): See :func:`~onmt.translate.greedy_search.sample_with_temperature()`. keep_topk (int): See :func:`~onmt.translate.greedy_search.sample_with_temperature()`. keep_topp (float): See :func:`~onmt.translate.greedy_search.sample_with_temperature()`. beam_size (int): Number of beams to use. """ def __init__( self, pad, bos, eos, unk, start, n_best, batch_size, global_scorer, min_length, block_ngram_repeat, exclusion_tokens, return_attention, max_length, sampling_temp, keep_topk, keep_topp, beam_size, ban_unk_token, ): super(GreedySearch, self).__init__( pad, bos, eos, unk, start, batch_size, beam_size, global_scorer, min_length, block_ngram_repeat, exclusion_tokens, return_attention, max_length, ban_unk_token, ) self.sampling_temp = sampling_temp self.keep_topk = keep_topk self.keep_topp = keep_topp self.topk_scores = None self.beam_size = beam_size self.n_best = n_best
[docs] def initialize( self, enc_out, src_len, src_map=None, device=None, target_prefix=None ): """Initialize for decoding.""" (fn_map_state, enc_out, src_map, target_prefix) = self.initialize_tile( enc_out, src_len, src_map, target_prefix ) if device is None: device = self.get_device_from_enc_out(enc_out) super(GreedySearch, self).initialize(device, target_prefix) self.select_indices = torch.arange( self.batch_size * self.beam_size, dtype=torch.long, device=device ) self.original_batch_idx = fn_map_state( torch.arange(self.batch_size, dtype=torch.long, device=device), dim=0 ) self.beams_scores = torch.zeros( (self.batch_size * self.beam_size, 1), dtype=torch.float, device=device ) return fn_map_state, enc_out, src_map
@property def current_predictions(self): return self.alive_seq[:, -1] @property def batch_offset(self): return self.select_indices def _pick(self, log_probs): """Function used to pick next tokens. Args: log_probs (FloatTensor): ``(batch_size, vocab_size)``. """ # maybe fix some prediction at this step by modifying log_probs log_probs = self.target_prefixing(log_probs) topk_ids, topk_scores = sample_with_temperature( log_probs, self.sampling_temp, self.keep_topk, self.keep_topp ) return topk_ids, topk_scores def align_select_indices(self): nb_finished_beams = len(self.is_finished_list) - self.select_indices.size(0) if nb_finished_beams: self.select_indices = torch.arange( self.select_indices.size(0), dtype=torch.long, device=self.select_indices.device, )
[docs] def advance(self, log_probs, attn): """Select next tokens randomly from the top k possible next tokens. Args: log_probs (FloatTensor): Shaped ``(batch_size, vocab_size)``. These can be logits (``(-inf, inf)``) or log-probs (``(-inf, 0]``). (The distribution actually uses the log-probabilities ``logits - logits.logsumexp(-1)``, which equals the logits if they are log-probabilities summing to 1.) attn (FloatTensor): Shaped ``(1, B, inp_seq_len)``. """ self.align_select_indices() self.ensure_min_length(log_probs) self.ensure_unk_removed(log_probs) self.block_ngram_repeats(log_probs) topk_ids, self.topk_scores = self._pick(log_probs) self.beams_scores += self.topk_scores self.is_finished_list = topk_ids.eq(self.eos).tolist() self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1) if self.return_attention: if self.alive_attn is None: self.alive_attn = attn else: self.alive_attn = torch.cat([self.alive_attn, attn], 1) self.ensure_max_length()
[docs] def update_finished(self): """Finalize scores and predictions.""" # shape: (sum(~ self.is_finished), 1) step = len(self) non_finished_batch = [ b for b, fin in enumerate(self.is_finished_list) if not fin[0] ] length_penalty = self.global_scorer.length_penalty( step, alpha=self.global_scorer.alpha ) for b in [i for i, fin in enumerate(self.is_finished_list) if fin[0]]: b_orig = self.original_batch_idx[b] score = self.beams_scores[b, 0] / length_penalty pred = self.alive_seq[b, 1:] attention = ( self.alive_attn[ b, :, : self.src_len[b], ] if self.alive_attn is not None else [] ) self.hypotheses[b_orig].append((score, pred, attention)) self.done = len(non_finished_batch) == 0 if self.done: for b in range(self.batch_size): best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True)[ : self.n_best ] for score, pred, attn in best_hyp: self.scores[b].append(score) self.predictions[b].append(pred) self.attention[b].append(attn) return self.select_indices = torch.tensor( non_finished_batch, device=self.alive_seq.device ) self.alive_seq = self.alive_seq[self.select_indices] self.beams_scores = self.beams_scores[self.select_indices] self.src_len = self.src_len[self.select_indices] if self.alive_attn is not None: self.alive_attn = self.alive_attn[self.select_indices] self.original_batch_idx = self.original_batch_idx[self.select_indices] self.maybe_update_target_prefix(self.select_indices)
class GreedySearchLM(GreedySearch): def update_finished(self): super(GreedySearchLM, self).update_finished() def initialize(self, src, src_len, src_map=None, device=None, target_prefix=None): """Initialize for decoding.""" if device is None: device = src.device (fn_map_state, _, src_map) = super(GreedySearchLM, self).initialize( None, src_len, src_map, device, target_prefix ) return fn_map_state, src, src_map def advance(self, log_probs, attn): super(GreedySearchLM, self).advance(log_probs, attn) # in LM task src_len is associated with currently generated src # and therefore needs to follow the generation self.src_len += 1