""" Translation main class """
import os
import torch
from onmt.constants import DefaultTokens
from onmt.inputters.text_dataset import TextMultiField
from onmt.utils.alignment import build_align_pharaoh

[docs]class TranslationBuilder(object): """ Build a word-based translation from the batch output of translator and the underlying dictionaries. Replacement based on "Addressing the Rare Word Problem in Neural Machine Translation" :cite:`Luong2015b` Args: data (onmt.inputters.Dataset): Data. fields (List[Tuple[str,]]): data fields n_best (int): number of translations produced replace_unk (bool): replace unknown words using attention has_tgt (bool): will the batch have gold targets """ def __init__(self, data, fields, n_best=1, replace_unk=False, has_tgt=False, phrase_table=""): = data self.fields = fields self._has_text_src = isinstance( dict(self.fields)["src"], TextMultiField) self.n_best = n_best self.replace_unk = replace_unk self.phrase_table_dict = {} if phrase_table != "" and os.path.exists(phrase_table): with open(phrase_table) as phrase_table_fd: for line in phrase_table_fd: phrase_src, phrase_trg = line.rstrip("\n").split( DefaultTokens.PHRASE_TABLE_SEPARATOR) self.phrase_table_dict[phrase_src] = phrase_trg self.has_tgt = has_tgt def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn): tgt_field = dict(self.fields)["tgt"].base_field vocab = tgt_field.vocab tokens = [] for tok in pred: if tok < len(vocab): tokens.append(vocab.itos[tok]) else: tokens.append(src_vocab.itos[tok - len(vocab)]) if tokens[-1] == tgt_field.eos_token: tokens = tokens[:-1] break if self.replace_unk and attn is not None and src is not None: for i in range(len(tokens)): if tokens[i] == tgt_field.unk_token: _, max_index = attn[i][:len(src_raw)].max(0) tokens[i] = src_raw[max_index.item()] if self.phrase_table_dict: src_tok = src_raw[max_index.item()] if src_tok in self.phrase_table_dict: tokens[i] = self.phrase_table_dict[src_tok] return tokens def from_batch(self, translation_batch): batch = translation_batch["batch"] assert (len(translation_batch["gold_score"]) == len(translation_batch["predictions"])) batch_size = batch.batch_size preds, pred_score, attn, align, gold_score, indices = list(zip( *sorted(zip(translation_batch["predictions"], translation_batch["scores"], translation_batch["attention"], translation_batch["alignment"], translation_batch["gold_score"],, key=lambda x: x[-1]))) if not any(align): # when align is a empty nested list align = [None] * batch_size # Sorting inds, perm = torch.sort(batch.indices) if self._has_text_src: src = batch.src[0][:, :, 0].index_select(1, perm) else: src = None tgt = batch.tgt[:, :, 0].index_select(1, perm) \ if self.has_tgt else None translations = [] for b in range(batch_size): if self._has_text_src: src_vocab =[inds[b]] \ if else None src_raw =[inds[b]].src[0] else: src_vocab = None src_raw = None pred_sents = [self._build_target_tokens( src[:, b] if src is not None else None, src_vocab, src_raw, preds[b][n], align[b][n] if align[b] is not None else attn[b][n]) for n in range(self.n_best)] gold_sent = None if tgt is not None: gold_sent = self._build_target_tokens( src[:, b] if src is not None else None, src_vocab, src_raw, tgt[1:, b] if tgt is not None else None, None) translation = Translation( src[:, b] if src is not None else None, src_raw, pred_sents, attn[b], pred_score[b], gold_sent, gold_score[b], align[b] ) translations.append(translation) return translations
[docs]class Translation(object): """Container for a translated sentence. Attributes: src (LongTensor): Source word IDs. src_raw (List[str]): Raw source words. pred_sents (List[List[str]]): Words from the n-best translations. pred_scores (List[List[float]]): Log-probs of n-best translations. attns (List[FloatTensor]) : Attention distribution for each translation. gold_sent (List[str]): Words from gold translation. gold_score (List[float]): Log-prob of gold translation. word_aligns (List[FloatTensor]): Words Alignment distribution for each translation. """ __slots__ = ["src", "src_raw", "pred_sents", "attns", "pred_scores", "gold_sent", "gold_score", "word_aligns"] def __init__(self, src, src_raw, pred_sents, attn, pred_scores, tgt_sent, gold_score, word_aligns): self.src = src self.src_raw = src_raw self.pred_sents = pred_sents self.attns = attn self.pred_scores = pred_scores self.gold_sent = tgt_sent self.gold_score = gold_score self.word_aligns = word_aligns
[docs] def log(self, sent_number): """ Log translation. """ msg = ['\nSENT {}: {}\n'.format(sent_number, self.src_raw)] best_pred = self.pred_sents[0] best_score = self.pred_scores[0] pred_sent = ' '.join(best_pred) msg.append('PRED {}: {}\n'.format(sent_number, pred_sent)) msg.append("PRED SCORE: {:.4f}\n".format(best_score)) if self.word_aligns is not None: pred_align = self.word_aligns[0] pred_align_pharaoh = build_align_pharaoh(pred_align) pred_align_sent = ' '.join(pred_align_pharaoh) msg.append("ALIGN: {}\n".format(pred_align_sent)) if self.gold_sent is not None: tgt_sent = ' '.join(self.gold_sent) msg.append('GOLD {}: {}\n'.format(sent_number, tgt_sent)) msg.append(("GOLD SCORE: {:.4f}\n".format(self.gold_score))) if len(self.pred_sents) > 1: msg.append('\nBEST HYP:\n') for score, sent in zip(self.pred_scores, self.pred_sents): msg.append("[{:.4f}] {}\n".format(score, sent)) return "".join(msg)