batch_sequence_dataset

opennmt.data.batch_sequence_dataset(batch_size, batch_type='examples', batch_multiplier=1, batch_size_multiple=1, length_bucket_width=None, length_fn=None, maximum_length=None, pad_to_bucket_boundary=False, padded_shapes=None)[source]

Transformation that batches a dataset of sequences.

This implements an example-based and a token-based batching strategy with optional bucketing of sequences.

Bucketing makes the batches contain sequences of similar lengths to optimize the training efficiency. For example, if length_bucket_width is 5, sequences will be organized by the following length buckets:

1 - 5 | 6 - 10 | 11 - 15 | …

Then when building the next batch, sequences will be selected from the same length bucket.

If the dataset has parallel elements (e.g. a parallel source and target dataset), the element is assigned to the bucket corresponding to the maximum length of all parallel elements.

Example

>>> dataset = dataset.apply(opennmt.data.batch_sequence_dataset(...))
Parameters
  • batch_size – The batch size.

  • batch_type – The training batching strategy to use: can be “examples” or “tokens”.

  • batch_multiplier – The batch size multiplier.

  • batch_size_multiple – When batch_type is “tokens”, ensure that the resulting batch size is a multiple of this value.

  • length_bucket_width – The width of the length buckets to select batch candidates from. None to not constrain batch formation.

  • length_fn – A function or list of functions (in case of a parallel dataset) that take features as argument and return the associated sequence length.

  • maximum_length – If known, the maximum length or list of maximum lengths (in case of a parallel dataset). This argument is required with pad_to_bucket_boundary.

  • pad_to_bucket_boundary – Pad each batch to the length bucket boundary.

  • padded_shapes – The padded shapes for this dataset. If None, the shapes are automatically inferred from the dataset output shapes.

Returns

A tf.data.Dataset transformation.

Raises
  • ValueError – if batch_type is not one of “examples” or “tokens”.

  • ValueError – if batch_type is “tokens” but length_bucket_width is not set.

  • ValueError – if the number of length functions in length_fn does not match the number of parallel elements.