"""Checkpoint utilities."""
import copy
import os
import tempfile
import tensorflow as tf
from opennmt.utils import misc
class Checkpoint:
"""Wrapper around TensorFlow checkpoints utilities."""
def __init__(self, model, optimizer=None, model_dir=None, keep_checkpoint_max=8):
"""Initializes the wrapper.
Args:
model: A :class:`opennmt.models.Model` to save.
optimizer: The optimizer instance.
model_dir: The directory where checkpoints will be saved. If not set, a
temporary directory will be used.
keep_checkpoint_max: The maximum number of checkpoints to keep.
"""
if model_dir is None:
model_dir = tempfile.mkdtemp()
trackables = {}
trackables["model"] = model
if optimizer is not None:
trackables["optimizer"] = optimizer
self._model = model
self._optimizer = optimizer
self._model_dir = model_dir
self._checkpoint = tf.train.Checkpoint(**trackables)
self._checkpoint_manager = tf.train.CheckpointManager(
self._checkpoint, model_dir, keep_checkpoint_max
)
@classmethod
def from_config(cls, config, model, optimizer=None):
"""Creates a checkpoint wrapper from the configuration.
Args:
config: The user configuration.
model: A :class:`opennmt.models.Model` to save.
optimizer: The optimizer instance.
Returns:
A :class:`opennmt.utils.Checkpoint` instance.
"""
train_config = config.get("train")
if train_config is None:
train_config = {}
keep_checkpoint_max = max(
train_config.get("keep_checkpoint_max", 8),
train_config.get("average_last_checkpoints", 0),
)
return cls(
model,
optimizer=optimizer,
model_dir=config.get("model_dir"),
keep_checkpoint_max=keep_checkpoint_max,
)
@property
def model(self):
"""The managed model."""
return self._model
@property
def optimizer(self):
"""The managed optimizer."""
return self._optimizer
@property
def model_dir(self):
"""The model directory."""
return self._model_dir
@property
def last_saved_step(self):
"""The last training step that was saved."""
latest_checkpoint = self._checkpoint_manager.latest_checkpoint
if latest_checkpoint is None:
return None
return get_step_from_checkpoint_prefix(latest_checkpoint)
def save(self, step=None):
"""Saves a checkpoint.
Args:
step: The step to save for. If ``None``, get the value from ``optimizer.iterations``.
Returns:
The path to the saved checkpoint.
"""
if step is None:
step = self._optimizer.iterations
path = self._checkpoint_manager.save(checkpoint_number=step)
tf.get_logger().info("Saved checkpoint %s", path)
return path
def restore(self, checkpoint_path=None, weights_only=False):
"""Restores a checkpoint.
Args:
checkpoint_path: Path a checkpoint to restore. If not set, the latest
checkpoint from :obj:`model_dir` will be restored.
weights_only: Only restore model weights.
Returns:
Path to the restored checkpoint.
"""
if weights_only:
checkpoint = tf.train.Checkpoint(model=self._model)
else:
checkpoint = self._checkpoint
if checkpoint_path is not None:
if tf.io.gfile.isdir(checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
elif self._checkpoint_manager.latest_checkpoint is not None:
checkpoint_path = self._checkpoint_manager.latest_checkpoint
if checkpoint_path is None:
tf.get_logger().warning("No checkpoint to restore in %s", self._model_dir)
return None
if is_v1_checkpoint(checkpoint_path):
tf.get_logger().info("Upgrading V1 checkpoint...")
# Work with copies of model and optimizer as the downstream task might
# need to create the variable differently (e.g. under a distribution
# strategy scope).
tmp_model = misc.clone_layer(self._model)
tmp_optimizer = (
copy.deepcopy(self._optimizer) if self._optimizer is not None else None
)
tmp_model.create_variables(optimizer=tmp_optimizer)
step = _restore_v1_checkpoint(
checkpoint_path, tmp_model, optimizer=tmp_optimizer
)
# Save an updated checkpoint in the model directory and restore this one instead.
tmp_checkpoint = Checkpoint(
tmp_model, optimizer=tmp_optimizer, model_dir=self._model_dir
)
checkpoint_path = tmp_checkpoint.save(step)
return self.restore(
checkpoint_path=checkpoint_path, weights_only=weights_only
)
load_status = checkpoint.restore(checkpoint_path)
load_status.expect_partial()
tf.get_logger().info("Restored checkpoint %s", checkpoint_path)
return checkpoint_path
def get_step_from_checkpoint_prefix(prefix):
"""Extracts the training step from the checkpoint file prefix."""
return int(prefix.split("-")[-1])
[docs]def is_v1_checkpoint(checkpoint_path):
"""Returns ``True`` if the checkpoint at :obj:`checkpoint_path` has been
trained with OpenNMT-tf v1.
"""
if tf.io.gfile.isdir(checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
return os.path.basename(checkpoint_path).startswith("model")
def get_checkpoint_variables(checkpoint_path):
"""Returns variables included in a checkpoint.
Args:
checkpoint_path: Path to the checkpoint.
Returns:
A dictionary mapping variables name to value.
"""
reader = tf.train.load_checkpoint(checkpoint_path)
return {
name: reader.get_tensor(name)
for name in reader.get_variable_to_shape_map().keys()
}
[docs]def average_checkpoints(
model_dir, output_dir, trackables, max_count=8, model_key="model"
):
"""Averages object-based checkpoints.
Args:
model_dir: The directory containing checkpoints, or a list of checkpoint paths.
output_dir: The directory that will contain the averaged checkpoint.
trackables: A dictionary containing the trackable objects included in the
checkpoint.
max_count: The maximum number of checkpoints to average.
model_key: The key in :obj:`trackables` that references the model.
Returns:
The path to the directory containing the averaged checkpoint.
Raises:
ValueError: if :obj:`output_dir` is the same as :obj:`model_dir`.
ValueError: if a model is not found in :obj:`trackables` or is not already
built.
ValueError: if no checkpoints are found in :obj:`model_dir`.
See Also:
:func:`opennmt.utils.average_checkpoints_into_layer`
"""
model = trackables.get(model_key)
if model is None:
raise ValueError("%s not found in trackables %s" % (model_key, trackables))
if isinstance(model_dir, list):
checkpoints_path = list(sorted(model_dir, key=get_step_from_checkpoint_prefix))
else:
if model_dir == output_dir:
raise ValueError("Model and output directory must be different")
checkpoint_state = tf.train.get_checkpoint_state(model_dir)
if checkpoint_state is None:
raise ValueError("No checkpoints found in %s" % model_dir)
checkpoints_path = checkpoint_state.all_model_checkpoint_paths
if len(checkpoints_path) > max_count:
checkpoints_path = checkpoints_path[-max_count:]
average_checkpoints_into_layer(checkpoints_path, model, model_key)
last_step = get_step_from_checkpoint_prefix(checkpoints_path[-1])
checkpoint = tf.train.Checkpoint(**trackables)
new_checkpoint_manager = tf.train.CheckpointManager(
checkpoint, output_dir, max_to_keep=None
)
path = new_checkpoint_manager.save(checkpoint_number=last_step)
tf.get_logger().info("Saved averaged checkpoint to %s", path)
return output_dir
[docs]def average_checkpoints_into_layer(checkpoints, layer, layer_prefix):
"""Updates the layer weights with their average value in the checkpoints.
Args:
checkpoints: A non empty list of checkpoint paths.
layer: A ``tf.keras.layers.Layer`` instance.
layer_prefix: The name/scope that prefixes the layer variables names in the
checkpoints.
Raises:
ValueError: if :obj:`checkpoints` is empty.
ValueError: if :obj:`layer` is not already built.
See Also:
:func:`opennmt.utils.average_checkpoints`
"""
if not checkpoints:
raise ValueError("There should be at least one checkpoint")
if not layer.built:
raise ValueError("The layer should be built before calling this function")
# Reset the layer variables to 0.
for variable in layer.variables:
variable.assign(tf.zeros_like(variable))
# Get a map from variable names in the checkpoint to variables in the layer.
names_to_variables = misc.get_variables_name_mapping(layer, layer_prefix)
num_checkpoints = len(checkpoints)
tf.get_logger().info("Averaging %d checkpoints...", num_checkpoints)
for checkpoint_path in checkpoints:
tf.get_logger().info("Reading checkpoint %s...", checkpoint_path)
reader = tf.train.load_checkpoint(checkpoint_path)
for path in reader.get_variable_to_shape_map().keys():
if not path.startswith(layer_prefix) or ".OPTIMIZER_SLOT" in path:
continue
variable = names_to_variables[path]
value = reader.get_tensor(path)
variable.assign_add(value / num_checkpoints)
_V1_OPTIM_SCOPE = "optim"
_V1_SLOTS_MAPPING = {"Adam": "m", "Adam_1": "v"}
def _restore_v1_checkpoint(checkpoint_path, model, optimizer=None):
v1_variables = get_checkpoint_variables(checkpoint_path)
v1_structure = _variables_to_structure(v1_variables)
step = v1_structure["global_step"]
if optimizer is not None:
optimizer.iterations.assign(step)
if _V1_OPTIM_SCOPE in v1_structure:
slots = v1_structure[_V1_OPTIM_SCOPE]
del v1_structure[_V1_OPTIM_SCOPE]
v1_structure = _merge_optimizer_slots(v1_structure, slots)
mapping = model.map_v1_weights(v1_structure)
existing_variables = set(variable.ref() for variable in model.variables)
mapped_variables = set(variable.ref() for variable, _ in mapping)
missing_mapping = existing_variables.difference(mapped_variables)
if missing_mapping:
raise ValueError(
"The following variables were not mapped: %s"
% (", ".join(var.name for var in missing_mapping))
)
# Assign each variable and possibly the optimizer slots.
for v2_variable, v1_variable in mapping:
if isinstance(v1_variable, tuple):
v1_variable, v1_slots = v1_variable
else:
v1_slots = None
v2_variable.assign(v1_variable)
if v1_slots is not None:
for slot_name, value in v1_slots.items():
v2_slot = optimizer.get_slot(v2_variable, slot_name)
v2_slot.assign(value)
return step
def _variables_to_structure(variables):
"""Represents variables a nested dictionary with scope names as keys."""
structure = {}
for name, value in variables.items():
fields = name.split("/")
cur = structure
for i, key in enumerate(fields):
if key not in cur:
if i + 1 == len(fields):
cur[key] = value
break
cur[key] = {}
cur = cur[key]
return structure
def _merge_optimizer_slots(variables, slots):
"""Replaces leaves in the variables structure by tuples of
(variable, dict of optimizer slots).
"""
if isinstance(variables, dict):
merged = {}
for key, value in variables.items():
if key not in slots:
merged[key] = copy.deepcopy(value)
else:
merged[key] = _merge_optimizer_slots(value, slots[key])
return merged
else:
new_slots = {}
for name, value in slots.items():
name = _V1_SLOTS_MAPPING.get(name)
if name is None:
# Just ignore the optimizer slots if their name is not listed.
return variables
new_slots[name] = value
return (variables, new_slots)