average_checkpoints
- opennmt.utils.average_checkpoints(model_dir, output_dir, trackables, max_count=8, model_key='model')[source]
Averages object-based checkpoints.
- Parameters
model_dir – The directory containing checkpoints, or a list of checkpoint paths.
output_dir – The directory that will contain the averaged checkpoint.
trackables – A dictionary containing the trackable objects included in the checkpoint.
max_count – The maximum number of checkpoints to average.
model_key – The key in
trackables
that references the model.
- Returns
The path to the directory containing the averaged checkpoint.
- Raises
ValueError – if
output_dir
is the same asmodel_dir
.ValueError – if a model is not found in
trackables
or is not already built.ValueError – if no checkpoints are found in
model_dir
.