filter_irregular_batches

opennmt.data.filter_irregular_batches(multiple)[source]

Transformation that filters out batches based on their size.

Example

>>> dataset = tf.data.Dataset.range(10).batch(3)
>>> dataset = dataset.apply(opennmt.data.filter_irregular_batches(3))
>>> len(list(iter(dataset)))
3
Parameters

multiple – The divisor of the batch size.

Returns

A tf.data.Dataset transformation.