""" Translation main class """
import os
from onmt.constants import DefaultTokens
from onmt.utils.alignment import build_align_pharaoh
[docs]class TranslationBuilder(object):
"""
Build a word-based translation from the batch output
of translator and the underlying dictionaries.
Replacement based on "Addressing the Rare Word
Problem in Neural Machine Translation" :cite:`Luong2015b`
Args:
data ():
vocabs ():
n_best (int): number of translations produced
replace_unk (bool): replace unknown words using attention
"""
def __init__(self, vocabs, n_best=1, replace_unk=False, phrase_table=""):
self.vocabs = vocabs
self.n_best = n_best
self.replace_unk = replace_unk
self.phrase_table_dict = {}
if phrase_table != "" and os.path.exists(phrase_table):
with open(phrase_table) as phrase_table_fd:
for line in phrase_table_fd:
phrase_src, phrase_trg = line.rstrip("\n").split(
DefaultTokens.PHRASE_TABLE_SEPARATOR
)
self.phrase_table_dict[phrase_src] = phrase_trg
def _build_target_tokens(self, src, srclen, pred, attn, voc, dyn_voc):
if dyn_voc is None:
tokens = [voc[tok] for tok in pred.tolist()]
else:
tokens = [
voc[tok]
if tok < len(voc)
else dyn_voc.ids_to_tokens[tok - len(self.vocabs["src"].ids_to_tokens)]
for tok in pred.tolist()
]
if tokens[-1] == DefaultTokens.EOS:
tokens = tokens[:-1]
if self.replace_unk and attn is not None and src is not None:
for i in range(len(tokens)):
if tokens[i] == DefaultTokens.UNK:
_, max_index = attn[i][:srclen].max(0)
src_tok = self.vocabs["src"].ids_to_tokens[src[max_index.item()]]
tokens[i] = src_tok
if self.phrase_table_dict:
if src_tok in self.phrase_table_dict:
tokens[i] = self.phrase_table_dict[src_tok]
return tokens
def from_batch(self, translation_batch):
batch = translation_batch["batch"]
if "src_ex_vocab" in batch.keys():
dyn_voc_batch = batch["src_ex_vocab"]
else:
dyn_voc_batch = None
assert len(translation_batch["gold_score"]) == len(
translation_batch["predictions"]
)
batch_size = len(batch["srclen"])
preds, pred_score, attn, align, gold_score, ind = (
translation_batch["predictions"],
translation_batch["scores"],
translation_batch["attention"],
translation_batch["alignment"],
translation_batch["gold_score"],
batch["ind_in_bucket"],
)
if not any(align): # when align is a empty nested list
align = [None] * batch_size
src = batch["src"][:, :, 0]
srclen = batch["srclen"][:]
if "tgt" in batch.keys():
tgt = batch["tgt"][:, :, 0]
else:
tgt = None
translations = []
voc_tgt = self.vocabs["tgt"].ids_to_tokens
# These comp lists are costy but less than for loops
for b in range(batch_size):
if dyn_voc_batch is not None:
dyn_voc = dyn_voc_batch[b]
else:
dyn_voc = None
pred_sents = [
self._build_target_tokens(
src[b, :] if src is not None else None,
srclen[b],
preds[b][n],
align[b][n] if align[b] is not None else attn[b][n],
voc_tgt,
dyn_voc,
)
for n in range(self.n_best)
]
gold_sent = None
if tgt is not None:
gold_sent = self._build_target_tokens(
src[b, :] if src is not None else None,
srclen[b],
tgt[b, 1:] if tgt is not None else None,
None,
voc_tgt,
dyn_voc,
)
translation = Translation(
src[b, :] if src is not None else None,
srclen[b],
pred_sents,
attn[b],
pred_score[b],
gold_sent,
gold_score[b],
align[b],
ind[b],
)
translations.append(translation)
return translations
[docs]class Translation(object):
"""Container for a translated sentence.
Attributes:
src (LongTensor): Source word IDs.
srclen (List[int]): Source lengths.
pred_sents (List[List[str]]): Words from the n-best translations.
pred_scores (List[List[float]]): Log-probs of n-best translations.
attns (List[FloatTensor]) : Attention distribution for each
translation.
gold_sent (List[str]): Words from gold translation.
gold_score (List[float]): Log-prob of gold translation.
word_aligns (List[FloatTensor]): Words Alignment distribution for
each translation.
"""
__slots__ = [
"src",
"srclen",
"pred_sents",
"attns",
"pred_scores",
"gold_sent",
"gold_score",
"word_aligns",
"ind_in_bucket",
]
def __init__(
self,
src,
srclen,
pred_sents,
attn,
pred_scores,
tgt_sent,
gold_score,
word_aligns,
ind_in_bucket,
):
self.src = src
self.srclen = srclen
self.pred_sents = pred_sents
self.attns = attn
self.pred_scores = pred_scores
self.gold_sent = tgt_sent
self.gold_score = gold_score
self.word_aligns = word_aligns
self.ind_in_bucket = ind_in_bucket
[docs] def log(self, sent_number, src_raw=""):
"""
Log translation.
"""
msg = ["\nSENT {}: {}\n".format(sent_number, src_raw)]
best_pred = self.pred_sents[0]
best_score = self.pred_scores[0]
pred_sent = " ".join(best_pred)
msg.append("PRED {}: {}\n".format(sent_number, pred_sent))
msg.append("PRED SCORE: {:.4f}\n".format(best_score))
if self.word_aligns is not None:
pred_align = self.word_aligns[0]
pred_align_pharaoh, _ = build_align_pharaoh(pred_align)
pred_align_sent = " ".join(pred_align_pharaoh)
msg.append("ALIGN: {}\n".format(pred_align_sent))
if self.gold_sent is not None:
tgt_sent = " ".join(self.gold_sent)
msg.append("GOLD {}: {}\n".format(sent_number, tgt_sent))
msg.append(("GOLD SCORE: {:.4f}\n".format(self.gold_score)))
if len(self.pred_sents) > 1:
msg.append("\nBEST HYP:\n")
for score, sent in zip(self.pred_scores, self.pred_sents):
msg.append("[{:.4f}] {}\n".format(score, sent))
return "".join(msg)