get_dataset_size

opennmt.data.get_dataset_size(dataset, batch_size=5000)[source]

Get the dataset size.

Example

>>> dataset = tf.data.Dataset.range(5)
>>> opennmt.data.get_dataset_size(dataset).numpy()
5
Parameters
  • dataset – A dataset.

  • batch_size – The batch size to use to improve the scan performance, or None to scan the dataset as-is.

Returns

The dataset size or None if the dataset is infinite.