Framework¶
Model¶
-
class
onmt.models.
BaseModel
(encoder, decoder)[source]¶ Bases:
torch.nn.modules.module.Module
Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic encoder / decoder or decoder only model.
- Parameters
encoder (onmt.encoders.EncoderBase) – an encoder object
decoder (onmt.decoders.DecoderBase) – a decoder object
-
forward
(src, tgt, src_len, bptt=False, with_align=False)[source]¶ Forward propagate a src and tgt pair for training.
- Parameters
src (Tensor) – A source sequence passed to encoder. Typically for input this will be a padded LongTensor of size
(batch, len, features)
. However, may be an image or other generic input depending on encoder.tgt (LongTensor) – A target sequence passed to decoder. Size
(batch, tgt_len, features)
.src_len (LongTensor) – The src lengths, pre-padding
(batch,)
.bptt (Boolean) – A flag indicating if truncated bptt is set. If bptt is false then init decoder state.
with_align (Boolean) – A flag indicating whether output alignment, Only valid for transformer decoder.
- Returns
decoder output
(batch, tgt_len, hidden)
dictionary of attention weights
(batch, tgt_len, src_len)
- Return type
(FloatTensor, dict[str, FloatTensor])
-
class
onmt.models.
NMTModel
(encoder, decoder)[source]¶ Bases:
onmt.models.model.BaseModel
NMTModel Class See
BaseModel
for options.-
count_parameters
(log=<built-in function print>)[source]¶ Count number of parameters in model (& print with log callback).
- Returns
encoder side parameter count
decoder side parameter count
- Return type
(int, int)
-
forward
(src, tgt, src_len, bptt=False, with_align=False)[source]¶ An NMTModel forward the src side to the encoder. Then the output of encoder
enc_out
is forwarded to the decoder along with the target excluding the last token. The decoder state is initiliazed with:enc_final_hs in the case of RNNs
enc_out + enc_final_hs in the case of CNNs
src in the case of Transformer
-
-
class
onmt.models.
LanguageModel
(encoder=None, decoder=None)[source]¶ Bases:
onmt.models.model.BaseModel
NMTModel Class Currently TransformerLMDecoder is the only LM decoder implemented :param decoder: a transformer decoder :type decoder: onmt.decoders.TransformerLMDecoder
Trainer¶
-
class
onmt.
Trainer
(model, train_loss, valid_loss, scoring_preparator, train_scorers, valid_scorers, optim, trunc_size=0, norm_method='sents', accum_count=[1], accum_steps=[0], n_gpu=1, gpu_rank=1, train_eval_steps=200, report_manager=None, with_align=False, model_saver=None, average_decay=0, average_every=1, model_dtype='fp32', earlystopper=None, dropout=[0.3], attention_dropout=[0.1], dropout_steps=[0])[source]¶ Bases:
object
Class that controls the training process.
- Parameters
model (
onmt.models.model.NMTModel
) – translation model to traintrain_loss (
onmt.utils.loss.LossComputeBase
) – training loss computationvalid_loss (
onmt.utils.loss.LossComputeBase
) – training loss computationscoring_preparator (
onmt.translate.utils.ScoringPreparator
) – preparator for the calculation of metrics via the training_eval_handler methodtrain_scorers (dict) – keeps in memory the current values of the training metrics
valid_scorers (dict) – keeps in memory the current values of the validation metrics
optim (
onmt.utils.optimizers.Optimizer
) – the optimizer responsible for updatetrunc_size (int) – length of truncated back propagation through time
accum_count (list) – accumulate gradients this many times.
accum_steps (list) – steps for accum gradients changes.
n_gpu (int) – number of gpu.
gpu_rank (int) – ordinal rank of the gpu in the list.
train_eval_steps (int) – process a validation every x steps.
report_manager (
onmt.utils.ReportMgrBase
) – the object that creates reports, or Nonewith_align (bool) – whether to jointly lear alignment (Transformer)
model_saver (
onmt.models.ModelSaverBase
) – the saver is used to save a checkpoint. Thus nothing will be saved if this parameter is None.average_decay (float) – cf opt.average_decay
average_every (int) – average model every x steps.
model_dtype (str) – fp32 or fp16.
earlystopper (
onmt.utils.EarlyStopping
) – add early stopping mecanismdropout (float) – dropout value in RNN or FF layers.
attention_dropout (float) – dropaout in attention layers.
dropout_steps (list) – dropout values scheduling in steps.
-
train
(train_iter, train_steps, save_checkpoint_steps=5000, valid_iter=None, valid_steps=10000)[source]¶ The main training loop by iterating over train_iter and possibly running validation on valid_iter.
- Parameters
train_iter – An iterator that returns the next training batch.
train_steps – Run training for this many iterations.
save_checkpoint_steps – Save a checkpoint every this many iterations.
valid_iter – A generator that returns the next validation batch.
valid_steps – Run evaluation every this many iterations.
- Returns
training loss statistics
- Return type
nmt.Statistics
-
class
onmt.utils.
Statistics
(loss=0, n_batchs=0, n_sents=0, n_words=0, n_correct=0, computed_metrics={})[source]¶ Bases:
object
Accumulator for loss statistics. Currently calculates:
accuracy
perplexity
elapsed time
-
static
all_gather_stats
(stat, max_size=4096)[source]¶ Gather a Statistics object accross multiple process/nodes
- Parameters
stat( – obj:Statistics): the statistics object to gather accross all processes/nodes
max_size (int) – max buffer size to use
- Returns
Statistics, the update stats object
-
static
all_gather_stats_list
(stat_list, max_size=4096)[source]¶ Gather a Statistics list accross all processes/nodes
- Parameters
stat_list (list([Statistics])) – list of statistics objects to gather accross all processes/nodes
max_size (int) – max buffer size to use
- Returns
list of updated stats
- Return type
our_stats(list([Statistics]))
-
log_tensorboard
(prefix, writer, learning_rate, patience, step)[source]¶ display statistics to tensorboard
-
output
(step, num_steps, learning_rate, start)[source]¶ Write out statistics to stdout.
- Parameters
step (int) – current step
n_batch (int) – total batches
start (int) – start time of step.
Loss¶
-
class
onmt.utils.loss.
LossCompute
(criterion, generator, copy_attn=False, lambda_coverage=0.0, lambda_align=0.0, tgt_shift_index=1, vocab=None, lm_generator=None, lm_prior_lambda=None, lm_prior_tau=None, lm_prior_model=None)[source]¶ Bases:
torch.nn.modules.module.Module
Class for managing efficient loss computation. Handles accumulating multiple loss computations.
- Parameters
criterion (
nn. loss function
) – NLLoss or customed lossgenerator (
nn.Module
) –copy_attn (bool) – whether copy attention mechanism is on/off
lambda_coverage – Hyper-param to apply coverage attention if any
lambda_align – Hyper-param for alignment loss
tgt_shift_index (int) – 1 for NMT, 0 for LM
vocab – target vocab (for copy attention score calculation) module that maps the output of the decoder to a distribution over the target vocabulary.
lm_generator (
ctranslate2.Generator
) – LM Generatorlm_prior_lambda (float) – weight of LM model in loss
lm_prior_tau (float) – scaler for LM loss
-
forward
(batch, output, attns, trunc_start=0, trunc_size=None)[source]¶ Compute the forward loss, supports truncated BPTT for long sequences by taking a range in the decoder output sequence to back propagate in. Range is from (trunc_start, trunc_start + trunc_size). Truncation is an approximate efficiency trick to relieve the memory required in the RNN buffers.
- Parameters
batch (batch) – batch of labeled examples
output (
FloatTensor
) – output of decoder model(batch, tgt_len, hidden)
attns (dict) – dictionary of attention weights
(batch, tgt_len, src_len)
trunc_start (int) – starting position of truncation window
trunc_size (int) – length of truncation window
- Returns
A tuple with the loss and a
onmt.utils.Statistics
instance.
-
classmethod
from_opts
(opt, model, vocab, train=True)[source]¶ Returns a subclass which wraps around an nn.Module subclass (such as nn.NLLLoss) which defines the loss criterion. The LossCompute object passes relevant data to a Statistics object which handles training/validation logging. The Criterion and LossCompute options are triggered by opt settings.
Optimizer¶
-
class
onmt.utils.
Optimizer
(optimizer, learning_rate, learning_rate_decay_fn=None, max_grad_norm=None)[source]¶ Bases:
object
Controller class for optimization. Mostly a thin wrapper for optim, but also useful for implementing rate scheduling beyond what is currently available. Also implements necessary methods for training RNNs such as grad manipulations.
- Parameters
optimizer – A
torch.optim.Optimizer
instance.learning_rate – The initial learning rate.
learning_rate_decay_fn – An optional callable taking the current step as argument and return a learning rate scaling factor.
max_grad_norm – Clip gradients to this global norm.
-
property
amp
¶ True if use torch amp mix precision training.
-
backward
(loss)[source]¶ Wrapper for backward pass. Some optimizer requires ownership of the backward pass.
-
classmethod
from_opt
(model, opt, checkpoint=None)[source]¶ Builds the optimizer from options.
- Parameters
cls – The
Optimizer
class to instantiate.model – The model to optimize.
opt – The dict of user options.
checkpoint – An optional checkpoint to load states from.
- Returns
An
Optimizer
instance.
-
step
()[source]¶ Update the model parameters based on current gradients.
Optionally, will employ gradient modification or update learning rate.
-
property
training_step
¶ The current training step.
-
class
onmt.utils.
AdaFactor
(params, lr=None, beta1=0.9, beta2=0.999, eps1=1e-30, eps2=0.001, cliping_threshold=1, non_constant_decay=True, enable_factorization=True, ams_grad=True, weight_decay=0)[source]¶ Bases:
torch.optim.optimizer.Optimizer
-
step
(closure=None)[source]¶ Performs a single optimization step (parameter update).
- Parameters
closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.
Note
Unless otherwise specified, this function should not modify the
.grad
field of the parameters.
-
-
class
onmt.utils.
FusedAdam
(params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, eps_inside_sqrt=False, weight_decay=0.0, max_grad_norm=0.0, amsgrad=False)[source]¶ Bases:
torch.optim.optimizer.Optimizer
- Implements Adam algorithm. Currently GPU-only.
Requires Apex to be installed via
python setup.py install --cuda_ext --cpp_ext
.
- Parameters
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional) – learning rate. (default: 1e-3)
betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional) – in the ‘update parameters’ step, adds eps to the bias-corrected second moment estimate before evaluating square root instead of adding it to the square root of second moment estimate as in the original paper. (default: False)
-
step
(closure=None, grads=None, output_params=None, scale=1.0, grad_norms=None)[source]¶ Performs a single optimization step.
- Parameters
closure (callable, optional) – A closure that reevaluates the model and returns the loss.
grads (list of tensors, optional) – weight gradient to use for the optimizer update. If gradients have type torch.half, parameters are expected to be in type torch.float. (default: None)
params (output) – A reduced precision copy of the updated weights written out in addition to the regular updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional) – factor to divide gradient tensor values by before applying to weights. (default: 1)