class onmt.models.NMTModel(encoder, decoder)[source]

Bases: onmt.models.model.BaseModel

Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic encoder + decoder model. :param encoder: an encoder object :type encoder: onmt.encoders.EncoderBase :param decoder: a decoder object :type decoder: onmt.decoders.DecoderBase

count_parameters(log=<built-in function print>)[source]

Count number of parameters in model (& print with log callback).


  • encoder side parameter count

  • decoder side parameter count

Return type

(int, int)

forward(src, tgt, lengths, bptt=False, with_align=False)[source]

Forward propagate a src and tgt pair for training. Possible initialized with a beginning decoder state.

  • src (Tensor) – A source sequence passed to encoder. typically for inputs this will be a padded LongTensor of size (len, batch, features). However, may be an image or other generic input depending on encoder.

  • tgt (LongTensor) – A target sequence passed to decoder. Size (tgt_len, batch, features).

  • lengths (LongTensor) – The src lengths, pre-padding (batch,).

  • bptt (Boolean) – A flag indicating if truncated bptt is set. If reset then init_state

  • with_align (Boolean) – A flag indicating whether output alignment, Only valid for transformer decoder.


  • decoder output (tgt_len, batch, hidden)

  • dictionary attention dists of (tgt_len, batch, src_len)

Return type

(FloatTensor, dict[str, FloatTensor])


class onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size=0, shard_size=32, norm_method='sents', accum_count=[1], accum_steps=[0], n_gpu=1, gpu_rank=1, gpu_verbose_level=0, report_manager=None, with_align=False, model_saver=None, average_decay=0, average_every=1, model_dtype='fp32', earlystopper=None, dropout=[0.3], dropout_steps=[0])[source]

Bases: object

Class that controls the training process.

  • model (onmt.models.model.NMTModel) – translation model to train

  • train_loss (onmt.utils.loss.LossComputeBase) – training loss computation

  • valid_loss (onmt.utils.loss.LossComputeBase) – training loss computation

  • optim (onmt.utils.optimizers.Optimizer) – the optimizer responsible for update

  • trunc_size (int) – length of truncated back propagation through time

  • shard_size (int) – compute loss in shards of this size for efficiency

  • data_type (string) – type of the source input: [text]

  • norm_method (string) – normalization methods: [sents|tokens]

  • accum_count (list) – accumulate gradients this many times.

  • accum_steps (list) – steps for accum gradients changes.

  • report_manager (onmt.utils.ReportMgrBase) – the object that creates reports, or None

  • model_saver (onmt.models.ModelSaverBase) – the saver is used to save a checkpoint. Thus nothing will be saved if this parameter is None

train(train_iter, train_steps, save_checkpoint_steps=5000, valid_iter=None, valid_steps=10000)[source]

The main training loop by iterating over train_iter and possibly running validation on valid_iter.

  • train_iter – A generator that returns the next training batch.

  • train_steps – Run training for this many iterations.

  • save_checkpoint_steps – Save a checkpoint every this many iterations.

  • valid_iter – A generator that returns the next validation batch.

  • valid_steps – Run evaluation every this many iterations.


The gathered statistics.

validate(valid_iter, moving_average=None)[source]
Validate model.

valid_iter: validate data iterator


validation loss statistics

Return type


class onmt.utils.Statistics(loss=0, n_words=0, n_correct=0)[source]

Bases: object

Accumulator for loss statistics. Currently calculates:

  • accuracy

  • perplexity

  • elapsed time


compute accuracy

static all_gather_stats(stat, max_size=4096)[source]

Gather a Statistics object accross multiple process/nodes

  • stat( – obj:Statistics): the statistics object to gather accross all processes/nodes

  • max_size (int) – max buffer size to use


Statistics, the update stats object

static all_gather_stats_list(stat_list, max_size=4096)[source]

Gather a Statistics list accross all processes/nodes

  • stat_list (list([Statistics])) – list of statistics objects to gather accross all processes/nodes

  • max_size (int) – max buffer size to use


list of updated stats

Return type



compute elapsed time

log_tensorboard(prefix, writer, learning_rate, patience, step)[source]

display statistics to tensorboard

output(step, num_steps, learning_rate, start)[source]

Write out statistics to stdout.

  • step (int) – current step

  • n_batch (int) – total batches

  • start (int) – start time of step.


compute perplexity

update(stat, update_n_src_words=False)[source]

Update statistics by suming values with another Statistics object

  • stat – another statistic object

  • update_n_src_words (bool) – whether to update (sum) n_src_words or not


compute cross entropy


class onmt.utils.loss.LossComputeBase(criterion, generator)[source]

Bases: torch.nn.modules.module.Module

Class for managing efficient loss computation. Handles sharding next step predictions and accumulating multiple loss computations

Users can implement their own loss computation strategy by making subclass of this one. Users need to implement the _compute_loss() and make_shard_state() methods.

  • generator (nn.Module) – module that maps the output of the decoder to a distribution over the target vocabulary.

  • tgt_vocab (Vocab) – torchtext vocab object representing the target output

  • normalzation (str) – normalize by “sents” or “tokens”


class onmt.utils.Optimizer(optimizer, learning_rate, learning_rate_decay_fn=None, max_grad_norm=None)[source]

Bases: object

Controller class for optimization. Mostly a thin wrapper for optim, but also useful for implementing rate scheduling beyond what is currently available. Also implements necessary methods for training RNNs such as grad manipulations.

property amp

True if use torch amp mix precision training.


Wrapper for backward pass. Some optimizer requires ownership of the backward pass.

classmethod from_opt(model, opt, checkpoint=None)[source]

Builds the optimizer from options.

  • cls – The Optimizer class to instantiate.

  • model – The model to optimize.

  • opt – The dict of user options.

  • checkpoint – An optional checkpoint to load states from.


An Optimizer instance.


Returns the current learning rate.


Update the model parameters based on current gradients.

Optionally, will employ gradient modification or update learning rate.

property training_step

The current training step.


Zero the gradients of optimized parameters.