Doc: Framework

Model

Trainer

class onmt.Trainer(model, train_loss, valid_loss, optim, trunc_size=0, shard_size=32, data_type='text', norm_method='sents', grad_accum_count=1, n_gpu=1, gpu_rank=1, gpu_verbose_level=0, report_manager=None, model_saver=None)[source]

Class that controls the training process.

Parameters:
  • model (onmt.models.model.NMTModel) – translation model to train
  • train_loss (onmt.utils.loss.LossComputeBase) – training loss computation
  • valid_loss (onmt.utils.loss.LossComputeBase) – training loss computation
  • optim (onmt.utils.optimizers.Optimizer) – the optimizer responsible for update
  • trunc_size (int) – length of truncated back propagation through time
  • shard_size (int) – compute loss in shards of this size for efficiency
  • data_type (string) – type of the source input: [text|img|audio]
  • norm_method (string) – normalization methods: [sents|tokens]
  • grad_accum_count (int) – accumulate gradients this many times.
  • report_manager (onmt.utils.ReportMgrBase) – the object that creates reports, or None
  • model_saver (onmt.models.ModelSaverBase) – the saver is used to save a checkpoint. Thus nothing will be saved if this parameter is None
train(train_iter_fct, valid_iter_fct, train_steps, valid_steps)[source]

The main training loops. by iterating over training data (i.e. train_iter_fct) and running validation (i.e. iterating over valid_iter_fct

Parameters:
  • train_iter_fct (function) – a function that returns the train iterator. e.g. something like train_iter_fct = lambda: generator(*args, **kwargs)
  • valid_iter_fct (function) – same as train_iter_fct, for valid data
  • train_steps (int) –
  • valid_steps (int) –
  • save_checkpoint_steps (int) –
Returns:

None

validate(valid_iter)[source]
Validate model.
valid_iter: validate data iterator
Returns:validation loss statistics
Return type:nmt.Statistics

Loss

Optim

onmt.Optim.Optim

alias of onmt.utils.optimizers.Optimizer