random_shard

opennmt.data.random_shard(shard_size, dataset_size)[source]

Transformation that shards the dataset in a random order.

Example

>>> dataset = tf.data.Dataset.range(6)
>>> dataset = dataset.apply(opennmt.data.random_shard(2, 6))
>>> list(dataset.as_numpy_iterator())
[0, 1, 4, 5, 2, 3]
Parameters
  • shard_size – The number of examples in each shard.

  • dataset_size – The total number of examples in the dataset.

Returns

A tf.data.Dataset transformation.