batch_dataset
- opennmt.data.batch_dataset(batch_size, padded_shapes=None)[source]
Transformation that batches a dataset.
Example
>>> dataset = dataset.apply(opennmt.data.batch_dataset(...))
- Parameters
batch_size – The batch size.
padded_shapes – The padded shapes for this dataset. If
None
, the shapes are automatically inferred from the dataset output shapes.
- Returns
A
tf.data.Dataset
transformation.