Framework¶
Model¶
- class onmt.models.BaseModel(encoder, decoder)[source]¶
Bases:
Module
Core trainable object in OpenNMT. Implements a trainable interface for a simple, generic encoder / decoder or decoder only model.
- Parameters:
encoder (onmt.encoders.EncoderBase) – an encoder object
decoder (onmt.decoders.DecoderBase) – a decoder object
- forward(src, tgt, src_len, bptt=False, with_align=False)[source]¶
Forward propagate a src and tgt pair for training.
- Parameters:
src (Tensor) – A source sequence passed to encoder. Typically for input this will be a padded LongTensor of size
(batch, len, features)
. However, may be an image or other generic input depending on encoder.tgt (LongTensor) – A target sequence passed to decoder. Size
(batch, tgt_len, features)
.src_len (LongTensor) – The src lengths, pre-padding
(batch,)
.bptt (Boolean) – A flag indicating if truncated bptt is set. If bptt is false then init decoder state.
with_align (Boolean) – A flag indicating whether output alignment, Only valid for transformer decoder.
- Returns:
decoder output
(batch, tgt_len, hidden)
dictionary of attention weights
(batch, tgt_len, src_len)
- Return type:
(FloatTensor, dict[str, FloatTensor])
- load_safe_state_dict(model_path, precision=torch.float32, device=device(type='cpu'), strict=True, offset=0)[source]¶
Custom state_dict loading to enable moving module on device as they are loaded
- Parameters:
model_path – Model path
precision – same as above
device – same as above
strict – same as above
- load_state_dict(checkpoint, precision=torch.float32, device=device(type='cpu'), strict=True, offset=0)[source]¶
Custom state_dict loading to enable moving module on device as they are loaded
- Parameters:
checkpoint – Pytorch serialized checkpoint
precision – precision to move each module to
device – device to move each module to
strict – if True checks model keys wrt state_dict (both ways)
- class onmt.models.NMTModel(encoder, decoder)[source]¶
Bases:
BaseModel
NMTModel Class See
BaseModel
for options.- count_parameters(log=<built-in function print>)[source]¶
Count number of parameters in model (& print with log callback).
- Returns:
encoder side parameter count
decoder side parameter count
- Return type:
(int, int)
- forward(src, tgt, src_len, bptt=False, with_align=False)[source]¶
An NMTModel forward the src side to the encoder. Then the output of encoder
enc_out
is forwarded to the decoder along with the target excluding the last token. The decoder state is initiliazed with: * enc_final_hs in the case of RNNs * enc_out + enc_final_hs in the case of CNNs * src in the case of Transformer
- class onmt.models.LanguageModel(encoder=None, decoder=None)[source]¶
Bases:
BaseModel
NMTModel Class Currently TransformerLMDecoder is the only LM decoder implemented
- Parameters:
decoder (onmt.decoders.TransformerLMDecoder) – a transformer decoder
Trainer¶
- class onmt.trainer.Trainer(model, train_loss, valid_loss, scoring_preparator, valid_scorers, optim, trunc_size=0, norm_method='sents', accum_count=[1], accum_steps=[0], n_gpu=1, gpu_rank=1, parallel_mode='data_parallel', report_manager=None, with_align=False, model_saver=None, average_decay=0, average_every=1, model_dtype='fp32', earlystopper=None, dropout=[0.3], attention_dropout=[0.1], dropout_steps=[0], zero_out_prompt_loss=False)[source]¶
Bases:
object
Class that controls the training process.
- Parameters:
model (
onmt.models.model.NMTModel
) – model to traintrain_loss (
onmt.utils.loss.LossComputeBase
) – training loss computationvalid_loss (
onmt.utils.loss.LossComputeBase
) – training loss computationscoring_preparator (
onmt.translate.utils.ScoringPreparator
) – preparator for the calculation of metrics via the _eval_handler methodvalid_scorers (dict) – keeps in memory the current values of the validation metrics
optim (
onmt.utils.optimizers.Optimizer
) – the optimizer responsible for updatetrunc_size (int) – length of truncated back propagation through time
accum_count (list) – accumulate gradients this many times.
accum_steps (list) – steps for accum gradients changes.
n_gpu (int) – number of gpu.
gpu_rank (int) – ordinal rank of the gpu in the list.
report_manager (
onmt.utils.ReportMgrBase
) – the object that creates reports, or Nonewith_align (bool) – whether to jointly lear alignment (Transformer)
model_saver (
onmt.models.ModelSaverBase
) – the saver is used to save a checkpoint. Thus nothing will be saved if this parameter is None.average_decay (float) – cf opt.average_decay
average_every (int) – average model every x steps.
model_dtype (str) – fp32 or fp16.
earlystopper (
onmt.utils.EarlyStopping
) – add early stopping mecanismdropout (float) – dropout value in RNN or FF layers.
attention_dropout (float) – dropaout in attention layers.
dropout_steps (list) – dropout values scheduling in steps.
zero_out_prompt_loss (bool) – whether to zero-out the prompt loss (mostly for LLM finetuning).
- train(train_iter, train_steps, save_checkpoint_steps=5000, valid_iter=None, valid_steps=10000)[source]¶
The main training loop by iterating over
train_iter
and possibly running validation onvalid_iter
.- Parameters:
train_iter – An iterator that returns the next training batch.
train_steps – Run training for this many iterations.
save_checkpoint_steps – Save a checkpoint every this many iterations.
valid_iter – A generator that returns the next validation batch.
valid_steps – Run evaluation every this many iterations.
- Returns:
training loss statistics
- Return type:
:obj:
nmt.Statistics
- class onmt.utils.Statistics(loss=0, n_batchs=0, n_sents=0, n_words=0, n_correct=0, computed_metrics={})[source]¶
Bases:
object
Accumulator for loss statistics. Currently calculates:
accuracy
perplexity
elapsed time
- static all_gather_stats(stat, max_size=4096)[source]¶
Gather a Statistics object accross multiple process/nodes
- Parameters:
stat( – obj:Statistics): the statistics object to gather accross all processes/nodes
max_size (int) – max buffer size to use
- Returns:
Statistics, the update stats object
- static all_gather_stats_list(stat_list, max_size=4096)[source]¶
Gather a Statistics list accross all processes/nodes
- Parameters:
stat_list (list([Statistics])) – list of statistics objects to gather accross all processes/nodes
max_size (int) – max buffer size to use
- Returns:
list of updated stats
- Return type:
our_stats(list([Statistics]))
- log_tensorboard(prefix, writer, learning_rate, patience, step)[source]¶
display statistics to tensorboard
- output(step, num_steps, learning_rate, start)[source]¶
Write out statistics to stdout.
- Parameters:
step (int) – current step
n_batch (int) – total batches
start (int) – start time of step.
Loss¶
- class onmt.utils.loss.LossCompute(criterion, generator, copy_attn=False, lambda_coverage=0.0, lambda_align=0.0, tgt_shift_index=1, vocab=None, lm_generator=None, lm_prior_lambda=None, lm_prior_tau=None, lm_prior_model=None)[source]¶
Bases:
Module
Class for managing efficient loss computation. Handles accumulating multiple loss computations.
- Parameters:
criterion (
nn. loss function
) – NLLoss or customed lossgenerator (
nn.Module
) –copy_attn (bool) – whether copy attention mechanism is on/off
lambda_coverage – Hyper-param to apply coverage attention if any
lambda_align – Hyper-param for alignment loss
tgt_shift_index (int) – 1 for NMT, 0 for LM
vocab – target vocab (for copy attention score calculation) module that maps the output of the decoder to a distribution over the target vocabulary.
lm_generator (
ctranslate2.Generator
) – LM Generatorlm_prior_lambda (float) – weight of LM model in loss
lm_prior_tau (float) – scaler for LM loss
- forward(batch, output, attns, trunc_start=0, trunc_size=None)[source]¶
Compute the forward loss, supports truncated BPTT for long sequences by taking a range in the decoder output sequence to back propagate in. Range is from (trunc_start, trunc_start + trunc_size). Truncation is an approximate efficiency trick to relieve the memory required in the RNN buffers.
- Parameters:
batch (batch) – batch of labeled examples
output (
FloatTensor
) – output of decoder model(batch, tgt_len, hidden)
attns (dict) – dictionary of attention weights
(batch, tgt_len, src_len)
trunc_start (int) – starting position of truncation window
trunc_size (int) – length of truncation window
- Returns:
A tuple with the loss and a
onmt.utils.Statistics
instance.
- classmethod from_opts(opt, model, vocab, train=True)[source]¶
Returns a subclass which wraps around an nn.Module subclass (such as nn.NLLLoss) which defines the loss criterion. The LossCompute object passes relevant data to a Statistics object which handles training/validation logging. The Criterion and LossCompute options are triggered by opt settings.
- ignore_prompt(batch)[source]¶
- Mask the prompt in the target side of the batch examples in order
to set the loss of the prompt to zero.
For finetuning on specific tasks. The end of the prompt must be indicated by the DefaultTokens.MASK_BEFORE
placeholder.
- The masks are supposed to be properly handled by the loss criterion
(e.g. nn.CrossEntropyLoss ).
- Parameters:
batch – The current batch.
Optimizer¶
- class onmt.utils.Optimizer(optimizer, learning_rate, learning_rate_decay_fn=None, max_grad_norm=None)[source]¶
Bases:
object
Controller class for optimization. Mostly a thin wrapper for optim, but also useful for implementing rate scheduling beyond what is currently available. Also implements necessary methods for training RNNs such as grad manipulations.
- Parameters:
optimizer – A
torch.optim.Optimizer
instance.learning_rate – The initial learning rate.
learning_rate_decay_fn – An optional callable taking the current step as argument and return a learning rate scaling factor.
max_grad_norm – Clip gradients to this global norm.
- property amp¶
True if use torch amp mix precision training.
- backward(loss)[source]¶
Wrapper for backward pass. Some optimizer requires ownership of the backward pass.
- classmethod from_opt(model, opt, checkpoint=None)[source]¶
Builds the optimizer from options.
- Parameters:
cls – The
Optimizer
class to instantiate.model – The model to optimize.
opt – The dict of user options.
checkpoint – An optional checkpoint to load states from.
- Returns:
An
Optimizer
instance.
- step()[source]¶
Update the model parameters based on current gradients.
Optionally, will employ gradient modification or update learning rate.
- property training_step¶
The current training step.
- class onmt.utils.AdaFactor(params, lr=None, beta1=0.9, beta2=0.999, eps1=1e-30, eps2=0.001, cliping_threshold=1, non_constant_decay=True, enable_factorization=True, ams_grad=True, weight_decay=0)[source]¶
Bases:
Optimizer
- step(closure=None)[source]¶
Performs a single optimization step (parameter update).
- Parameters:
closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.
Note
Unless otherwise specified, this function should not modify the
.grad
field of the parameters.
- class onmt.utils.FusedAdam(params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, eps_inside_sqrt=False, weight_decay=0.0, max_grad_norm=0.0, amsgrad=False)[source]¶
Bases:
Optimizer
- Implements Adam algorithm. Currently GPU-only.
Requires Apex to be installed via
python setup.py install --cuda_ext --cpp_ext
.
- Parameters:
params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional) – learning rate. (default: 1e-3)
betas (Tuple[float, float], optional) – coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)
weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional) – whether to use the AMSGrad variant of this algorithm from the paper ‘On the Convergence of Adam and Beyond’ (default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional) – in the ‘update parameters’ step, adds eps to the bias-corrected second moment estimate before evaluating square root instead of adding it to the square root of second moment estimate as in the original paper. (default: False)
- step(closure=None, grads=None, output_params=None, scale=1.0, grad_norms=None)[source]¶
Performs a single optimization step.
- Parameters:
closure (callable, optional) – A closure that reevaluates the model and returns the loss.
grads (list of tensors, optional) – weight gradient to use for the optimizer update. If gradients have type torch.half, parameters are expected to be in type torch.float. (default: None)
params (output) – A reduced precision copy of the updated weights written out in addition to the regular updated weights. Have to be of same type as gradients. (default: None)
scale (float, optional) – factor to divide gradient tensor values by before applying to weights. (default: 1)