batch_sequence_dataset
- opennmt.data.batch_sequence_dataset(batch_size, batch_type='examples', batch_multiplier=1, batch_size_multiple=1, length_bucket_width=None, length_fn=None, maximum_length=None, pad_to_bucket_boundary=False, padded_shapes=None)[source]
Transformation that batches a dataset of sequences.
This implements an example-based and a token-based batching strategy with optional bucketing of sequences.
Bucketing makes the batches contain sequences of similar lengths to optimize the training efficiency. For example, if
length_bucket_width
is 5, sequences will be organized by the following length buckets:1 - 5 | 6 - 10 | 11 - 15 | …
Then when building the next batch, sequences will be selected from the same length bucket.
If the dataset has parallel elements (e.g. a parallel source and target dataset), the element is assigned to the bucket corresponding to the maximum length of all parallel elements.
Example
>>> dataset = dataset.apply(opennmt.data.batch_sequence_dataset(...))
- Parameters
batch_size – The batch size.
batch_type – The training batching strategy to use: can be “examples” or “tokens”.
batch_multiplier – The batch size multiplier.
batch_size_multiple – When
batch_type
is “tokens”, ensure that the resulting batch size is a multiple of this value.length_bucket_width – The width of the length buckets to select batch candidates from.
None
to not constrain batch formation.length_fn – A function or list of functions (in case of a parallel dataset) that take features as argument and return the associated sequence length.
maximum_length – If known, the maximum length or list of maximum lengths (in case of a parallel dataset). This argument is required with
pad_to_bucket_boundary
.pad_to_bucket_boundary – Pad each batch to the length bucket boundary.
padded_shapes – The padded shapes for this dataset. If
None
, the shapes are automatically inferred from the dataset output shapes.
- Returns
A
tf.data.Dataset
transformation.- Raises
ValueError – if
batch_type
is not one of “examples” or “tokens”.ValueError – if
batch_type
is “tokens” butlength_bucket_width
is not set.ValueError – if the number of length functions in
length_fn
does not match the number of parallel elements.
See also