opennmt.utils.parallel module

Utilities to run execution in parallel.

class opennmt.utils.parallel.GraphDispatcher(num_devices=None, daisy_chain_variables=True, devices=None, session_config=None)[source]

Bases: object

Helper class to replicate graph parts on multiple devices and dispatch sharded batches.

__init__(num_devices=None, daisy_chain_variables=True, devices=None, session_config=None)[source]

Initializes the dispatcher.

Parameters:
  • num_devices – The number of devices to dispatch on.
  • daisy_chain_variables – If True, variables are copied in a daisy chain fashion between devices (credits to Tensor2Tensor).
  • devices – List of devices to use (takes priority over num_devices).
  • session_config – Session configuration to use when querying available devices.
Raises:

ValueError – if the number of visible devices is lower than num_devices.

shard(data)[source]

Shards a structure of tf.Tensor for dispatching.

Parameters:data – A tf.Tensor of dictionary of tf.Tensor.
Returns:A list of the same tf.Tensor structure.
repeat(data)[source]

Ensures that the object is dispatchable list.

Parameters:data – The object to convert.
Returns:data if it is valid list or a list where data is replicated.
Raises:ValueError – if data is a non dispatchable list.
__call__(fun, *args, **kwargs)[source]

Dispatches fun calls accross devices.

Each argument must either not be a list or a list with length the number of devices used for dispatching.

Parameters:
  • fun – A callable.
  • *args – The callable arguments.
  • **kwargs – The callable keyword arguments.
Returns:

The sharded outputs of fun.

opennmt.utils.parallel.split_batch(data, num_shards)[source]

Split data into shards.

opennmt.utils.parallel.get_devices(num_devices=None, session_config=None)[source]

Returns available devices.

Parameters:
  • num_devices – The number of devices to get.
  • session_config – An optional session configuration to use when querying available devices.
Returns:

A list of devices.

Raises:

ValueError – if num_devices is set but the number of visible devices is lower than it.