Source code for opennmt.schedules.lr_schedules

"""Define learning rate decay functions."""

import inspect

import numpy as np
import tensorflow as tf

from opennmt.utils import misc

_LR_SCHEDULES_REGISTRY = misc.ClassRegistry(
    base_class=tf.keras.optimizers.schedules.LearningRateSchedule
)

register_learning_rate_schedule = _LR_SCHEDULES_REGISTRY.register


def get_lr_schedule_class(name):
    """Returns the learning rate schedule class.

    Args:
      name: The schedule class name.

    Returns:
      A class extending ``tf.keras.optimizers.schedules.LearningRateSchedule``.

    Raises:
      ValueError: if :obj:`name` can not be resolved to an existing schedule.
    """
    schedule_class = None
    if schedule_class is None:
        schedule_class = getattr(tf.keras.optimizers.schedules, name, None)
    if schedule_class is None:
        schedule_class = _LR_SCHEDULES_REGISTRY.get(name)
    if schedule_class is None:
        raise ValueError("Unknown learning rate schedule: %s" % name)
    return schedule_class


[docs]def make_learning_rate_schedule( initial_learning_rate, schedule_type, schedule_params=None, schedule_step_duration=1, start_step=0, minimum_learning_rate=0, ): """Creates the learning rate schedule. Args: initial_learning_rate: The initial learning rate value. This can be ``None`` if the learning rate is fully defined by the schedule. schedule_type: The type of learning rate schedule. A class name from ``tf.keras.optimizers.schedules`` or :mod:`opennmt.schedules` as a string. schedule_params: Additional parameters passed to the schedule constructor. schedule_step_duration: The number of training steps that make 1 schedule step. start_step: Start the schedule after this many steps. minimum_learning_rate: Do not decay past this learning rate value. Returns: A ``tf.keras.optimizers.schedules.LearningRateSchedule`` instance. Raises: ValueError: if :obj:`schedule_type` can not be resolved to an existing schedule. See Also: :class:`opennmt.schedules.ScheduleWrapper` """ if schedule_params is None: schedule_params = {} schedule_class = get_lr_schedule_class(schedule_type) first_arg = inspect.getfullargspec(schedule_class)[0][1] if first_arg not in schedule_params: schedule_params[first_arg] = initial_learning_rate schedule = schedule_class(**schedule_params) schedule = ScheduleWrapper( schedule, step_start=start_step, step_duration=schedule_step_duration, minimum_learning_rate=minimum_learning_rate, ) return schedule
[docs]class ScheduleWrapper(tf.keras.optimizers.schedules.LearningRateSchedule): """Wrapper to augment a learning rate scheduler behavior."""
[docs] def __init__( self, schedule, step_start=0, step_duration=1, minimum_learning_rate=0 ): """Initializes the decay function. Args: schedule: A ``tf.keras.optimizers.schedules.LearningRateSchedule``. step_duration: The number of training steps that make 1 decay step. start_step: Start decay after this many steps. minimum_learning_rate: Do not decay past this learning rate value. See Also: :class:`opennmt.schedules.make_learning_rate_schedule` """ self.schedule = schedule self.step_start = step_start self.step_duration = step_duration self.minimum_learning_rate = minimum_learning_rate
def __call__(self, step): # Map the training step to a decay step. step = tf.maximum(step - self.step_start, 0) step //= self.step_duration learning_rate = self.schedule(step) return tf.maximum(learning_rate, self.minimum_learning_rate)
[docs]@register_learning_rate_schedule class NoamDecay(tf.keras.optimizers.schedules.LearningRateSchedule): """Defines the decay function described in https://arxiv.org/abs/1706.03762."""
[docs] def __init__(self, scale, model_dim, warmup_steps): """Initializes the decay function. Args: scale: The scale constant. model_dim: The model dimension. warmup_steps: The number of warmup steps. """ self.scale = tf.cast(scale, tf.float32) self.model_dim = tf.cast(model_dim, tf.float32) self.warmup_steps = tf.cast(warmup_steps, tf.float32)
def __call__(self, step): step = tf.cast(step + 1, tf.float32) return ( self.scale * tf.pow(self.model_dim, -0.5) * tf.minimum(tf.pow(step, -0.5), step * tf.pow(self.warmup_steps, -1.5)) )
[docs]@register_learning_rate_schedule class RsqrtDecay(tf.keras.optimizers.schedules.LearningRateSchedule): r"""Decay based on the reciprocal of the step square root. This corresponds to ``rsqrt_decay`` in Tensor2Tensor. .. math:: \text{schedule}(\text{step}) = \frac{\text{scale}} {\sqrt{\max(\text{step},\text{warmup_steps})}} See also: - :class:`opennmt.schedules.InvSqrtDecay` """
[docs] def __init__(self, scale, warmup_steps): """Initializes the decay function. Args: scale: The scale constant. warmup_steps: The number of warmup steps. """ self.scale = tf.cast(scale, tf.float32) self.warmup_steps = tf.cast(warmup_steps, tf.float32)
def __call__(self, step): step = tf.cast(step, tf.float32) return self.scale * tf.math.rsqrt(tf.maximum(step, self.warmup_steps))
[docs]@register_learning_rate_schedule class InvSqrtDecay(tf.keras.optimizers.schedules.LearningRateSchedule): r"""Decay based on the reciprocal of the step square root. This corresponds to ``inverse_sqrt`` in Fairseq and ``--lr-decay-inv-sqrt`` in Marian. During warmup (linear increase of the learning rate): .. math:: \text{schedule}(\text{step}) = \text{init_lr} + (\text{lr} - \text{init_lr}) \times \frac{\text{step}}{\text{warmup_steps}} After warmup: .. math:: \text{schedule}(\text{step}) = \text{lr} \times \sqrt{\frac{\text{warmup_steps}}{\text{step}}} See also: - :class:`opennmt.schedules.RsqrtDecay` """
[docs] def __init__(self, learning_rate, warmup_steps, initial_learning_rate=0): """Initializes the decay function. Args: learning_rate: The base learning rate. warmup_steps: The number of warmup steps. initial_learning_rate: Initial learning rate during warmup. """ self.lr = tf.cast(learning_rate, tf.float32) self.init_lr = tf.cast(initial_learning_rate, tf.float32) self.warmup_steps = tf.cast(warmup_steps, tf.float32)
def __call__(self, step): step = tf.cast(step + 1, tf.float32) def _warmup(): return self.init_lr + (self.lr - self.init_lr) * (step / self.warmup_steps) def _after_warmup(): return self.lr * tf.math.sqrt(self.warmup_steps / step) return tf.cond( step <= self.warmup_steps, true_fn=_warmup, false_fn=_after_warmup, )
[docs]@register_learning_rate_schedule class CosineAnnealing(tf.keras.optimizers.schedules.LearningRateSchedule): """Decay using a cosine annealing schedule."""
[docs] def __init__(self, eta_max, eta_min=0, max_step=1000000, warmup_steps=None): """Initializes the decay function. Args: eta_max: Maximum learning rate. eta_min: Minimum learning rate. max_step: The last step of the scedule. warmup_steps: The number of steps to increment the learning rate linearly from 0 to :obj:`scale` before annealing. """ self.eta_max = tf.cast(eta_max, tf.float32) self.eta_min = tf.cast(eta_min, tf.float32) self.max_step = tf.cast(max_step, tf.float32) self.warmup_steps = ( tf.cast(warmup_steps, tf.float32) if warmup_steps is not None else None )
def __call__(self, step): step = tf.cast(step, tf.float32) annealing = lambda: ( self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + tf.cos(np.pi * step / self.max_step)) ) linear = lambda: self.eta_max * step / tf.cast(self.warmup_steps, tf.float32) if self.warmup_steps is None: return annealing() return tf.cond( tf.less(step, self.warmup_steps), true_fn=linear, false_fn=annealing )
[docs]@register_learning_rate_schedule class RNMTPlusDecay(tf.keras.optimizers.schedules.LearningRateSchedule): """Defines the decay function described in https://arxiv.org/abs/1804.09849."""
[docs] def __init__( self, scale, num_replicas, warmup_steps=500, start_step=600000, end_step=1200000 ): """Initializes the decay function. Args: scale: The scale constant. num_replicas: The number of concurrent model replicas. warmup_steps: The number of warmup steps. start_step: The start step of the exponential decay. end_step: The end step of the exponential decay. """ self.scale = tf.cast(scale, tf.float32) self.num_replicas = tf.cast(num_replicas, tf.float32) self.warmup_steps = tf.cast(warmup_steps, tf.float32) self.start_step = tf.cast(start_step, tf.float32) self.end_step = tf.cast(end_step, tf.float32)
def __call__(self, step): t = tf.cast(step, tf.float32) n = self.num_replicas p = self.warmup_steps s = self.start_step e = self.end_step return self.scale * tf.minimum( tf.minimum(1 + (t * (n - 1)) / (n * p), n), n * tf.pow(2 * n, (s - n * t) / (e - s)), )