Source code for opennmt.optimizers.utils

"""Optimization utilities."""

import inspect

import tensorflow as tf
import tensorflow_addons as tfa

from packaging.version import Version
from tensorflow_addons.optimizers.weight_decay_optimizers import (
    DecoupledWeightDecayExtension,
)

from opennmt.utils import misc

if Version(tf.__version__) >= Version("2.11.0"):
    tf_optimizers = tf.keras.optimizers.legacy
else:
    tf_optimizers = tf.keras.optimizers

_OPTIMIZERS_REGISTRY = misc.ClassRegistry(
    base_class=getattr(tf_optimizers, "Optimizer")
)

register_optimizer = _OPTIMIZERS_REGISTRY.register


def get_optimizer_class(name):
    """Returns the optimizer class.

    Args:
      name: The optimizer name.

    Returns:
      A class extending ``tf.keras.optimizers.legacy.Optimizer``.

    Raises:
      ValueError: if :obj:`name` can not be resolved to an optimizer class.
    """
    optimizer_class = None
    if optimizer_class is None:
        optimizer_class = getattr(tf_optimizers, name, None)
    if optimizer_class is None:
        optimizer_class = getattr(tfa.optimizers, name, None)
    if optimizer_class is None:
        optimizer_class = _OPTIMIZERS_REGISTRY.get(name)
    if optimizer_class is None:
        raise ValueError("Unknown optimizer class: %s" % name)
    return optimizer_class


[docs]def make_optimizer(name, learning_rate, **kwargs): """Creates the optimizer. Args: name: The name of the optimizer class in ``tf.keras.optimizers.legacy`` or ``tfa.optimizers`` as a string. learning_rate: The learning rate or learning rate schedule to use. **kwargs: Additional optimizer arguments. If ``weight_decay`` is set, the optimizer will be extended with decoupled weight decay. Returns: A ``tf.keras.optimizers.legacy.Optimizer`` instance. Raises: ValueError: if :obj:`name` can not be resolved to an optimizer class. """ optimizer_class = get_optimizer_class(name) if "weight_decay" in kwargs: if DecoupledWeightDecayExtension not in inspect.getmro(optimizer_class): optimizer_class = tfa.optimizers.extend_with_decoupled_weight_decay( optimizer_class ) optimizer = optimizer_class(learning_rate=learning_rate, **kwargs) return optimizer
class GradientAccumulator: """Gradient accumulation utility. When used with a distribution strategy, the accumulator should be called in a replica context. Gradients will be accumulated locally on each replica and without synchronization. Users should then call ``.gradients``, scale the gradients if required, and pass the result to ``apply_gradients``. """ # We use the ON_READ synchronization policy so that no synchronization is # performed on assignment. To get the value, we call .value() which returns the # value on the current replica without synchronization. def __init__(self): """Initializes the accumulator.""" self._gradients = [] self._accum_steps = None @property def step(self): """Number of accumulated steps.""" if self._accum_steps is None: self._accum_steps = tf.Variable( tf.constant(0, dtype=tf.int64), trainable=False, synchronization=tf.VariableSynchronization.ON_READ, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) return self._accum_steps.value() @property def gradients(self): """The accumulated gradients on the current replica.""" if not self._gradients: raise ValueError( "The accumulator should be called first to initialize the gradients" ) return list(gradient.value() for gradient in self._gradients) def __call__(self, gradients): """Accumulates :obj:`gradients` on the current replica.""" if not self._gradients: _ = self.step # Create the step variable. self._gradients.extend( [ tf.Variable( tf.zeros_like(gradient), trainable=False, synchronization=tf.VariableSynchronization.ON_READ, ) for gradient in gradients ] ) if len(gradients) != len(self._gradients): raise ValueError( "Expected %s gradients, but got %d" % (len(self._gradients), len(gradients)) ) for accum_gradient, gradient in zip(self._gradients, gradients): accum_gradient.assign_add(gradient, read_value=False) self._accum_steps.assign_add(1) def reset(self): """Resets the accumulated gradients on the current replica.""" if not self._gradients: return self._accum_steps.assign(0) for gradient in self._gradients: shape = ( gradient.shape if gradient.shape.is_fully_defined() else tf.shape(gradient) ) gradient.assign(tf.zeros(shape, dtype=gradient.dtype), read_value=False)