Doc: Framework


class onmt.Models.NMTModel(encoder, decoder, multigpu=False)[source]

Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic encoder + decoder model.

  • encoder (EncoderBase) – an encoder object
  • decoder (RNNDecoderBase) – a decoder object
  • multi<gpu (bool) – setup for multigpu support
forward(src, tgt, lengths, dec_state=None)[source]

Forward propagate a src and tgt pair for training. Possible initialized with a beginning decoder state.

  • src (Tensor) – a source sequence passed to encoder. typically for inputs this will be a padded LongTensor of size [len x batch x features]. however, may be an image or other generic input depending on encoder.
  • tgt (LongTensor) – a target sequence of size [tgt_len x batch].
  • lengths (LongTensor) – the src lengths, pre-padding [batch].
  • dec_state (DecoderState, optional) – initial decoder state

  • decoder output [tgt_len x batch x hidden]
  • dictionary attention dists of [tgt_len x batch x src_len]
  • final decoder state

Return type:

(FloatTensor, dict, onmt.Models.DecoderState)

class onmt.Models.DecoderState[source]

Interface for grouping together the current state of a recurrent decoder. In the simplest case just represents the hidden state of the model. But can also be used for implementing various forms of input_feeding and non-recurrent models.

Modules need to implement this to utilize beam search decoding.


class onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size=0, shard_size=32, data_type='text', norm_method='sents', grad_accum_count=1)[source]

Class that controls the training process.

  • model (onmt.Model.NMTModel) – translation model to train
  • train_loss (onmt.Loss.LossComputeBase) – training loss computation
  • valid_loss (onmt.Loss.LossComputeBase) – training loss computation
  • optim (onmt.Optim.Optim) – the optimizer responsible for update
  • trunc_size (int) – length of truncated back propagation through time
  • shard_size (int) – compute loss in shards of this size for efficiency
  • data_type (string) – type of the source input: [text|img|audio]
  • norm_method (string) – normalization methods: [sents|tokens]
  • grad_accum_count (int) – accumulate gradients this many times.
drop_checkpoint(opt, epoch, fields, valid_stats)[source]

Save a resumable checkpoint.

  • opt (dict) – option object
  • epoch (int) – epoch number
  • fields (dict) – fields and vocabulary
  • valid_stats – statistics of last validation run
train(train_iter, epoch, report_func=None)[source]

Train next epoch. :param train_iter: training data iterator :param epoch: the epoch number :type epoch: int :param report_func: function for logging :type report_func: fn

Returns:epoch loss statistics
Return type:stats (onmt.Statistics)
Validate model.
valid_iter: validate data iterator
Returns:validation loss statistics
Return type:onmt.Statistics
class onmt.Statistics(loss=0, n_words=0, n_correct=0)[source]

Accumulator for loss statistics. Currently calculates:

  • accuracy
  • perplexity
  • elapsed time
output(epoch, batch, n_batches, start)[source]

Write out statistics to stdout.

  • epoch (int) – current epoch
  • batch (int) – current batch
  • n_batch (int) – total batches
  • start (int) – start time of epoch.

log message.

Return type:

msg (str)


class onmt.Loss.LossComputeBase(generator, tgt_vocab)[source]

Class for managing efficient loss computation. Handles sharding next step predictions and accumulating mutiple loss computations

Users can implement their own loss computation strategy by making subclass of this one. Users need to implement the _compute_loss() and make_shard_state() methods.

  • generator (nn.Module) – module that maps the output of the decoder to a distribution over the target vocabulary.
  • tgt_vocab (Vocab) – torchtext vocab object representing the target output
  • normalzation (str) – normalize by “sents” or “tokens”
monolithic_compute_loss(batch, output, attns)[source]

Compute the forward loss for the batch.

  • batch (batch) – batch of labeled examples
  • output (FloatTensor) – output of decoder model [tgt_len x batch x hidden]
  • attns (dict of FloatTensor) – dictionary of attention distributions [tgt_len x batch x src_len]

loss statistics

Return type:


sharded_compute_loss(batch, output, attns, cur_trunc, trunc_size, shard_size, normalization)[source]

Compute the forward loss and backpropagate. Computation is done with shards and optionally truncation for memory efficiency.

Also supports truncated BPTT for long sequences by taking a range in the decoder output sequence to back propagate in. Range is from (cur_trunc, cur_trunc + trunc_size).

Note sharding is an exact efficiency trick to relieve memory required for the generation buffers. Truncation is an approximate efficiency trick to relieve the memory required in the RNN buffers.

  • batch (batch) – batch of labeled examples
  • output (FloatTensor) – output of decoder model [tgt_len x batch x hidden]
  • attns (dict) – dictionary of attention distributions [tgt_len x batch x src_len]
  • cur_trunc (int) – starting position of truncation window
  • trunc_size (int) – length of truncation window
  • shard_size (int) – maximum number of examples in a shard
  • normalization (int) – Loss is divided by this number

validation loss statistics

Return type:



class onmt.Optim.Optim(method, lr, max_grad_norm, lr_decay=1, start_decay_at=None, beta1=0.9, beta2=0.999, adagrad_accum=0.0, decay_method=None, warmup_steps=4000, model_size=None)[source]

Controller class for optimization. Mostly a thin wrapper for optim, but also useful for implementing rate scheduling beyond what is currently available. Also implements necessary methods for training RNNs such as grad manipulations.

  • method (str) – one of [sgd, adagrad, adadelta, adam]
  • lr (float) – learning rate
  • lr_decay (float, optional) – learning rate decay multiplier
  • start_decay_at (int, optional) – epoch to start learning rate decay
  • beta2 (beta1,) – parameters for adam
  • adagrad_accum (float, optional) – initialization parameter for adagrad
  • decay_method (str, option) – custom decay options
  • warmup_steps (int, option) – parameter for noam decay
  • model_size (int, option) – parameter for noam decay

Update the model parameters based on current gradients.

Optionally, will employ gradient modification or update learning rate.

update_learning_rate(ppl, epoch)[source]

Decay learning rate if val perf does not improve or we hit the start_decay_at limit.