batch_dataset

opennmt.data.batch_dataset(batch_size, padded_shapes=None)[source]

Transformation that batches a dataset.

Example

>>> dataset = dataset.apply(opennmt.data.batch_dataset(...))
Parameters
  • batch_size – The batch size.

  • padded_shapes – The padded shapes for this dataset. If None, the shapes are automatically inferred from the dataset output shapes.

Returns

A tf.data.Dataset transformation.