#!/usr/bin/env python
import codecs
import sys
import os
import time
import json
import threading
import re
import traceback
import importlib
import torch
import onmt.opts
from itertools import islice, zip_longest
from copy import deepcopy
from argparse import Namespace
from onmt.constants import DefaultTokens
from onmt.utils.logging import init_logger
from onmt.utils.misc import set_random_seed
from onmt.utils.misc import check_model_config
from onmt.utils.alignment import to_word_align
from onmt.utils.parse import ArgumentParser
from onmt.translate.translator import build_translator
from onmt.transforms.features import InferFeatsTransform
from onmt.inputters.text_utils import (
textbatch_to_tensor,
parse_features,
append_features_to_text,
)
from onmt.utils.alignment import build_align_pharaoh
def critical(func):
"""Decorator for critical section (mutually exclusive code)"""
def wrapper(server_model, *args, **kwargs):
if sys.version_info[0] == 3:
if not server_model.running_lock.acquire(True, 120):
raise ServerModelError(
"Model %d running lock timeout" % server_model.model_id
)
else:
# semaphore doesn't have a timeout arg in Python 2.7
server_model.running_lock.acquire(True)
try:
o = func(server_model, *args, **kwargs)
except (Exception, RuntimeError):
server_model.running_lock.release()
raise
server_model.running_lock.release()
return o
return wrapper
[docs]class Timer:
def __init__(self, start=False):
self.stime = -1
self.prev = -1
self.times = {}
if start:
self.start()
def start(self):
self.stime = time.time()
self.prev = self.stime
self.times = {}
def tick(self, name=None, tot=False):
t = time.time()
if not tot:
elapsed = t - self.prev
else:
elapsed = t - self.stime
self.prev = t
if name is not None:
self.times[name] = elapsed
return elapsed
[docs]class ServerModelError(Exception):
pass
class CTranslate2Translator(object):
"""This class wraps the ``ctranslate2.Translator`` object to
reproduce the ``onmt.translate.translator`` API."""
def __init__(
self,
model_path,
ct2_translator_args,
ct2_translate_batch_args,
target_prefix=False,
preload=False,
report_align=False,
):
import ctranslate2
self.translator = ctranslate2.Translator(model_path, **ct2_translator_args)
self.ct2_translate_batch_args = ct2_translate_batch_args
self.target_prefix = target_prefix
self.report_align = report_align
if preload:
# perform a first request to initialize everything
dummy_translation = self.translate([{"src": {"src": "a"}}])
print(
"Performed a dummy translation to initialize the model",
dummy_translation,
)
time.sleep(1)
self.translator.unload_model(to_cpu=True)
@staticmethod
def convert_onmt_to_ct2_opts(ct2_translator_args, ct2_translate_batch_args, opt):
def setdefault_if_exists_must_match(obj, name, value):
if name in obj:
assert value == obj[name], (
f"{name} is different in"
" OpenNMT-py config and in CTranslate2 config"
f" ({value} vs {obj[name]})"
)
else:
obj.setdefault(name, value)
default_for_translator = {
"inter_threads": 1,
"intra_threads": torch.get_num_threads(),
"compute_type": "default",
}
for name, value in default_for_translator.items():
ct2_translator_args.setdefault(name, value)
onmt_for_translator = {
"device": "cuda" if opt.cuda else "cpu",
"device_index": opt.gpu if opt.cuda else 0,
}
for name, value in onmt_for_translator.items():
setdefault_if_exists_must_match(ct2_translator_args, name, value)
onmt_for_translate_batch_enforce = {
"beam_size": opt.beam_size,
"max_batch_size": opt.batch_size,
"num_hypotheses": opt.n_best,
"max_decoding_length": opt.max_length,
"min_decoding_length": opt.min_length,
}
for name, value in onmt_for_translate_batch_enforce.items():
setdefault_if_exists_must_match(ct2_translate_batch_args, name, value)
def translate(self, examples, batch_size=8, tgt=None):
if "feats" in examples[0]["src"]:
batch = [
append_features_to_text(ex["src"]["src"], ex["src"]["feats"]).split(" ")
for ex in examples
]
else:
batch = [ex["src"]["src"].split(" ") for ex in examples]
if tgt is not None:
tgt = [item.split(" ") for item in tgt]
if self.report_align:
self.ct2_translate_batch_args["return_attention"] = True
preds = self.translator.translate_batch(
batch,
target_prefix=tgt if self.target_prefix else None,
return_scores=True,
**self.ct2_translate_batch_args,
)
scores = [[item["score"] for item in ex] for ex in preds]
predictions = [[" ".join(item["tokens"]) for item in ex] for ex in preds]
if self.report_align:
attentions = [
[torch.Tensor(item["attention"]) for item in ex] for ex in preds
]
align_pharaohs = [
[build_align_pharaoh(item) for item in ex] for ex in attentions
]
aligns = [[" ".join(item[0]) for item in ex] for ex in align_pharaohs]
align_scores = [[" ".join(item[1]) for item in ex] for ex in align_pharaohs]
predictions = [
[
pred
+ DefaultTokens.ALIGNMENT_SEPARATOR
+ align
+ DefaultTokens.ALIGNMENT_SEPARATOR
+ align_score
for pred, align, align_score in zip(*item)
]
for item in zip(predictions, aligns, align_scores)
]
return scores, predictions
def to_cpu(self):
self.translator.unload_model(to_cpu=True)
def to_gpu(self):
self.translator.load_model()
def parse_features_opts(conf):
features_opt = conf.get("features", None)
if features_opt is not None:
features_opt["n_src_feats"] = features_opt.get("n_src_feats", 0)
features_opt["src_feats_defaults"] = features_opt.get(
"src_feats_defaults", None
)
features_opt["reversible_tokenization"] = features_opt.get(
"reversible_tokenization", "joiner"
)
return features_opt
[docs]class TranslationServer(object):
def __init__(self):
self.models = {}
self.next_id = 0
[docs] def start(self, config_file):
"""Read the config file and pre-/load the models."""
self.config_file = config_file
with open(self.config_file) as f:
self.confs = json.load(f)
self.models_root = self.confs.get("models_root", "./available_models")
for i, conf in enumerate(self.confs["models"]):
if "models" not in conf:
if "model" in conf:
# backwards compatibility for confs
conf["models"] = [conf["model"]]
else:
raise ValueError(
"""Incorrect config file: missing 'models'
parameter for model #%d"""
% i
)
check_model_config(conf, self.models_root)
kwargs = {
"timeout": conf.get("timeout", None),
"load": conf.get("load", None),
"preprocess_opt": conf.get("preprocess", None),
"tokenizer_opt": conf.get("tokenizer", None),
"postprocess_opt": conf.get("postprocess", None),
"custom_opt": conf.get("custom_opt", None),
"on_timeout": conf.get("on_timeout", None),
"model_root": conf.get("model_root", self.models_root),
"ct2_model": conf.get("ct2_model", None),
"ct2_translator_args": conf.get("ct2_translator_args", {}),
"ct2_translate_batch_args": conf.get("ct2_translate_batch_args", {}),
"features_opt": parse_features_opts(conf),
}
kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
model_id = conf.get("id", None)
opt = conf["opt"]
opt["models"] = conf["models"]
self.preload_model(opt, model_id=model_id, **kwargs)
[docs] def clone_model(self, model_id, opt, timeout=-1):
"""Clone a model ``model_id``
Different options may be passed. If ``opt`` is None, it will use the
same set of options"""
if model_id in self.models:
if opt is None:
opt = self.models[model_id].user_opt
opt["models"] = self.models[model_id].opt.models
return self.load_model(opt, timeout)
else:
raise ServerModelError("No such model '%s'" % str(model_id))
[docs] def load_model(self, opt, model_id=None, **model_kwargs):
"""Load a model given a set of options"""
model_id = self.preload_model(opt, model_id=model_id, **model_kwargs)
load_time = self.models[model_id].load_time
return model_id, load_time
[docs] def preload_model(self, opt, model_id=None, **model_kwargs):
"""Preloading the model: updating internal datastructure
It will effectively load the model if ``load`` is set"""
if model_id is not None:
if model_id in self.models.keys():
raise ValueError("Model ID %d already exists" % model_id)
else:
model_id = self.next_id
while model_id in self.models.keys():
model_id += 1
self.next_id = model_id + 1
print("Pre-loading model %d" % model_id)
model = ServerModel(opt, model_id, **model_kwargs)
self.models[model_id] = model
return model_id
[docs] def run(self, inputs):
"""Translate ``inputs``
We keep the same format as the Lua version i.e.
``[{"id": model_id, "src": "sequence to translate"},{ ...}]``
We use inputs[0]["id"] as the model id"""
model_id = inputs[0].get("id", 0)
if model_id in self.models and self.models[model_id] is not None:
return self.models[model_id].run(inputs)
else:
print("Error No such model '%s'" % str(model_id))
raise ServerModelError("No such model '%s'" % str(model_id))
[docs] def unload_model(self, model_id):
"""Manually unload a model.
It will free the memory and cancel the timer
"""
if model_id in self.models and self.models[model_id] is not None:
self.models[model_id].unload()
else:
raise ServerModelError("No such model '%s'" % str(model_id))
[docs] def list_models(self):
"""Return the list of available models"""
models = []
for _, model in self.models.items():
models += [model.to_dict()]
return models
[docs]class ServerModel(object):
"""Wrap a model with server functionality.
Args:
opt (dict): Options for the Translator
model_id (int): Model ID
preprocess_opt (list): Options for preprocess processus or None
tokenizer_opt (dict): Options for the tokenizer or None
postprocess_opt (list): Options for postprocess processus or None
custom_opt (dict): Custom options, can be used within preprocess or
postprocess, default None
load (bool): whether to load the model during :func: ``__init__()``
timeout (int): Seconds before running :func: ``do_timeout()``
Negative values means no timeout
on_timeout (str): Options are [to_cpu, unload]. Set what to do on
timeout (see :func: ``do_timeout()``.)
model_root (str): Path to the model directory
it must contain the model and tokenizer file"""
def __init__(
self,
opt,
model_id,
preprocess_opt=None,
tokenizer_opt=None,
postprocess_opt=None,
custom_opt=None,
load=False,
timeout=-1,
on_timeout="to_cpu",
model_root="./",
ct2_model=None,
ct2_translator_args=None,
ct2_translate_batch_args=None,
features_opt=None,
):
self.model_root = model_root
self.opt = self.parse_opt(opt)
self.custom_opt = custom_opt
self.model_id = model_id
self.preprocess_opt = preprocess_opt
self.tokenizers_opt = tokenizer_opt
self.features_opt = features_opt
self.postprocess_opt = postprocess_opt
self.timeout = timeout
self.on_timeout = on_timeout
self.ct2_model = (
os.path.join(model_root, ct2_model) if ct2_model is not None else None
)
self.ct2_translator_args = ct2_translator_args
self.ct2_translate_batch_args = ct2_translate_batch_args
self.unload_timer = None
self.user_opt = opt
self.tokenizers = None
if len(self.opt.log_file) > 0:
log_file = os.path.join(model_root, self.opt.log_file)
else:
log_file = None
self.logger = init_logger(
log_file=log_file, log_file_level=self.opt.log_file_level, rotate=True
)
self.loading_lock = threading.Event()
self.loading_lock.set()
self.running_lock = threading.Semaphore(value=1)
set_random_seed(self.opt.seed, self.opt.cuda)
if self.preprocess_opt is not None:
self.logger.info("Loading preprocessor")
self.preprocessor = []
for function_path in self.preprocess_opt:
function = get_function_by_path(function_path)
self.preprocessor.append(function)
if self.tokenizers_opt is not None:
if "src" in self.tokenizers_opt and "tgt" in self.tokenizers_opt:
self.logger.info("Loading src & tgt tokenizer")
self.tokenizers = {
"src": self.build_tokenizer(tokenizer_opt["src"]),
"tgt": self.build_tokenizer(tokenizer_opt["tgt"]),
}
else:
self.logger.info("Loading tokenizer")
self.tokenizers_opt = {"src": tokenizer_opt, "tgt": tokenizer_opt}
tokenizer = self.build_tokenizer(tokenizer_opt)
self.tokenizers = {"src": tokenizer, "tgt": tokenizer}
self.feats_transform = None
if self.features_opt is not None:
self.feats_transform = InferFeatsTransform(Namespace(**self.features_opt))
if self.postprocess_opt is not None:
self.logger.info("Loading postprocessor")
self.postprocessor = []
for function_path in self.postprocess_opt:
function = get_function_by_path(function_path)
self.postprocessor.append(function)
if load:
self.load(preload=True)
self.stop_unload_timer()
[docs] def parse_opt(self, opt):
"""Parse the option set passed by the user using ``onmt.opts``
Args:
opt (dict): Options passed by the user
Returns:
opt (argparse.Namespace): full set of options for the Translator
"""
prec_argv = sys.argv
sys.argv = sys.argv[:1]
parser = ArgumentParser()
onmt.opts.translate_opts(parser)
models = opt["models"]
if not isinstance(models, (list, tuple)):
models = [models]
opt["models"] = [os.path.join(self.model_root, model) for model in models]
opt["src"] = "dummy_src"
for k, v in opt.items():
if k == "models":
sys.argv += ["-model"]
sys.argv += [str(model) for model in v]
elif type(v) == bool:
sys.argv += ["-%s" % k]
else:
sys.argv += ["-%s" % k, str(v)]
opt = parser.parse_args()
ArgumentParser.validate_translate_opts(opt)
opt.cuda = opt.gpu > -1
sys.argv = prec_argv
return opt
@property
def loaded(self):
return hasattr(self, "translator")
def load(self, preload=False):
self.loading_lock.clear()
timer = Timer()
self.logger.info("Loading model %d" % self.model_id)
timer.start()
try:
if self.ct2_model is not None:
CTranslate2Translator.convert_onmt_to_ct2_opts(
self.ct2_translator_args, self.ct2_translate_batch_args, self.opt
)
self.translator = CTranslate2Translator(
self.ct2_model,
ct2_translator_args=self.ct2_translator_args,
ct2_translate_batch_args=self.ct2_translate_batch_args,
target_prefix=self.opt.tgt_file_prefix,
preload=preload,
report_align=self.opt.report_align,
)
else:
self.translator = build_translator(
self.opt,
report_score=False,
out_file=codecs.open(os.devnull, "w", "utf-8"),
)
except RuntimeError as e:
raise ServerModelError("Runtime Error: %s" % str(e))
timer.tick("model_loading")
self.load_time = timer.tick()
self.reset_unload_timer()
self.loading_lock.set()
@critical
def run(self, inputs):
"""Translate ``inputs`` using this model
Args:
inputs (List[dict[str, str]]): [{'src': '...'},{'src': '...'}]
Returns:
result (list): translations
times (dict): containing times"""
self.stop_unload_timer()
timer = Timer()
timer.start()
self.logger.info("Running translation using %d" % self.model_id)
if not self.loading_lock.is_set():
self.logger.info(
"Model #%d is being loaded by another thread, waiting" % self.model_id
)
if not self.loading_lock.wait(timeout=30):
raise ServerModelError("Model %d loading timeout" % self.model_id)
else:
if not self.loaded:
self.load()
timer.tick(name="load")
elif self.opt.cuda:
self.to_gpu()
timer.tick(name="to_gpu")
texts = []
head_spaces = []
tail_spaces = []
all_preprocessed = []
for i, inp in enumerate(inputs):
src = inp["src"]
whitespaces_before, whitespaces_after = "", ""
match_before = re.search(r"^\s+", src)
match_after = re.search(r"\s+$", src)
if match_before is not None:
whitespaces_before = match_before.group(0)
if match_after is not None:
whitespaces_after = match_after.group(0)
head_spaces.append(whitespaces_before)
# every segment becomes a dict for flexibility purposes
seg_dict = self.maybe_preprocess(inp)
all_preprocessed.append(seg_dict)
for seg, ref, feats in zip_longest(
seg_dict["seg"], seg_dict["ref"], seg_dict["src_feats"]
):
tok = self.maybe_tokenize(seg)
if ref is not None:
ref = self.maybe_tokenize(ref, side="tgt")
feats = self.maybe_transform_feats(seg, tok, feats)
texts.append((tok, ref, feats))
tail_spaces.append(whitespaces_after)
empty_indices = []
examples = []
for i, (tok, ref_tok, feats) in enumerate(texts):
if tok == "":
empty_indices.append(i)
else:
ex = {
"src": {"src": tok},
"tgt": {"tgt": ref_tok} if ref_tok is not None else None,
}
if feats is not None:
ex["src"]["feats"] = feats
examples.append(ex)
scores = []
predictions = []
if len(examples) > 0:
try:
if isinstance(self.translator, CTranslate2Translator):
scores, predictions = self.translator.translate(examples)
else:
device_id = (
self.translator._dev.index if self.translator._use_cuda else -1
)
device = (
torch.device(device_id)
if device_id >= 0
else torch.device("cpu")
)
infer_iter = textbatch_to_tensor(
self.translator.vocabs, examples, device
)
scores, predictions = self.translator._translate(infer_iter)
except (RuntimeError, Exception) as e:
err = "Error: %s" % str(e)
self.logger.error(err)
self.logger.error("repr(examples): " + repr(examples))
self.logger.error("model: #%s" % self.model_id)
self.logger.error("model opt: " + str(self.opt.__dict__))
self.logger.error(traceback.format_exc())
raise ServerModelError(err)
timer.tick(name="translation")
self.logger.info(
"""Using model #%d\t%d inputs
\ttranslation time: %f"""
% (self.model_id, len(texts), timer.times["translation"])
)
self.reset_unload_timer()
# NOTE: translator returns lists of `n_best` list
def flatten_list(_list):
return sum(_list, [])
tiled_texts = [
ex["src"]["src"] for ex in examples for _ in range(self.opt.n_best)
]
results = flatten_list(predictions)
def maybe_item(x):
return x.item() if type(x) is torch.Tensor else x
scores = [maybe_item(score_tensor) for score_tensor in flatten_list(scores)]
results = [
self.maybe_detokenize_with_align(result, src)
for result, src in zip(results, tiled_texts)
]
aligns = [align[0] if align is not None else None for _, align in results]
align_scores = [align[1] if align is not None else None for _, align in results]
results = [tokens for tokens, _ in results]
# build back results with empty texts
for i in empty_indices:
j = i * self.opt.n_best
results = results[:j] + [""] * self.opt.n_best + results[j:]
aligns = aligns[:j] + [None] * self.opt.n_best + aligns[j:]
align_scores = (
align_scores[:j] + [None] * self.opt.n_best + align_scores[j:]
)
scores = scores[:j] + [0] * self.opt.n_best + scores[j:]
rebuilt_segs, scores, aligns, align_scores = self.rebuild_seg_packages(
all_preprocessed, results, scores, aligns, align_scores, self.opt.n_best
)
results = [self.maybe_postprocess(seg) for seg in rebuilt_segs]
head_spaces = [h for h in head_spaces for i in range(self.opt.n_best)]
tail_spaces = [h for h in tail_spaces for i in range(self.opt.n_best)]
results = ["".join(items) for items in zip(head_spaces, results, tail_spaces)]
self.logger.info("Translation Results: %d", len(results))
return (results, scores, self.opt.n_best, timer.times, aligns, align_scores)
[docs] def rebuild_seg_packages(
self, all_preprocessed, results, scores, aligns, align_scores, n_best
):
"""Rebuild proper segment packages based on initial n_seg."""
offset = 0
rebuilt_segs = []
avg_scores = []
merged_aligns = []
merged_align_scores = []
for i, seg_dict in enumerate(all_preprocessed):
n_seg = seg_dict["n_seg"]
sub_results = results[n_best * offset : (offset + n_seg) * n_best]
sub_scores = scores[n_best * offset : (offset + n_seg) * n_best]
sub_aligns = aligns[n_best * offset : (offset + n_seg) * n_best]
sub_align_scores = align_scores[n_best * offset : (offset + n_seg) * n_best]
for j in range(n_best):
_seg_dict = deepcopy(seg_dict)
_seg_dict["seg"] = list(islice(sub_results, j, None, n_best))
rebuilt_segs.append(_seg_dict)
sub_sub_scores = list(islice(sub_scores, j, None, n_best))
avg_score = sum(sub_sub_scores) / n_seg if n_seg != 0 else 0
avg_scores.append(avg_score)
sub_sub_aligns = list(islice(sub_aligns, j, None, n_best))
merged_aligns.append(sub_sub_aligns)
sub_sub_align_scores = list(islice(sub_align_scores, j, None, n_best))
merged_align_scores.append(sub_sub_align_scores)
offset += n_seg
return rebuilt_segs, avg_scores, merged_aligns, merged_align_scores
[docs] def do_timeout(self):
"""Timeout function that frees GPU memory.
Moves the model to CPU or unloads it; depending on
attr ``self.on_timemout`` value"""
if self.on_timeout == "unload":
self.logger.info("Timeout: unloading model %d" % self.model_id)
self.unload()
if self.on_timeout == "to_cpu":
self.logger.info("Timeout: sending model %d to CPU" % self.model_id)
self.to_cpu()
@critical
def unload(self):
self.logger.info("Unloading model %d" % self.model_id)
del self.translator
if self.opt.cuda:
torch.cuda.empty_cache()
self.stop_unload_timer()
self.unload_timer = None
def stop_unload_timer(self):
if self.unload_timer is not None:
self.unload_timer.cancel()
def reset_unload_timer(self):
if self.timeout < 0:
return
self.stop_unload_timer()
self.unload_timer = threading.Timer(self.timeout, self.do_timeout)
self.unload_timer.start()
def to_dict(self):
hide_opt = ["models", "src"]
d = {
"model_id": self.model_id,
"opt": {
k: self.user_opt[k] for k in self.user_opt.keys() if k not in hide_opt
},
"models": self.user_opt["models"],
"loaded": self.loaded,
"timeout": self.timeout,
}
if self.tokenizers_opt is not None:
d["tokenizer"] = self.tokenizers_opt
return d
@critical
def to_cpu(self):
"""Move the model to CPU and clear CUDA cache."""
if type(self.translator) == CTranslate2Translator:
self.translator.to_cpu()
else:
self.translator.model.cpu()
if self.opt.cuda:
torch.cuda.empty_cache()
[docs] def to_gpu(self):
"""Move the model to GPU."""
if type(self.translator) == CTranslate2Translator:
self.translator.to_gpu()
else:
torch.cuda.set_device(self.opt.gpu)
self.translator.model.cuda()
[docs] def maybe_preprocess(self, sequence):
"""Preprocess the sequence (or not)"""
if sequence.get("src", None) is not None:
sequence = deepcopy(sequence)
src, src_feats = parse_features(
sequence["src"].strip(),
n_feats=(
self.features_opt["n_src_feats"]
if self.features_opt is not None
else 0
),
defaults=(
self.features_opt["src_feats_defaults"]
if self.features_opt is not None
else None
),
)
sequence["seg"] = [src]
sequence.pop("src")
sequence["ref"] = [sequence.get("ref", None)]
sequence["src_feats"] = [src_feats]
sequence["n_seg"] = 1
if self.preprocess_opt is not None:
return self.preprocess(sequence)
return sequence
[docs] def preprocess(self, sequence):
"""Preprocess a single sequence.
Args:
sequence (str): The sequence to preprocess.
Returns:
sequence (str): The preprocessed sequence."""
if self.preprocessor is None:
raise ValueError("No preprocessor loaded")
for function in self.preprocessor:
sequence = function(sequence, self)
return sequence
[docs] def build_tokenizer(self, tokenizer_opt):
"""Build tokenizer described by ``tokenizer_opt``."""
if "type" not in tokenizer_opt:
raise ValueError("Missing mandatory tokenizer option 'type'")
if tokenizer_opt["type"] == "sentencepiece":
if "model" not in tokenizer_opt:
raise ValueError("Missing mandatory tokenizer option 'model'")
import sentencepiece as spm
tokenizer = spm.SentencePieceProcessor()
model_path = os.path.join(self.model_root, tokenizer_opt["model"])
tokenizer.Load(model_path)
elif tokenizer_opt["type"] == "pyonmttok":
if "params" not in tokenizer_opt:
raise ValueError("Missing mandatory tokenizer option 'params'")
import pyonmttok
if tokenizer_opt["mode"] is not None:
mode = tokenizer_opt["mode"]
else:
mode = None
# load can be called multiple times: modify copy
tokenizer_params = dict(tokenizer_opt["params"])
for key, value in tokenizer_opt["params"].items():
if key.endswith("path"):
tokenizer_params[key] = os.path.join(self.model_root, value)
tokenizer = pyonmttok.Tokenizer(mode, **tokenizer_params)
else:
raise ValueError("Invalid value for tokenizer type")
return tokenizer
[docs] def maybe_tokenize(self, sequence, side="src"):
"""Tokenize the sequence (or not).
Same args/returns as ``tokenize``"""
if self.tokenizers_opt is not None:
return self.tokenize(sequence, side)
return sequence
[docs] def tokenize(self, sequence, side="src"):
"""Tokenize a single sequence.
Args:
sequence (str): The sequence to tokenize.
Returns:
tok (str): The tokenized sequence."""
if self.tokenizers is None:
raise ValueError("No tokenizer loaded")
if self.tokenizers_opt[side]["type"] == "sentencepiece":
tok = self.tokenizers[side].EncodeAsPieces(sequence)
tok = " ".join(tok)
elif self.tokenizers_opt[side]["type"] == "pyonmttok":
tok, _ = self.tokenizers[side].tokenize(sequence)
tok = " ".join(tok)
return tok
[docs] def tokenizer_marker(self, side="src"):
"""Return marker used in ``side`` tokenizer."""
marker = None
if self.tokenizers_opt is not None:
tokenizer_type = self.tokenizers_opt[side].get("type", None)
if tokenizer_type == "pyonmttok":
params = self.tokenizers_opt[side].get("params", None)
if params is not None:
if params.get("joiner_annotate", None) is not None:
marker = "joiner"
elif params.get("spacer_annotate", None) is not None:
marker = "spacer"
elif tokenizer_type == "sentencepiece":
marker = "spacer"
return marker
[docs] def maybe_detokenize_with_align(self, sequence, src, side="tgt"):
"""De-tokenize (or not) the sequence (with alignment).
Args:
sequence (str): The sequence to detokenize, possible with
alignment seperate by '|||'
Returns:
sequence (str): The detokenized sequence.
align (str): The alignment correspand to detokenized src/tgt
sorted or None if no alignment in output."""
align = None
if self.opt.report_align:
# output contain alignment
sequence, align, align_scores = sequence.split(
DefaultTokens.ALIGNMENT_SEPARATOR
)
if align != "":
align = self.maybe_convert_align(src, sequence, align, align_scores)
sequence = self.maybe_detokenize(sequence, side)
return (sequence, align)
[docs] def maybe_detokenize(self, sequence, side="tgt"):
"""De-tokenize the sequence (or not)
Same args/returns as :func:``tokenize()``"""
if self.tokenizers_opt is not None and "".join(sequence.split(" ")) != "":
return self.detokenize(sequence, side)
return sequence
[docs] def detokenize(self, sequence, side="tgt"):
"""Detokenize a single sequence
Same args/returns as :func:``tokenize()``"""
if self.tokenizers is None:
raise ValueError("No tokenizer loaded")
if self.tokenizers_opt[side]["type"] == "sentencepiece":
detok = self.tokenizers[side].DecodePieces(sequence.split(" "))
elif self.tokenizers_opt[side]["type"] == "pyonmttok":
detok = self.tokenizers[side].detokenize(sequence.split(" "))
return detok
[docs] def maybe_convert_align(self, src, tgt, align, align_scores):
"""Convert alignment to match detokenized src/tgt (or not).
Args:
src (str): The tokenized source sequence.
tgt (str): The tokenized target sequence.
align (str): The alignment correspand to src/tgt pair.
Returns:
align (str): The alignment correspand to detokenized src/tgt.
"""
if self.tokenizers_opt is not None:
src_marker = self.tokenizer_marker(side="src")
tgt_marker = self.tokenizer_marker(side="tgt")
if src_marker is None or tgt_marker is None:
raise ValueError(
"To get decoded alignment, joiner/spacer "
"should be used in both side's tokenizer."
)
elif "".join(tgt.split(" ")) != "":
align = to_word_align(
src, tgt, align, align_scores, src_marker, tgt_marker
)
return align
[docs] def maybe_postprocess(self, sequence):
"""Postprocess the sequence (or not)"""
if self.postprocess_opt is not None:
return self.postprocess(sequence)
else:
return sequence["seg"][0]
[docs] def postprocess(self, sequence):
"""Preprocess a single sequence.
Args:
sequence (str): The sequence to process.
Returns:
sequence (str): The postprocessed sequence."""
if self.postprocessor is None:
raise ValueError("No postprocessor loaded")
for function in self.postprocessor:
sequence = function(sequence, self)
return sequence
def get_function_by_path(path, args=[], kwargs={}):
module_name = ".".join(path.split(".")[:-1])
function_name = path.split(".")[-1]
try:
module = importlib.import_module(module_name)
except ValueError as e:
print("Cannot import module '%s'" % module_name)
raise e
function = getattr(module, function_name)
return function