opennmt.utils.Sampler

class opennmt.utils.Sampler[source]

Base class for samplers.

Inherits from: builtins.object

abstract __call__(scores, num_samples=1)[source]

Samples predictions.

Parameters
  • scores – The scores to sample from, a tensor of shape [batch_size, vocab_size].

  • num_samples – The number of samples per batch to produce.

Returns

A tuple (sample_ids, sample_scores).

static from_params(params)[source]

Constructs a sampler based on user parameters.

Parameters

params – A dictionary of user parameters.

Returns

A opennmt.utils.Sampler instance.