opennmt.utils.checkpoint module

Checkpoint utilities.

opennmt.utils.checkpoint.get_checkpoint_variables(checkpoint_path)[source]

Returns variables included in a checkpoint.

Parameters:checkpoint_path – Path to the checkpoint.
Returns:A dictionary mapping variables name to value.
opennmt.utils.checkpoint.convert_checkpoint(checkpoint_path, output_dir, source_dtype, target_type, session_config=None)[source]

Converts checkpoint variables from one dtype to another.

Parameters:
  • checkpoint_path – The path to the checkpoint to convert.
  • output_dir – The directory that will contain the converted checkpoint.
  • source_dtype – The data type to convert from.
  • target_dtype – The data type to convert to.
  • session_config – Optional configuration to use when creating the session.
Returns:

The path to the directory containing the converted checkpoint.

Raises:

ValueError – if output_dir points to the same directory as checkpoint_path.

opennmt.utils.checkpoint.update_vocab(model_dir, output_dir, current_src_vocab, current_tgt_vocab, new_src_vocab=None, new_tgt_vocab=None, mode='merge', session_config=None)[source]

Updates the last checkpoint to support new vocabularies.

This allows to add new words to a model while keeping the previously learned weights.

Parameters:
  • model_dir – The directory containing checkpoints (the most recent will be loaded).
  • output_dir – The directory that will contain the converted checkpoint.
  • current_src_vocab – Path to the source vocabulary currently use in the model.
  • current_tgt_vocab – Path to the target vocabulary currently use in the model.
  • new_src_vocab – Path to the new source vocabulary to support.
  • new_tgt_vocab – Path to the new target vocabulary to support.
  • mode – Update mode: “merge” keeps all existing words and adds new words, “replace” makes the new vocabulary file the active one. In all modes, if an existing word appears in the new vocabulary, its learned weights are kept in the converted checkpoint.
  • session_config – Optional configuration to use when creating the session.
Returns:

The path to the directory containing the converted checkpoint.

Raises:

ValueError – if output_dir is the same as model_dir or if mode is invalid.

opennmt.utils.checkpoint.average_checkpoints(model_dir, output_dir, max_count=8, session_config=None)[source]

Averages checkpoints.

Parameters:
  • model_dir – The directory containing checkpoints.
  • output_dir – The directory that will contain the averaged checkpoint.
  • max_count – The maximum number of checkpoints to average.
  • session_config – Configuration to use when creating the session.
Returns:

The path to the directory containing the averaged checkpoint.

Raises:

ValueError – if output_dir is the same as model_dir.