class onmt.models.BaseModel(encoder, decoder)[source]

Bases: Module

Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic encoder / decoder or decoder only model.

forward(src, tgt, src_len, bptt=False, with_align=False)[source]

Forward propagate a src and tgt pair for training.

  • src (Tensor) – A source sequence passed to encoder. Typically for input this will be a padded LongTensor of size (batch, len, features). However, may be an image or other generic input depending on encoder.

  • tgt (LongTensor) – A target sequence passed to decoder. Size (batch, tgt_len, features).

  • src_len (LongTensor) – The src lengths, pre-padding (batch,).

  • bptt (Boolean) – A flag indicating if truncated bptt is set. If bptt is false then init decoder state.

  • with_align (Boolean) – A flag indicating whether output alignment, Only valid for transformer decoder.


  • decoder output (batch, tgt_len, hidden)

  • dictionary of attention weights (batch, tgt_len, src_len)

Return type:

(FloatTensor, dict[str, FloatTensor])

load_safe_state_dict(model_path, precision=torch.float32, device=device(type='cpu'), strict=True, offset=0)[source]

Custom state_dict loading to enable moving module on device as they are loaded

  • model_path – Model path

  • precision – same as above

  • device – same as above

  • strict – same as above

load_state_dict(checkpoint, precision=torch.float32, device=device(type='cpu'), strict=True, offset=0)[source]

Custom state_dict loading to enable moving module on device as they are loaded

  • checkpoint – Pytorch serialized checkpoint

  • precision – precision to move each module to

  • device – device to move each module to

  • strict – if True checks model keys wrt state_dict (both ways)

class onmt.models.NMTModel(encoder, decoder)[source]

Bases: BaseModel

NMTModel Class See BaseModel for options.

count_parameters(log=<built-in function print>)[source]

Count number of parameters in model (& print with log callback).


  • encoder side parameter count

  • decoder side parameter count

Return type:

(int, int)

forward(src, tgt, src_len, bptt=False, with_align=False)[source]

An NMTModel forward the src side to the encoder. Then the output of encoder enc_out is forwarded to the decoder along with the target excluding the last token. The decoder state is initiliazed with: * enc_final_hs in the case of RNNs * enc_out + enc_final_hs in the case of CNNs * src in the case of Transformer

class onmt.models.LanguageModel(encoder=None, decoder=None)[source]

Bases: BaseModel

NMTModel Class Currently TransformerLMDecoder is the only LM decoder implemented


decoder (onmt.decoders.TransformerLMDecoder) – a transformer decoder

count_parameters(log=<built-in function print>)[source]

Count number of parameters in model (& print with log callback).

Returns: (int, int)

encoder side parameter count decoder side parameter count

forward(src, tgt, src_len, bptt=False, with_align=False)[source]

A LanguageModel forward the src side to the decoder along with the source lengths vector. It is a decoder only LM (cf GPT-2)


class onmt.trainer.Trainer(model, train_loss, valid_loss, scoring_preparator, valid_scorers, optim, trunc_size=0, norm_method='sents', accum_count=[1], accum_steps=[0], n_gpu=1, gpu_rank=1, parallel_mode='data_parallel', report_manager=None, with_align=False, model_saver=None, average_decay=0, average_every=1, model_dtype='fp32', earlystopper=None, dropout=[0.3], attention_dropout=[0.1], dropout_steps=[0], zero_out_prompt_loss=False)[source]

Bases: object

Class that controls the training process.

  • model (onmt.models.model.NMTModel) – model to train

  • train_loss (onmt.utils.loss.LossComputeBase) – training loss computation

  • valid_loss (onmt.utils.loss.LossComputeBase) – training loss computation

  • scoring_preparator (onmt.translate.utils.ScoringPreparator) – preparator for the calculation of metrics via the _eval_handler method

  • valid_scorers (dict) – keeps in memory the current values of the validation metrics

  • optim (onmt.utils.optimizers.Optimizer) – the optimizer responsible for update

  • trunc_size (int) – length of truncated back propagation through time

  • accum_count (list) – accumulate gradients this many times.

  • accum_steps (list) – steps for accum gradients changes.

  • n_gpu (int) – number of gpu.

  • gpu_rank (int) – ordinal rank of the gpu in the list.

  • report_manager (onmt.utils.ReportMgrBase) – the object that creates reports, or None

  • with_align (bool) – whether to jointly lear alignment (Transformer)

  • model_saver (onmt.models.ModelSaverBase) – the saver is used to save a checkpoint. Thus nothing will be saved if this parameter is None.

  • average_decay (float) – cf opt.average_decay

  • average_every (int) – average model every x steps.

  • model_dtype (str) – fp32 or fp16.

  • earlystopper (onmt.utils.EarlyStopping) – add early stopping mecanism

  • dropout (float) – dropout value in RNN or FF layers.

  • attention_dropout (float) – dropaout in attention layers.

  • dropout_steps (list) – dropout values scheduling in steps.

  • zero_out_prompt_loss (bool) – whether to zero-out the prompt loss (mostly for LLM finetuning).

train(train_iter, train_steps, save_checkpoint_steps=5000, valid_iter=None, valid_steps=10000)[source]

The main training loop by iterating over train_iter and possibly running validation on valid_iter.

  • train_iter – An iterator that returns the next training batch.

  • train_steps – Run training for this many iterations.

  • save_checkpoint_steps – Save a checkpoint every this many iterations.

  • valid_iter – A generator that returns the next validation batch.

  • valid_steps – Run evaluation every this many iterations.


training loss statistics

Return type:


validate(valid_iter, moving_average=None)[source]

Validate model.


valid_iter – validate data iterator


validation loss statistics

Return type:


class onmt.utils.Statistics(loss=0, n_batchs=0, n_sents=0, n_words=0, n_correct=0, computed_metrics={})[source]

Bases: object

Accumulator for loss statistics. Currently calculates:

  • accuracy

  • perplexity

  • elapsed time


compute accuracy

static all_gather_stats(stat, max_size=4096)[source]

Gather a Statistics object accross multiple process/nodes

  • stat( – obj:Statistics): the statistics object to gather accross all processes/nodes

  • max_size (int) – max buffer size to use


Statistics, the update stats object

static all_gather_stats_list(stat_list, max_size=4096)[source]

Gather a Statistics list accross all processes/nodes

  • stat_list (list([Statistics])) – list of statistics objects to gather accross all processes/nodes

  • max_size (int) – max buffer size to use


list of updated stats

Return type:



compute elapsed time

log_tensorboard(prefix, writer, learning_rate, patience, step)[source]

display statistics to tensorboard

output(step, num_steps, learning_rate, start)[source]

Write out statistics to stdout.

  • step (int) – current step

  • n_batch (int) – total batches

  • start (int) – start time of step.


compute perplexity

update(stat, update_n_src_words=False)[source]

Update statistics by suming values with another Statistics object

  • stat – another statistic object

  • update_n_src_words (bool) – whether to update (sum) n_src_words or not


compute cross entropy


class onmt.utils.loss.LossCompute(criterion, generator, copy_attn=False, lambda_coverage=0.0, lambda_align=0.0, tgt_shift_index=1, vocab=None, lm_generator=None, lm_prior_lambda=None, lm_prior_tau=None, lm_prior_model=None)[source]

Bases: Module

Class for managing efficient loss computation. Handles accumulating multiple loss computations.

  • criterion (nn. loss function) – NLLoss or customed loss

  • generator (nn.Module) –

  • copy_attn (bool) – whether copy attention mechanism is on/off

  • lambda_coverage – Hyper-param to apply coverage attention if any

  • lambda_align – Hyper-param for alignment loss

  • tgt_shift_index (int) – 1 for NMT, 0 for LM

  • vocab – target vocab (for copy attention score calculation) module that maps the output of the decoder to a distribution over the target vocabulary.

  • lm_generator (ctranslate2.Generator) – LM Generator

  • lm_prior_lambda (float) – weight of LM model in loss

  • lm_prior_tau (float) – scaler for LM loss

forward(batch, output, attns, trunc_start=0, trunc_size=None)[source]

Compute the forward loss, supports truncated BPTT for long sequences by taking a range in the decoder output sequence to back propagate in. Range is from (trunc_start, trunc_start + trunc_size). Truncation is an approximate efficiency trick to relieve the memory required in the RNN buffers.

  • batch (batch) – batch of labeled examples

  • output (FloatTensor) – output of decoder model (batch, tgt_len, hidden)

  • attns (dict) – dictionary of attention weights (batch, tgt_len, src_len)

  • trunc_start (int) – starting position of truncation window

  • trunc_size (int) – length of truncation window


A tuple with the loss and a onmt.utils.Statistics instance.

classmethod from_opts(opt, model, vocab, train=True)[source]

Returns a subclass which wraps around an nn.Module subclass (such as nn.NLLLoss) which defines the loss criterion. The LossCompute object passes relevant data to a Statistics object which handles training/validation logging. The Criterion and LossCompute options are triggered by opt settings.

Mask the prompt in the target side of the batch examples in order

to set the loss of the prompt to zero.

For finetuning on specific tasks. The end of the prompt must be indicated by the DefaultTokens.MASK_BEFORE


The masks are supposed to be properly handled by the loss criterion

(e.g. nn.CrossEntropyLoss ).


batch – The current batch.


class onmt.utils.Optimizer(optimizer, learning_rate, learning_rate_decay_fn=None, max_grad_norm=None)[source]

Bases: object

Controller class for optimization. Mostly a thin wrapper for optim, but also useful for implementing rate scheduling beyond what is currently available. Also implements necessary methods for training RNNs such as grad manipulations.

  • optimizer – A torch.optim.Optimizer instance.

  • learning_rate – The initial learning rate.

  • learning_rate_decay_fn – An optional callable taking the current step as argument and return a learning rate scaling factor.

  • max_grad_norm – Clip gradients to this global norm.

property amp

True if use torch amp mix precision training.


Wrapper for backward pass. Some optimizer requires ownership of the backward pass.

classmethod from_opt(model, opt, checkpoint=None)[source]

Builds the optimizer from options.

  • cls – The Optimizer class to instantiate.

  • model – The model to optimize.

  • opt – The dict of user options.

  • checkpoint – An optional checkpoint to load states from.


An Optimizer instance.


Returns the current learning rate.


Update the model parameters based on current gradients.

Optionally, will employ gradient modification or update learning rate.

property training_step

The current training step.


Zero the gradients of optimized parameters.

class onmt.utils.AdaFactor(params, lr=None, beta1=0.9, beta2=0.999, eps1=1e-30, eps2=0.001, cliping_threshold=1, non_constant_decay=True, enable_factorization=True, ams_grad=True, weight_decay=0)[source]

Bases: Optimizer


Performs a single optimization step (parameter update).


closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.


Unless otherwise specified, this function should not modify the .grad field of the parameters.

class onmt.utils.FusedAdam(params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, eps_inside_sqrt=False, weight_decay=0.0, max_grad_norm=0.0, amsgrad=False)[source]

Bases: Optimizer

Implements Adam algorithm. Currently GPU-only.

Requires Apex to be installed via python install --cuda_ext --cpp_ext.

  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.

  • lr (float, optional) – learning rate. (default: 1e-3)

  • betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))

  • eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)

  • weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)

  • amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper ‘On the Convergence of Adam and Beyond’ (default: False) NOT SUPPORTED in FusedAdam!

  • eps_inside_sqrt (boolean, optional) – in the ‘update parameters’ step, adds eps to the bias-corrected second moment estimate before evaluating square root instead of adding it to the square root of second moment estimate as in the original paper. (default: False)

step(closure=None, grads=None, output_params=None, scale=1.0, grad_norms=None)[source]

Performs a single optimization step.

  • closure (callable, optional) – A closure that reevaluates the model and returns the loss.

  • grads (list of tensors, optional) – weight gradient to use for the optimizer update. If gradients have type torch.half, parameters are expected to be in type torch.float. (default: None)

  • params (output) – A reduced precision copy of the updated weights written out in addition to the regular updated weights. Have to be of same type as gradients. (default: None)

  • scale (float, optional) – factor to divide gradient tensor values by before applying to weights. (default: 1)