opennmt.utils.average_checkpoints

opennmt.utils.average_checkpoints(model_dir, output_dir, trackables, max_count=8, model_key='model')[source]

Averages object-based checkpoints.

Parameters
  • model_dir – The directory containing checkpoints.

  • output_dir – The directory that will contain the averaged checkpoint.

  • trackables – A dictionary containing the trackable objects included in the checkpoint.

  • max_count – The maximum number of checkpoints to average.

  • model_key – The key in trackables that references the model.

Returns

The path to the directory containing the averaged checkpoint.

Raises
  • ValueError – if output_dir is the same as model_dir.

  • ValueError – if a model is not found in trackables or is not already built.

  • ValueError – if no checkpoints are found in model_dir.