Runner
- class opennmt.Runner(model, config, auto_config=None, mixed_precision=False, jit_compile=False, seed=None)[source]
Class for running and exporting models.
Inherits from:
builtins.object
- __init__(model, config, auto_config=None, mixed_precision=False, jit_compile=False, seed=None)[source]
Initializes the runner parameters.
- Parameters
model – A
opennmt.models.Model
instance to run or a callable that returns such instance.config – The run configuration.
auto_config – If
True
, use automatic configuration values defined bymodel
. If not set, the parameter is read from the run configuration.mixed_precision – Enable mixed precision.
jit_compile – Compile the model with XLA when possible.
seed – The random seed to set.
- Raises
TypeError – if
model
is not aopennmt.models.Model
instance or a callable.
- property model
The
opennmt.models.Model
executed by this runner.
- property model_dir
The active model directory.
- train(num_devices=1, with_eval=False, checkpoint_path=None, hvd=None, return_summary=False, fallback_to_cpu=True, continue_from_checkpoint=False)[source]
Runs the training loop.
- Parameters
num_devices – Number of devices to use for training.
with_eval – Enable evaluation during training.
checkpoint_path – The checkpoint path to load the model weights from.
hvd – Optional Horovod module.
return_summary – Return a summary of the training from this function.
fallback_to_cpu – If no GPU is detected, allow the training to run on CPU.
continue_from_checkpoint – Continue training from the checkpoint passed to
checkpoint_path
. Otherwise only the model weights are loaded.
- Returns
The path to the final model directory and, if
return_summary
is set, a dictionary with various training statistics.
- evaluate(features_file=None, labels_file=None, checkpoint_path=None)[source]
Runs evaluation.
- Parameters
features_file – The input features file to evaluate. If not set, will load
eval_features_file
from the data configuration.labels_file – The output labels file to evaluate. If not set, will load
eval_labels_file
from the data configuration.checkpoint_path – The checkpoint path to load the model weights from.
- Returns
A dict of evaluation metrics.
- average_checkpoints(output_dir, max_count=8, checkpoint_paths=None)[source]
Averages checkpoints.
- Parameters
output_dir – The directory that will contain the averaged checkpoint.
max_count – The maximum number of checkpoints to average.
checkpoint_paths – The list of checkpoints to average. If not set, the last
max_count
checkpoints of the current model directory are averaged.
- Returns
The path to the directory containing the averaged checkpoint.
- update_vocab(output_dir, src_vocab=None, tgt_vocab=None)[source]
Updates model vocabularies.
- Parameters
output_dir – Directory where the update checkpoint will be saved.
src_vocab – Path to the new source vocabulary.
tgt_vocab – Path to the new tagret vocabulary.
- Returns
Path to the new checkpoint directory.
- infer(features_file, predictions_file=None, checkpoint_path=None, log_time=False)[source]
Runs inference.
- Parameters
features_file – The file(s) to infer from.
predictions_file – If set, predictions are saved in this file, otherwise they are printed on the standard output.
checkpoint_path – Path to a specific checkpoint to load. If
None
, the latest is used.log_time – If
True
, several time metrics will be printed in the logs at the end of the inference loop.
- export(export_dir, checkpoint_path=None, exporter=None)[source]
Exports a model.
- Parameters
export_dir – The export directory.
checkpoint_path – The checkpoint path to export. If
None
, the latest is used.exporter – A
opennmt.utils.Exporter
instance. Defaults toopennmt.utils.SavedModelExporter
.
- score(features_file, predictions_file=None, checkpoint_path=None, output_file=None)[source]
Scores existing predictions.
- Parameters
features_file – The input file.
predictions_file – The predictions file to score.
checkpoint_path – Path to specific checkpoint to load. If
None
, the latest is used.output_file – The file where the scores are saved. Otherwise, they will be printed on the standard output.