split_heads

opennmt.layers.split_heads(inputs, num_heads)[source]

Splits a tensor in depth.

Parameters
  • inputs – A tf.Tensor of shape \([B, T, D]\).

  • num_heads – The number of heads \(H\).

Returns

A tf.Tensor of shape \([B, H, T, D / H]\).