#!/usr/bin/env python
""" Translator Class and builder """
import torch
from torch.nn.functional import log_softmax
from torch.nn.utils.rnn import pad_sequence
import codecs
from time import time
from math import exp
from itertools import count, zip_longest
from copy import deepcopy
import onmt.model_builder
import onmt.decoders.ensemble
from onmt.constants import DefaultTokens
from onmt.translate.beam_search import BeamSearch, BeamSearchLM
from onmt.translate.greedy_search import GreedySearch, GreedySearchLM
from onmt.utils.misc import tile, set_random_seed, report_matrix
from onmt.utils.alignment import extract_alignment, build_align_pharaoh
from onmt.modules.copy_generator import collapse_copy_scores
from onmt.constants import ModelTask
from onmt.transforms import TransformPipe
def build_translator(opt, device_id=0, report_score=True, logger=None, out_file=None):
if out_file is None:
out_file = codecs.open(opt.output, "w+", "utf-8")
load_test_model = (
onmt.decoders.ensemble.load_test_model
if len(opt.models) > 1
else onmt.model_builder.load_test_model
)
vocabs, model, model_opt = load_test_model(opt, device_id)
scorer = onmt.translate.GNMTGlobalScorer.from_opt(opt)
if model_opt.model_task == ModelTask.LANGUAGE_MODEL:
translator = GeneratorLM.from_opt(
model,
vocabs,
opt,
model_opt,
global_scorer=scorer,
out_file=out_file,
report_align=opt.report_align,
report_score=report_score,
logger=logger,
)
else:
translator = Translator.from_opt(
model,
vocabs,
opt,
model_opt,
global_scorer=scorer,
out_file=out_file,
report_align=opt.report_align,
report_score=report_score,
logger=logger,
)
return translator
class Inference(object):
"""Translate a batch of sentences with a saved model.
Args:
model (onmt.modules.NMTModel): NMT model to use for translation
vocabs (dict[str, Vocab]): A dict
mapping each side's Vocab.
gpu (int): GPU device. Set to negative for no GPU.
n_best (int): How many beams to wait for.
min_length (int): See
:class:`onmt.translate.decode_strategy.DecodeStrategy`.
max_length (int): See
:class:`onmt.translate.decode_strategy.DecodeStrategy`.
beam_size (int): Number of beams.
random_sampling_topk (int): See
:class:`onmt.translate.greedy_search.GreedySearch`.
random_sampling_temp (float): See
:class:`onmt.translate.greedy_search.GreedySearch`.
stepwise_penalty (bool): Whether coverage penalty is applied every step
or not.
dump_beam (bool): Debugging option.
block_ngram_repeat (int): See
:class:`onmt.translate.decode_strategy.DecodeStrategy`.
ignore_when_blocking (set or frozenset): See
:class:`onmt.translate.decode_strategy.DecodeStrategy`.
replace_unk (bool): Replace unknown token.
tgt_file_prefix (bool): Force the predictions begin with provided -tgt.
data_type (str): Source data type.
verbose (bool): Print/log every translation.
report_time (bool): Print/log total time/frequency.
copy_attn (bool): Use copy attention.
global_scorer (onmt.translate.GNMTGlobalScorer): Translation
scoring/reranking object.
out_file (TextIO or codecs.StreamReaderWriter): Output file.
report_score (bool) : Whether to report scores
logger (logging.Logger or NoneType): Logger.
"""
def __init__(
self,
model,
vocabs,
gpu=-1,
n_best=1,
min_length=0,
max_length=100,
max_length_ratio=1.5,
ratio=0.0,
beam_size=30,
random_sampling_topk=0,
random_sampling_topp=0.0,
random_sampling_temp=1.0,
stepwise_penalty=None,
dump_beam=False,
block_ngram_repeat=0,
ignore_when_blocking=frozenset(),
replace_unk=False,
ban_unk_token=False,
tgt_file_prefix=False,
phrase_table="",
data_type="text",
verbose=False,
report_time=False,
copy_attn=False,
global_scorer=None,
out_file=None,
report_align=False,
gold_align=False,
report_score=True,
logger=None,
seed=-1,
with_score=False,
return_gold_log_probs=False,
):
self.model = model
self.vocabs = vocabs
self._tgt_vocab = vocabs["tgt"]
self._tgt_eos_idx = vocabs["tgt"].lookup_token(DefaultTokens.EOS)
self._tgt_pad_idx = vocabs["tgt"].lookup_token(DefaultTokens.PAD)
self._tgt_bos_idx = vocabs["tgt"].lookup_token(DefaultTokens.BOS)
self._tgt_unk_idx = vocabs["tgt"].lookup_token(DefaultTokens.UNK)
self._tgt_sep_idx = vocabs["tgt"].lookup_token(DefaultTokens.SEP)
self._tgt_start_with = vocabs["tgt"].lookup_token(vocabs["decoder_start_token"])
self._tgt_vocab_len = len(self._tgt_vocab)
self._gpu = gpu
self._use_cuda = gpu > -1
self._dev = (
torch.device("cuda", self._gpu) if self._use_cuda else torch.device("cpu")
)
self.n_best = n_best
self.max_length = max_length
self.max_length_ratio = max_length_ratio
self.beam_size = beam_size
self.random_sampling_temp = random_sampling_temp
self.sample_from_topk = random_sampling_topk
self.sample_from_topp = random_sampling_topp
self.min_length = min_length
self.ban_unk_token = ban_unk_token
self.ratio = ratio
self.stepwise_penalty = stepwise_penalty
self.dump_beam = dump_beam
self.block_ngram_repeat = block_ngram_repeat
self.ignore_when_blocking = ignore_when_blocking
self._exclusion_idxs = {self._tgt_vocab[t] for t in self.ignore_when_blocking}
self.replace_unk = replace_unk
if self.replace_unk and not self.model.decoder.attentional:
raise ValueError("replace_unk requires an attentional decoder.")
self.tgt_file_prefix = tgt_file_prefix
self.phrase_table = phrase_table
self.data_type = data_type
self.verbose = verbose
self.report_time = report_time
self.copy_attn = copy_attn
self.global_scorer = global_scorer
if self.global_scorer.has_cov_pen and not self.model.decoder.attentional:
raise ValueError("Coverage penalty requires an attentional decoder.")
self.out_file = out_file
self.report_align = report_align
self.gold_align = gold_align
self.report_score = report_score
self.logger = logger
self.use_filter_pred = False
self._filter_pred = None
# for debugging
self.beam_trace = self.dump_beam != ""
self.beam_accum = None
if self.beam_trace:
self.beam_accum = {
"predicted_ids": [],
"beam_parent_ids": [],
"scores": [],
"log_probs": [],
}
set_random_seed(seed, self._use_cuda)
self.with_score = with_score
self.return_gold_log_probs = return_gold_log_probs
@classmethod
def from_opt(
cls,
model,
vocabs,
opt,
model_opt,
global_scorer=None,
out_file=None,
report_align=False,
report_score=True,
logger=None,
):
"""Alternate constructor.
Args:
model (onmt.modules.NMTModel): See :func:`__init__()`.
vocabs (dict[str, Vocab]): See
:func:`__init__()`.
opt (argparse.Namespace): Command line options
model_opt (argparse.Namespace): Command line options saved with
the model checkpoint.
global_scorer (onmt.translate.GNMTGlobalScorer): See
:func:`__init__()`..
out_file (TextIO or codecs.StreamReaderWriter): See
:func:`__init__()`.
report_align (bool) : See :func:`__init__()`.
report_score (bool) : See :func:`__init__()`.
logger (logging.Logger or NoneType): See :func:`__init__()`.
"""
# TODO: maybe add dynamic part
cls.validate_task(model_opt.model_task)
return cls(
model,
vocabs,
gpu=opt.gpu,
n_best=opt.n_best,
min_length=opt.min_length,
max_length=opt.max_length,
max_length_ratio=opt.max_length_ratio,
ratio=opt.ratio,
beam_size=opt.beam_size,
random_sampling_topk=opt.random_sampling_topk,
random_sampling_topp=opt.random_sampling_topp,
random_sampling_temp=opt.random_sampling_temp,
stepwise_penalty=opt.stepwise_penalty,
dump_beam=opt.dump_beam,
block_ngram_repeat=opt.block_ngram_repeat,
ignore_when_blocking=set(opt.ignore_when_blocking),
replace_unk=opt.replace_unk,
ban_unk_token=opt.ban_unk_token,
tgt_file_prefix=opt.tgt_file_prefix,
phrase_table=opt.phrase_table,
data_type=opt.data_type,
verbose=opt.verbose,
report_time=opt.report_time,
copy_attn=model_opt.copy_attn,
global_scorer=global_scorer,
out_file=out_file,
report_align=report_align,
gold_align=opt.gold_align,
report_score=report_score,
logger=logger,
seed=opt.seed,
with_score=opt.with_score,
)
def _log(self, msg):
if self.logger:
self.logger.info(msg)
else:
print(msg)
def _gold_score(
self, batch, enc_out, src_len, use_src_map, enc_final_hs, batch_size, src
):
if "tgt" in batch.keys() and not self.tgt_file_prefix:
gs, glp = self._score_target(
batch, enc_out, src_len, batch["src_map"] if use_src_map else None
)
self.model.decoder.init_state(src, enc_out, enc_final_hs)
else:
gs = [0] * batch_size
glp = None
return gs, glp
def _translate(
self,
infer_iter,
transform=None,
attn_debug=False,
align_debug=False,
phrase_table="",
):
"""Translate content of ``src`` and get gold scores from ``tgt``.
Args:
infer_iter: tensored batch iterator from DynamicDatasetIter
attn_debug (bool): enables the attention logging
align_debug (bool): enables the word alignment logging
Returns:
(`list`, `list`)
* all_scores is a list of `batch_size` lists of `n_best` scores
* all_predictions is a list of `batch_size` lists
of `n_best` predictions
"""
transform_pipe = (
TransformPipe.build_from([transform[name] for name in transform])
if transform
else None
)
xlation_builder = onmt.translate.TranslationBuilder(
self.vocabs,
self.n_best,
self.replace_unk,
self.phrase_table,
)
# Statistics
counter = count(1)
pred_score_total, pred_words_total = 0, 0
gold_score_total, gold_words_total = 0, 0
all_scores = []
all_predictions = []
start_time = time()
def _maybe_retranslate(translations, batch):
"""Here we handle the cases of mismatch in number of segments
between source and target. We re-translate seg by seg."""
inds, perm = torch.sort(batch["ind_in_bucket"])
trans_copy = deepcopy(translations)
inserted_so_far = 0
for j, trans in enumerate(translations):
if (trans.src == self._tgt_sep_idx).sum().item() != trans.pred_sents[
0
].count(DefaultTokens.SEP):
self._log("Mismatch in number of ((newline))")
# those two should be the same except feat dim
# batch['src'][perm[j], :, :])
# trans.src
# we rebuild a small batch made of the sub-segments
# in the long segment.
idx = (trans.src == self._tgt_sep_idx).nonzero()
sub_src = []
start_idx = 0
for i in range(len(idx)):
end_idx = idx[i]
sub_src.append(batch["src"][perm[j], start_idx:end_idx, :])
start_idx = end_idx + 1
end_idx = (
batch["src"][perm[j], :, 0].ne(self._tgt_pad_idx).sum() - 1
)
sub_src.append(batch["src"][perm[j], start_idx:end_idx, :])
t_sub_src = pad_sequence(
sub_src, batch_first=True, padding_value=self._tgt_pad_idx
)
t_sub_src_len = t_sub_src[:, :, 0].ne(self._tgt_pad_idx).sum(1)
t_sub_src_ind = torch.tensor(
[i for i in range(len(sub_src))], dtype=torch.int16
)
device = batch["src"].device
t_sub_batch = {
"src": t_sub_src.to(device),
"srclen": t_sub_src_len.to(device),
"ind_in_bucket": t_sub_src_ind.to(device),
}
# new sub-batch ready to be translated
sub_data = self.translate_batch(t_sub_batch, attn_debug)
sub_trans = xlation_builder.from_batch(sub_data)
# we re-insert the sub-batch in the initial translations
trans_copy[j + inserted_so_far] = sub_trans[0]
for i in range(1, len(sub_src)):
trans_copy.insert(j + i + inserted_so_far, sub_trans[i])
inserted_so_far += len(sub_src) - 1
return trans_copy
def _process_bucket(bucket_translations):
bucket_scores = []
bucket_predictions = []
bucket_score = 0
bucket_words = 0
bucket_gold_score = 0
bucket_gold_words = 0
voc_src = self.vocabs["src"].ids_to_tokens
bucket_translations = sorted(
bucket_translations, key=lambda x: x.ind_in_bucket
)
for trans in bucket_translations:
bucket_scores += [trans.pred_scores[: self.n_best]]
bucket_score += trans.pred_scores[0]
bucket_words += len(trans.pred_sents[0])
if "tgt" in batch.keys():
bucket_gold_score += trans.gold_score
bucket_gold_words += len(trans.gold_sent) + 1
n_best_preds = [
" ".join(pred) for pred in trans.pred_sents[: self.n_best]
]
if self.report_align:
align_pharaohs = [
build_align_pharaoh(align)
for align in trans.word_aligns[: self.n_best]
]
n_best_preds_align = [
" ".join(align[0]) for align in align_pharaohs
]
n_best_preds = [
pred + DefaultTokens.ALIGNMENT_SEPARATOR + align
for pred, align in zip(n_best_preds, n_best_preds_align)
]
if transform_pipe is not None:
n_best_preds = transform_pipe.batch_apply_reverse(n_best_preds)
bucket_predictions += [n_best_preds]
if self.with_score:
n_best_scores = [
score.item() for score in trans.pred_scores[: self.n_best]
]
out_all = [
pred + "\t" + str(score)
for (pred, score) in zip(n_best_preds, n_best_scores)
]
self.out_file.write("\n".join(out_all) + "\n")
else:
self.out_file.write("\n".join(n_best_preds) + "\n")
self.out_file.flush()
if self.verbose:
srcs = [voc_src[tok] for tok in trans.src[: trans.srclen]]
sent_number = next(counter)
output = trans.log(sent_number, src_raw=srcs)
self._log(output)
if attn_debug:
preds = trans.pred_sents[0]
preds.append(DefaultTokens.EOS)
attns = trans.attns[0].tolist()
if self.data_type == "text":
srcs = [
voc_src[tok] for tok in trans.src[: trans.srclen].tolist()
]
else:
srcs = [str(item) for item in range(len(attns[0]))]
output = report_matrix(srcs, preds, attns)
self._log(output)
if align_debug:
if self.gold_align:
tgts = trans.gold_sent
else:
tgts = trans.pred_sents[0]
align = trans.word_aligns[0].tolist()
if self.data_type == "text":
srcs = [
voc_src[tok] for tok in trans.src[: trans.srclen].tolist()
]
else:
srcs = [str(item) for item in range(len(align[0]))]
output = report_matrix(srcs, tgts, align)
self._log(output)
return (
bucket_scores,
bucket_predictions,
bucket_score,
bucket_words,
bucket_gold_score,
bucket_gold_words,
)
bucket_translations = []
prev_idx = 0
for batch, bucket_idx in infer_iter:
batch_data = self.translate_batch(batch, attn_debug)
translations = xlation_builder.from_batch(batch_data)
if (
not isinstance(self, GeneratorLM)
and self._tgt_sep_idx != self._tgt_unk_idx
and (batch["src"] == self._tgt_sep_idx).any().item()
):
# For seq2seq when we need to force doc to spit the same number of sents
translations = _maybe_retranslate(translations, batch)
bucket_translations += translations
if (
not isinstance(infer_iter, list)
and len(bucket_translations) >= infer_iter.bucket_size
):
bucket_idx += 1
if bucket_idx != prev_idx:
prev_idx = bucket_idx
(
bucket_scores,
bucket_predictions,
bucket_score,
bucket_words,
bucket_gold_score,
bucket_gold_words,
) = _process_bucket(bucket_translations)
all_scores += bucket_scores
all_predictions += bucket_predictions
pred_score_total += bucket_score
pred_words_total += bucket_words
gold_score_total += bucket_gold_score
gold_words_total += bucket_gold_words
bucket_translations = []
if len(bucket_translations) > 0:
(
bucket_scores,
bucket_predictions,
bucket_score,
bucket_words,
bucket_gold_score,
bucket_gold_words,
) = _process_bucket(bucket_translations)
all_scores += bucket_scores
all_predictions += bucket_predictions
pred_score_total += bucket_score
pred_words_total += bucket_words
gold_score_total += bucket_gold_score
gold_words_total += bucket_gold_words
end_time = time()
if self.report_score:
msg = self._report_score("PRED", pred_score_total, len(all_scores))
self._log(msg)
if "tgt" in batch.keys() and not self.tgt_file_prefix:
msg = self._report_score("GOLD", gold_score_total, len(all_scores))
self._log(msg)
if self.report_time:
total_time = end_time - start_time
self._log("Total translation time (s): %.1f" % total_time)
self._log(
"Average translation time (ms): %.1f"
% (total_time / len(all_predictions) * 1000)
)
self._log("Tokens per second: %.1f" % (pred_words_total / total_time))
if self.dump_beam:
import json
json.dump(
self.translator.beam_accum,
codecs.open(self.dump_beam, "w", "utf-8"),
)
return all_scores, all_predictions
def _score(self, infer_iter):
self.with_scores = True
score_res = []
processed_bucket = {}
prev_bucket_idx = 0
for batch, bucket_idx in infer_iter:
if bucket_idx != prev_bucket_idx:
prev_bucket_idx += 1
score_res += [item for _, item in sorted(processed_bucket.items())]
processed_bucket = {}
batch_data = self.translate_batch(batch, attn_debug=False, scoring=True)
batch_gold_scores = batch_data["gold_score"].cpu().numpy().tolist()
batch_tgt_lengths = batch["tgtlen"].cpu().numpy().tolist()
batch_inds_in_bucket = batch["ind_in_bucket"]
if self.return_gold_log_probs:
batch_gold_log_probs = (
batch_data["gold_log_probs"].cpu().numpy().tolist()
)
else:
batch_gold_log_probs = [
None for i, _ in enumerate(batch_inds_in_bucket)
]
for i, ind in enumerate(batch_inds_in_bucket):
processed_bucket[ind] = [
batch_gold_scores[i],
batch_gold_log_probs[i],
batch_tgt_lengths[i],
]
if processed_bucket:
score_res += [item for _, item in sorted(processed_bucket.items())]
return score_res
def _align_pad_prediction(self, predictions, bos, pad):
"""
Padding predictions in batch and add BOS.
Args:
predictions (List[List[Tensor]]): `(batch, n_best,)`, for each src
sequence contain n_best tgt predictions all of which ended with
eos id.
bos (int): bos index to be used.
pad (int): pad index to be used.
Return:
batched_nbest_predict (torch.LongTensor): `(batch, n_best, tgt_l)`
"""
dtype, device = predictions[0][0].dtype, predictions[0][0].device
flatten_tgt = [best.tolist() for bests in predictions for best in bests]
paded_tgt = torch.tensor(
list(zip_longest(*flatten_tgt, fillvalue=pad)),
dtype=dtype,
device=device,
).T
bos_tensor = torch.full([paded_tgt.size(0), 1], bos, dtype=dtype, device=device)
full_tgt = torch.cat((bos_tensor, paded_tgt), dim=-1)
batched_nbest_predict = full_tgt.view(
len(predictions), -1, full_tgt.size(-1)
) # (batch, n_best, tgt_l)
return batched_nbest_predict
def _report_score(self, name, score_total, nb_sentences):
# In the case of length_penalty = none we report the total logprobs
# divided by the number of sentence to get an approximation of the
# per sentence logprob. We also return the corresponding ppl
# When a length_penalty is used eg: "avg" or "wu" since logprobs
# are normalized per token we report the per line per token logprob
# and the corresponding "per word perplexity"
if nb_sentences == 0:
msg = "%s No translations" % (name,)
else:
score = score_total / nb_sentences
try:
ppl = exp(-score_total / nb_sentences)
except OverflowError:
ppl = float("inf")
msg = "%s SCORE: %.4f, %s PPL: %.2f NB SENTENCES: %d" % (
name,
score,
name,
ppl,
nb_sentences,
)
return msg
def _decode_and_generate(
self,
decoder_in,
enc_out,
batch,
src_len,
src_map=None,
step=None,
batch_offset=None,
return_attn=False,
):
if self.copy_attn:
# Turn any copied words into UNKs.
decoder_in = decoder_in.masked_fill(
decoder_in.gt(self._tgt_vocab_len - 1), self._tgt_unk_idx
)
# Decoder forward, takes [batch, tgt_len, nfeats] as input
# and [batch, src_len, hidden] as enc_out
# in case of inference tgt_len = 1, batch = beam times batch_size
# in case of Gold Scoring tgt_len = actual length, batch = 1 batch
dec_out, dec_attn = self.model.decoder(
decoder_in,
enc_out,
src_len=src_len,
step=step,
return_attn=self.global_scorer.has_cov_pen or return_attn,
)
# Generator forward.
if not self.copy_attn:
if "std" in dec_attn:
attn = dec_attn["std"]
else:
attn = None
scores = self.model.generator(dec_out.squeeze(1))
log_probs = log_softmax(scores, dim=-1) # we keep float16 if FP16
# returns [(batch_size x beam_size) , vocab ] when 1 step
# or [batch_size, tgt_len, vocab ] when full sentence
else:
attn = dec_attn["copy"]
scores = self.model.generator(
dec_out.view(-1, dec_out.size(2)),
attn.view(-1, attn.size(2)),
src_map,
)
# here we have scores [tgt_lenxbatch, vocab] or [beamxbatch, vocab]
if batch_offset is None:
scores = scores.view(-1, len(batch["srclen"]), scores.size(-1))
scores = scores.transpose(0, 1).contiguous()
else:
scores = scores.view(-1, self.beam_size, scores.size(-1))
# at this point scores is batch first (dim=0)
scores = collapse_copy_scores(
scores,
batch,
self._tgt_vocab,
batch_dim=0,
)
scores = scores.view(-1, decoder_in.size(1), scores.size(-1))
log_probs = scores.squeeze(1).log()
# returns [(batch_size x beam_size) , vocab ] when 1 step
# or [batch_size, tgt_len, vocab ] when full sentence
return log_probs, attn
def translate_batch(self, batch, attn_debug):
"""Translate a batch of sentences."""
raise NotImplementedError
def _score_target(self, batch, enc_out, src_len, src_map):
raise NotImplementedError
def report_results(
self,
gold_score,
gold_log_probs,
batch,
batch_size,
decode_strategy,
):
results = {
"predictions": None,
"scores": None,
"attention": None,
"batch": batch,
"gold_score": gold_score,
"gold_log_probs": gold_log_probs,
}
results["scores"] = decode_strategy.scores
results["predictions"] = decode_strategy.predictions
results["attention"] = decode_strategy.attention
if self.report_align:
results["alignment"] = self._align_forward(
batch, decode_strategy.predictions
)
else:
results["alignment"] = [[] for _ in range(batch_size)]
return results
[docs]class Translator(Inference):
@classmethod
def validate_task(cls, task):
if task != ModelTask.SEQ2SEQ:
raise ValueError(
f"Translator does not support task {task}."
f" Tasks supported: {ModelTask.SEQ2SEQ}"
)
def _align_forward(self, batch, predictions):
"""
For a batch of input and its prediction, return a list of batch predict
alignment src indice Tensor in size ``(batch, n_best,)``.
"""
# (0) add BOS and padding to tgt prediction
if "tgt" in batch.keys() and self.gold_align:
self._log("Computing alignments with gold target")
batch_tgt_idxs = batch["tgt"].transpose(1, 2)
else:
batch_tgt_idxs = self._align_pad_prediction(
predictions, bos=self._tgt_bos_idx, pad=self._tgt_pad_idx
)
tgt_mask = (
batch_tgt_idxs.eq(self._tgt_pad_idx)
| batch_tgt_idxs.eq(self._tgt_eos_idx)
| batch_tgt_idxs.eq(self._tgt_bos_idx)
)
n_best = batch_tgt_idxs.size(1)
# (1) Encoder forward.
src, enc_states, enc_out, src_len = self._run_encoder(batch)
# (2) Repeat src objects `n_best` times.
# We use batch_size x n_best, get ``(batch * n_best, src_len, nfeat)``
src = tile(src, n_best, dim=0)
if enc_states is not None:
# Quick fix. Transformers return None as enc_states.
# enc_states are only used later on to init decoder's state
# but are never used in Transformer decoder, so we can skip
enc_states = tile(enc_states, n_best, dim=0)
if isinstance(enc_out, tuple):
enc_out = tuple(tile(x, n_best, dim=0) for x in enc_out)
else:
enc_out = tile(enc_out, n_best, dim=0)
src_len = tile(src_len, n_best) # ``(batch * n_best,)``
# (3) Init decoder with n_best src,
self.model.decoder.init_state(src, enc_out, enc_states)
# reshape tgt to ``(len, batch * n_best, nfeat)``
# it should be done in a better way
tgt = batch_tgt_idxs.view(-1, batch_tgt_idxs.size(-1)).T.unsqueeze(-1)
dec_in = tgt[:-1].transpose(0, 1) # exclude last target from inputs
# here dec_in is batch first
_, attns = self.model.decoder(dec_in, enc_out, src_len=src_len, with_align=True)
alignment_attn = attns["align"] # ``(B, tgt_len-1, src_len)``
# masked_select
align_tgt_mask = tgt_mask.view(-1, tgt_mask.size(-1))
prediction_mask = align_tgt_mask[:, 1:] # exclude bos to match pred
# get aligned src id for each prediction's valid tgt tokens
alignement = extract_alignment(alignment_attn, prediction_mask, src_len, n_best)
return alignement
[docs] def translate_batch(self, batch, attn_debug):
"""Translate a batch of sentences."""
if self.max_length_ratio > 0:
max_length = int(
min(self.max_length, batch["src"].size(1) * self.max_length_ratio + 5)
)
else:
max_length = self.max_length
with torch.no_grad():
if self.sample_from_topk != 0 or self.sample_from_topp != 0:
decode_strategy = GreedySearch(
pad=self._tgt_pad_idx,
bos=self._tgt_bos_idx,
eos=self._tgt_eos_idx,
unk=self._tgt_unk_idx,
start=self._tgt_start_with,
n_best=self.n_best,
batch_size=len(batch["srclen"]),
global_scorer=self.global_scorer,
min_length=self.min_length,
max_length=max_length,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
return_attention=attn_debug or self.replace_unk,
sampling_temp=self.random_sampling_temp,
keep_topk=self.sample_from_topk,
keep_topp=self.sample_from_topp,
beam_size=self.beam_size,
ban_unk_token=self.ban_unk_token,
)
else:
# TODO: support these blacklisted features
assert not self.dump_beam
decode_strategy = BeamSearch(
self.beam_size,
batch_size=len(batch["srclen"]),
pad=self._tgt_pad_idx,
bos=self._tgt_bos_idx,
eos=self._tgt_eos_idx,
unk=self._tgt_unk_idx,
start=self._tgt_start_with,
n_best=self.n_best,
global_scorer=self.global_scorer,
min_length=self.min_length,
max_length=max_length,
return_attention=attn_debug or self.replace_unk,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
stepwise_penalty=self.stepwise_penalty,
ratio=self.ratio,
ban_unk_token=self.ban_unk_token,
)
return self._translate_batch_with_strategy(batch, decode_strategy)
def _run_encoder(self, batch):
src = batch["src"]
src_len = batch["srclen"]
batch_size = len(batch["srclen"])
enc_out, enc_final_hs, src_len = self.model.encoder(src, src_len)
if src_len is None:
assert not isinstance(
enc_out, tuple
), "Ensemble decoding only supported for text data"
src_len = (
torch.Tensor(batch_size).type_as(enc_out).long().fill_(enc_out.size(1))
)
return src, enc_final_hs, enc_out, src_len
def _translate_batch_with_strategy(self, batch, decode_strategy):
"""Translate a batch of sentences step by step using cache.
Args:
batch: a batch of sentences, yield by data iterator.
decode_strategy (DecodeStrategy): A decode strategy to use for
generate translation step by step.
Returns:
results (dict): The translation results.
"""
# (0) Prep the components of the search.
use_src_map = self.copy_attn
parallel_paths = decode_strategy.parallel_paths # beam_size
batch_size = len(batch["srclen"])
# (1) Run the encoder on the src.
src, enc_final_hs, enc_out, src_len = self._run_encoder(batch)
self.model.decoder.init_state(src, enc_out, enc_final_hs)
gold_score, gold_log_probs = self._gold_score(
batch,
enc_out,
src_len,
use_src_map,
enc_final_hs,
batch_size,
src,
)
# (2) prep decode_strategy. Possibly repeat src objects.
src_map = batch["src_map"] if use_src_map else None
target_prefix = batch["tgt"] if self.tgt_file_prefix else None
(fn_map_state, enc_out, src_map) = decode_strategy.initialize(
enc_out, src_len, src_map, target_prefix=target_prefix
)
if fn_map_state is not None:
self.model.decoder.map_state(fn_map_state)
# (3) Begin decoding step by step:
for step in range(decode_strategy.max_length):
decoder_input = decode_strategy.current_predictions.view(-1, 1, 1)
log_probs, attn = self._decode_and_generate(
decoder_input,
enc_out,
batch,
src_len=decode_strategy.src_len,
src_map=src_map,
step=step,
batch_offset=decode_strategy.batch_offset,
return_attn=decode_strategy.return_attention,
)
decode_strategy.advance(log_probs, attn)
any_finished = any(
[any(sublist) for sublist in decode_strategy.is_finished_list]
)
if any_finished:
decode_strategy.update_finished()
if decode_strategy.done:
break
select_indices = decode_strategy.select_indices
if any_finished:
# Reorder states.
if isinstance(enc_out, tuple):
enc_out = tuple(x[select_indices] for x in enc_out)
else:
enc_out = enc_out[select_indices]
if src_map is not None:
src_map = src_map[select_indices]
if parallel_paths > 1 or any_finished:
self.model.decoder.map_state(lambda state, dim: state[select_indices])
return self.report_results(
gold_score,
gold_log_probs,
batch,
batch_size,
decode_strategy,
)
def _score_target(self, batch, enc_out, src_len, src_map):
tgt = batch["tgt"]
tgt_in = tgt[:, :-1, :]
log_probs, attn = self._decode_and_generate(
tgt_in,
enc_out,
batch,
src_len=src_len,
src_map=src_map,
)
log_probs[:, :, self._tgt_pad_idx] = 0
gold = tgt[:, 1:, :]
gold_scores = log_probs.gather(2, gold)
gold_scores = gold_scores.sum(dim=1).view(-1)
return gold_scores, None
class GeneratorLM(Inference):
@classmethod
def validate_task(cls, task):
if task != ModelTask.LANGUAGE_MODEL:
raise ValueError(
f"GeneratorLM does not support task {task}."
f" Tasks supported: {ModelTask.LANGUAGE_MODEL}"
)
def _align_forward(self, batch, predictions):
"""
For a batch of input and its prediction, return a list of batch predict
alignment src indice Tensor in size ``(batch, n_best,)``.
"""
raise NotImplementedError
def translate_batch(self, batch, attn_debug, scoring=False):
"""Translate a batch of sentences."""
max_length = 0 if scoring else self.max_length
with torch.no_grad():
if self.sample_from_topk != 0 or self.sample_from_topp != 0:
decode_strategy = GreedySearchLM(
pad=self._tgt_pad_idx,
bos=self._tgt_bos_idx,
eos=self._tgt_eos_idx,
unk=self._tgt_unk_idx,
start=self._tgt_start_with,
n_best=self.n_best,
batch_size=len(batch["srclen"]),
global_scorer=self.global_scorer,
min_length=self.min_length,
max_length=max_length,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
return_attention=attn_debug or self.replace_unk,
sampling_temp=self.random_sampling_temp,
keep_topk=self.sample_from_topk,
keep_topp=self.sample_from_topp,
beam_size=self.beam_size,
ban_unk_token=self.ban_unk_token,
)
else:
# TODO: support these blacklisted features
assert not self.dump_beam
decode_strategy = BeamSearchLM(
self.beam_size,
batch_size=len(batch["srclen"]),
pad=self._tgt_pad_idx,
bos=self._tgt_bos_idx,
eos=self._tgt_eos_idx,
unk=self._tgt_unk_idx,
start=self._tgt_start_with,
n_best=self.n_best,
global_scorer=self.global_scorer,
min_length=self.min_length,
max_length=max_length,
return_attention=attn_debug or self.replace_unk,
block_ngram_repeat=self.block_ngram_repeat,
exclusion_tokens=self._exclusion_idxs,
stepwise_penalty=self.stepwise_penalty,
ratio=self.ratio,
ban_unk_token=self.ban_unk_token,
)
return self._translate_batch_with_strategy(batch, decode_strategy)
@classmethod
def split_src_to_prevent_padding(cls, src, src_len):
min_len_batch = torch.min(src_len).item()
target_prefix = None
if min_len_batch > 0 and min_len_batch < src.size(1):
target_prefix = src[:, min_len_batch:, :]
src = src[:, :min_len_batch, :]
src_len[:] = min_len_batch
return src, src_len, target_prefix
def tile_to_beam_size_after_initial_step(self, fn_map_state, log_probs):
if fn_map_state is not None:
log_probs = fn_map_state(log_probs, dim=0)
self.model.decoder.map_state(fn_map_state)
log_probs = log_probs[:, -1, :]
return log_probs
def _translate_batch_with_strategy(self, batch, decode_strategy, left_pad=True):
"""Translate a batch of sentences step by step using cache.
Args:
batch: a batch of sentences, yield by data iterator.
decode_strategy (DecodeStrategy): A decode strategy to use for
generate translation step by step.
Returns:
results (dict): The translation results.
"""
# (0) Prep the components of the search.
use_src_map = self.copy_attn
parallel_paths = decode_strategy.parallel_paths # beam_size
batch_size = len(batch["srclen"])
# (1) split src into src and target_prefix to avoid padding.
src = batch["src"]
src_len = batch["srclen"]
if left_pad:
target_prefix = None
else:
src, src_len, target_prefix = self.split_src_to_prevent_padding(
src, src_len
)
# (2) init decoder
self.model.decoder.init_state(src, None, None)
gold_score, gold_log_probs = self._gold_score(
batch, None, src_len, use_src_map, None, batch_size, src
)
# (3) prep decode_strategy. Possibly repeat src objects.
src_map = batch["src_map"] if use_src_map else None
(fn_map_state, src, src_map) = decode_strategy.initialize(
src,
src_len,
src_map,
target_prefix=target_prefix,
)
# (4) Begin decoding step by step:
# beg_time = time()
for step in range(decode_strategy.max_length):
decoder_input = (
src if step == 0 else decode_strategy.current_predictions.view(-1, 1, 1)
)
log_probs, attn = self._decode_and_generate(
decoder_input,
None,
batch,
src_len=decode_strategy.src_len,
src_map=src_map,
step=step if step == 0 else step + max(src_len.tolist()),
batch_offset=decode_strategy.batch_offset,
)
if step == 0:
log_probs = self.tile_to_beam_size_after_initial_step(
fn_map_state, log_probs
)
decode_strategy.advance(log_probs, attn)
any_finished = any(
[any(sublist) for sublist in decode_strategy.is_finished_list]
)
if any_finished:
decode_strategy.update_finished()
if decode_strategy.done:
break
select_indices = decode_strategy.select_indices
if any_finished:
# Reorder states.
if src_map is not None:
src_map = src_map[select_indices]
if parallel_paths > 1 or any_finished:
# select indexes in model state/cache
self.model.decoder.map_state(lambda state, dim: state[select_indices])
# if step == 0:
# print("step0 time: ", time() - beg_time)
return self.report_results(
gold_score,
gold_log_probs,
batch,
batch_size,
decode_strategy,
)
def _score_target(self, batch, enc_out, src_len, src_map):
src = batch["src"]
src_len = batch["srclen"]
tgt = batch["tgt"]
log_probs, attn = self._decode_and_generate(
src,
None,
batch,
src_len=src_len,
src_map=src_map,
)
log_probs[:, :, self._tgt_pad_idx] = 0
gold_log_probs = log_probs.gather(2, tgt)
gold_scores = gold_log_probs.sum(dim=1).view(-1)
if self.return_gold_log_probs:
return gold_scores, gold_log_probs
return gold_scores, None