"""Dynamic decoding utilities."""
import abc
import collections
import tensorflow as tf
import tensorflow_addons as tfa
from opennmt import constants
from opennmt.utils import misc
[docs]class Sampler(abc.ABC):
"""Base class for samplers."""
[docs] @abc.abstractmethod
def __call__(self, scores, num_samples=1):
"""Samples predictions.
Args:
scores: The scores to sample from, a tensor of shape
``[batch_size, vocab_size]``.
num_samples: The number of samples per batch to produce.
Returns:
A tuple ``(sample_ids, sample_scores)``.
"""
raise NotImplementedError()
[docs] @staticmethod
def from_params(params):
"""Constructs a sampler based on user parameters.
Args:
params: A dictionary of user parameters.
Returns:
A :class:`opennmt.utils.Sampler` instance.
"""
sampling_topk = params.get("sampling_topk", 1)
if sampling_topk == 1:
return BestSampler()
else:
return RandomSampler(
from_top_k=sampling_topk, temperature=params.get("sampling_temperature")
)
[docs]class RandomSampler(Sampler):
"""Randomly samples from model outputs."""
[docs] def __init__(self, from_top_k=None, temperature=None):
"""Initializes the random sampler.
Args:
from_top_k: Sample from the top K predictions instead of the full
distribution.
temperature: Divide logits by this value. High temperatures generate more
random samples.
"""
if from_top_k is not None and from_top_k <= 0:
from_top_k = None
self.from_top_k = from_top_k
self.temperature = temperature
def __call__(self, scores, num_samples=1):
if self.from_top_k is None:
sample_ids = _sample_from(scores, num_samples, temperature=self.temperature)
else:
top_scores, top_ids = tf.nn.top_k(scores, k=self.from_top_k)
sample_ids = _sample_from(
top_scores, num_samples, temperature=self.temperature
)
sample_ids = _gather_from_word_indices(top_ids, sample_ids)
sample_scores = _gather_from_word_indices(scores, sample_ids)
return sample_ids, sample_scores
[docs]class BestSampler(Sampler):
"""Sample the best predictions."""
def __call__(self, scores, num_samples=1):
sample_scores, sample_ids = tf.nn.top_k(scores, k=num_samples)
return sample_ids, sample_scores
[docs]class DecodingStrategy(abc.ABC):
"""Base class for decoding strategies."""
@property
def num_hypotheses(self):
"""The number of hypotheses returned by this strategy."""
return 1
[docs] @staticmethod
def from_params(params, tflite_mode=False):
"""Constructs a decoding strategy based on user parameters.
Args:
params: A dictionary of user parameters.
tflite_mode: boolean, should be set to True only if you're exporting with TensorFlow Lite
Returns:
A :class:`opennmt.utils.DecodingStrategy` instance.
"""
beam_size = params.get("beam_width", 1)
if beam_size > 1:
return BeamSearch(
beam_size,
length_penalty=params.get("length_penalty", 0),
coverage_penalty=params.get("coverage_penalty", 0),
tflite_output_size=params.get("tflite_output_size", 250)
if tflite_mode
else None,
)
else:
return GreedySearch()
[docs] @abc.abstractmethod
def initialize(self, start_ids, attention_size=None):
"""Initializes the strategy.
Args:
start_ids: The start decoding ids.
attention_size: If known, the size of the attention vectors (i.e. the
maximum source length).
Returns:
A tuple containing,
- The (possibly transformed) start decoding ids.
- The tensor of finished flags.
- Initial log probabilities per batch.
- An dictionary of additional tensors used during the decoding.
"""
raise NotImplementedError()
[docs] @abc.abstractmethod
def step(
self,
step,
sampler,
log_probs,
cum_log_probs,
finished,
state=None,
attention=None,
**kwargs
):
"""Updates the strategy state.
Args:
step: The current decoding step.
sampler: The sampler that produces predictions.
log_probs: The model log probabilities.
cum_log_probs: The cumulated log probabilities per batch.
finished: The current finished flags.
state: The decoder state.
attention: The attention vector for the current step.
**kwargs: Additional tensors used by this decoding strategy.
Returns:
A tuple containing,
- The predicted word ids.
- The new cumulated log probabilities.
- The updated finished flags.
- The updated decoder state.
- A dictionary with additional tensors used by this decoding strategy.
"""
raise NotImplementedError()
[docs] @abc.abstractmethod
def finalize(self, outputs, end_id, attention=None, **kwargs):
"""Finalize the predictions.
Args:
outputs: The array of sampled ids.
end_id: The end token id.
attention: The array of attention outputs.
**kwargs: Additional tensors used by this decoding strategy.
Returns:
A tuple containing,
- The final predictions as a tensor of shape :math:`[B, H, T_t]`.
- The final attention history of shape :math:`[B, H, T_t, T_s]`.
- The final sequence lengths of shape :math:`[B, H]`.
"""
raise NotImplementedError()
[docs]class GreedySearch(DecodingStrategy):
"""A basic greedy search strategy."""
[docs] def initialize(self, start_ids, attention_size=None):
batch_size = tf.shape(start_ids)[0]
finished = tf.zeros([batch_size], dtype=tf.bool)
initial_log_probs = tf.zeros([batch_size], dtype=tf.float32)
return start_ids, finished, initial_log_probs, {}
[docs] def step(
self,
step,
sampler,
log_probs,
cum_log_probs,
finished,
state=None,
attention=None,
**kwargs
):
sample_ids, sample_log_probs = sampler(log_probs)
sample_ids = tf.reshape(sample_ids, [-1])
sample_log_probs = tf.reshape(sample_log_probs, [-1])
cum_log_probs += sample_log_probs
return sample_ids, cum_log_probs, finished, state, kwargs
[docs] def finalize(self, outputs, end_id, attention=None, **kwargs):
ids = tf.transpose(outputs.stack())
ids = tf.expand_dims(ids, 1)
lengths = _lengths_from_ids(ids, end_id)
if attention is not None:
attention = tf.transpose(attention.stack(), perm=[1, 0, 2])
attention = tf.expand_dims(attention, 1)
return ids, attention, lengths
[docs]class BeamSearch(DecodingStrategy):
"""A beam search strategy."""
[docs] def __init__(
self, beam_size, length_penalty=0, coverage_penalty=0, tflite_output_size=None
):
"""Initializes the decoding strategy.
Args:
beam_size: The number of paths to consider per batch.
length_penalty: Length penalty, see https://arxiv.org/abs/1609.08144.
coverage_penalty: Coverage penalty, see https://arxiv.org/abs/1609.08144.
tflite_output_size: None if not TFLite exporting. Is the output size of TFLite model
"""
self.beam_size = beam_size
self.length_penalty = length_penalty
self.coverage_penalty = coverage_penalty
self._state_reorder_flags = None
self.tflite_output_size = tflite_output_size
@property
def num_hypotheses(self):
return self.beam_size
def _set_state_reorder_flags(self, state_reorder_flags):
"""Sets state reorder flags, a structure matching the decoder state that
indicates which tensor should be reorded during beam search.
"""
self._state_reorder_flags = state_reorder_flags
[docs] def initialize(self, start_ids, attention_size=None):
batch_size = tf.shape(start_ids)[0]
start_ids = tfa.seq2seq.tile_batch(start_ids, self.beam_size)
finished = tf.zeros([batch_size * self.beam_size], dtype=tf.bool)
# Give all probability to first beam for the first iteration.
initial_log_probs = tf.tile(
[0.0] + [-float("inf")] * (self.beam_size - 1), [batch_size]
)
if self.tflite_output_size is not None:
parent_ids = tf.TensorArray(
tf.int32,
size=self.tflite_output_size,
dynamic_size=False,
element_shape=tf.TensorShape(None),
)
else:
parent_ids = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
extra_vars = {
"parent_ids": parent_ids,
"sequence_lengths": tf.zeros([batch_size * self.beam_size], dtype=tf.int32),
}
if self.coverage_penalty != 0:
if attention_size is None:
raise ValueError(
"The attention size should be known to support coverage penalty"
)
extra_vars["accumulated_attention"] = tf.zeros(
[batch_size * self.beam_size, attention_size]
)
return start_ids, finished, initial_log_probs, extra_vars
def _get_scores(
self, log_probs, sequence_lengths, finished, accumulated_attention=None
):
scores = log_probs
if self.length_penalty != 0:
expand_sequence_lengths = tf.expand_dims(sequence_lengths, 1)
scores /= tf.pow(
((5.0 + tf.cast(expand_sequence_lengths + 1, scores.dtype)) / 6.0),
self.length_penalty,
)
if self.coverage_penalty != 0:
# Mask out of range steps with ones (log(1) == 0).
accumulated_attention = tf.where(
tf.equal(accumulated_attention, 0.0),
x=tf.ones_like(accumulated_attention),
y=accumulated_attention,
)
coverage_penalty = tf.reduce_sum(
tf.math.log(tf.minimum(accumulated_attention, 1.0)), 1
)
# Apply coverage penalty to finished predictions.
coverage_penalty *= tf.cast(finished, coverage_penalty.dtype)
scores += self.coverage_penalty * tf.expand_dims(coverage_penalty, 1)
return scores
[docs] def step(
self,
step,
sampler,
log_probs,
cum_log_probs,
finished,
state=None,
attention=None,
**kwargs
):
parent_ids = kwargs["parent_ids"]
sequence_lengths = kwargs["sequence_lengths"]
if self.coverage_penalty != 0:
if attention is None:
raise ValueError(
"Coverage penalty is enabled but the model did not "
"return an attention vector"
)
not_finished = tf.math.logical_not(finished)
attention *= tf.expand_dims(tf.cast(not_finished, attention.dtype), 1)
accumulated_attention = kwargs["accumulated_attention"] + attention
else:
accumulated_attention = None
# Compute scores from log probabilities.
vocab_size = log_probs.shape[-1]
total_probs = log_probs + tf.expand_dims(
cum_log_probs, 1
) # Add current beam probability.
scores = self._get_scores(
total_probs,
sequence_lengths,
finished,
accumulated_attention=accumulated_attention,
)
scores = tf.reshape(scores, [-1, self.beam_size * vocab_size])
total_probs = tf.reshape(total_probs, [-1, self.beam_size * vocab_size])
# Sample predictions.
sample_ids, sample_scores = sampler(scores, num_samples=self.beam_size)
cum_log_probs = tf.reshape(
_gather_from_word_indices(total_probs, sample_ids), [-1]
)
sample_ids = tf.reshape(sample_ids, [-1])
sample_scores = tf.reshape(sample_scores, [-1])
# Resolve beam origin and word ids.
word_ids = sample_ids % vocab_size
beam_ids = sample_ids // vocab_size
beam_indices = (
tf.range(tf.shape(word_ids)[0]) // self.beam_size
) * self.beam_size + beam_ids
# Update sequence_length of unfinished sequence.
sequence_lengths = tf.where(
finished, x=sequence_lengths, y=sequence_lengths + 1
)
# Update state and flags.
finished = tf.gather(finished, beam_indices)
sequence_lengths = tf.gather(sequence_lengths, beam_indices)
parent_ids = parent_ids.write(step, beam_ids)
extra_vars = {
"parent_ids": parent_ids,
"sequence_lengths": sequence_lengths,
}
if accumulated_attention is not None:
extra_vars["accumulated_attention"] = tf.gather(
accumulated_attention, beam_indices
)
if state is not None:
state = _reorder_state(
state, beam_indices, reorder_flags=self._state_reorder_flags
)
return word_ids, cum_log_probs, finished, state, extra_vars
[docs] def finalize(self, outputs, end_id, attention=None, **kwargs):
parent_ids = kwargs["parent_ids"]
sequence_lengths = kwargs["sequence_lengths"]
maximum_lengths = tf.reduce_max(
tf.reshape(sequence_lengths, [-1, self.beam_size]), axis=-1
)
max_time = outputs.size()
array_shape = [max_time, -1, self.beam_size]
step_ids = tf.reshape(outputs.stack(), array_shape)
parent_ids = tf.reshape(parent_ids.stack(), array_shape)
ids = _gather_tree(step_ids, parent_ids, maximum_lengths, end_id)
ids = tf.transpose(ids, perm=[1, 2, 0])
lengths = _lengths_from_ids(ids, end_id)
if attention is not None:
attention = _gather_tree_from_array(attention.stack(), parent_ids, lengths)
attention = tf.transpose(attention, perm=[1, 0, 2])
attention = tf.reshape(
attention, [tf.shape(ids)[0], self.beam_size, max_time, -1]
)
return ids, attention, lengths
[docs]class DecodingResult(
collections.namedtuple(
"DecodingResult", ("ids", "lengths", "log_probs", "attention", "state")
)
):
"""Final decoding result.
Args:
ids: The predicted ids of shape :math:`[B, H, T_t]`.
lengths: The produced sequences length of shape :math:`[B, H]`.
log_probs: The cumulated log probabilities of shape :math:`[B, H]`.
attention: The attention history of shape :math:`[B, H, T_t, T_s]`.
state: The final decoding state.
"""
[docs]def dynamic_decode(
symbols_to_logits_fn,
start_ids,
end_id=constants.END_OF_SENTENCE_ID,
initial_state=None,
decoding_strategy=None,
sampler=None,
maximum_iterations=None,
minimum_iterations=0,
attention_history=False,
attention_size=None,
tflite_output_size=None,
):
"""Dynamic decoding.
Args:
symbols_to_logits_fn: A callable taking ``(symbols, step, state)`` and
returning ``(logits, state, attention)`` (``attention`` is optional).
start_ids: Initial input IDs of shape :math:`[B]`.
end_id: ID of the end of sequence token.
initial_state: Initial decoder state.
decoding_strategy: A :class:`opennmt.utils.DecodingStrategy`
instance that defines the decoding logic. Defaults to a greedy search.
sampler: A :class:`opennmt.utils.Sampler` instance that samples
predictions from the model output. Defaults to an argmax sampling.
maximum_iterations: The maximum number of iterations to decode for.
minimum_iterations: The minimum number of iterations to decode for.
attention_history: Gather attention history during the decoding.
attention_size: If known, the size of the attention vectors (i.e. the
maximum source length).
tflite_output_size: If not None will run TFLite safe, is the size of 1D output tensor.
Returns:
A :class:`opennmt.utils.DecodingResult` instance.
"""
if initial_state is None:
initial_state = {}
if decoding_strategy is None:
decoding_strategy = GreedySearch()
if sampler is None:
sampler = BestSampler()
is_tflite_run = tflite_output_size is not None
def _cond(
step, finished, state, inputs, outputs, attention, cum_log_probs, extra_vars
):
return tf.reduce_any(tf.logical_not(finished))
def _body(
step, finished, state, inputs, outputs, attention, cum_log_probs, extra_vars
):
# Get log probs from the model.
result = symbols_to_logits_fn(inputs, step, state)
logits, state = result[0], result[1]
attn = result[2] if len(result) > 2 else None
logits = tf.cast(logits, tf.float32)
# Penalize or force EOS.
batch_size, vocab_size = misc.shape_list(logits)
eos_max_prob = tf.one_hot(
tf.fill([batch_size], end_id),
vocab_size,
on_value=logits.dtype.max,
off_value=logits.dtype.min,
)
logits = tf.cond(
step < minimum_iterations,
true_fn=lambda: _penalize_token(logits, end_id),
false_fn=lambda: tf.where(
tf.broadcast_to(tf.expand_dims(finished, -1), tf.shape(logits)),
x=eos_max_prob,
y=logits,
),
)
log_probs = tf.nn.log_softmax(logits)
# Run one decoding strategy step.
(
output,
next_cum_log_probs,
finished,
state,
extra_vars,
) = decoding_strategy.step(
step,
sampler,
log_probs,
cum_log_probs,
finished,
state=state,
attention=attn,
**extra_vars,
)
# Update loop vars.
outputs = outputs.write(step, output)
if attention_history:
if attn is None:
raise ValueError(
"attention_history is set but the model did not return attention"
)
attention = attention.write(step, tf.cast(attn, tf.float32))
cum_log_probs = tf.where(finished, x=cum_log_probs, y=next_cum_log_probs)
finished = tf.logical_or(finished, tf.equal(output, end_id))
return (
step + 1,
finished,
state,
output,
outputs,
attention,
cum_log_probs,
extra_vars,
)
start_ids = tf.convert_to_tensor(start_ids)
ids_dtype = start_ids.dtype
start_ids = tf.cast(start_ids, tf.int32)
start_ids, finished, initial_log_probs, extra_vars = decoding_strategy.initialize(
start_ids, attention_size=attention_size
)
step = tf.constant(0, dtype=tf.int32)
if is_tflite_run:
output_shape = tf.TensorShape(None)
outputs = tf.TensorArray(
tf.int32,
size=tflite_output_size,
dynamic_size=False,
element_shape=output_shape,
)
attn_shape = tf.TensorShape(None)
attention = tf.TensorArray(
tf.float32,
size=tflite_output_size,
dynamic_size=False,
element_shape=attn_shape,
)
maximum_iterations = (
tflite_output_size
if maximum_iterations > tflite_output_size
else maximum_iterations
)
else:
output_shape = tf.TensorShape(None)
outputs = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
attn_shape = tf.TensorShape(None)
attention = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
_, _, state, _, outputs, attention, log_probs, extra_vars = tf.while_loop(
_cond,
_body,
loop_vars=(
step,
finished,
initial_state,
start_ids,
outputs,
attention,
initial_log_probs,
extra_vars,
),
shape_invariants=(
step.shape,
finished.shape,
tf.nest.map_structure(_get_shape_invariants, initial_state),
start_ids.shape,
output_shape,
attn_shape,
initial_log_probs.shape,
tf.nest.map_structure(_get_shape_invariants, extra_vars),
),
parallel_iterations=1,
maximum_iterations=maximum_iterations,
)
ids, attention, lengths = decoding_strategy.finalize(
outputs,
end_id,
attention=attention if attention_history else None,
**extra_vars,
)
if attention is not None:
attention = attention[:, :, :-1] # Ignore attention for </s>.
log_probs = tf.reshape(log_probs, [-1, decoding_strategy.num_hypotheses])
ids = tf.cast(ids, ids_dtype)
return DecodingResult(
ids=ids, lengths=lengths, log_probs=log_probs, attention=attention, state=state
)
def _reorder_state(state, indices, reorder_flags=None):
"""Gather batch indices from the state tensors."""
def _reorder_one(tensor, reorder=True):
if not reorder or isinstance(tensor, tf.TensorArray) or tensor.shape.ndims == 0:
return tensor
return tf.gather(tensor, indices)
args = [state]
if reorder_flags is not None:
tf.nest.assert_same_structure(state, reorder_flags)
args.append(reorder_flags)
return tf.nest.map_structure(_reorder_one, *args)
def _get_shape_invariants(tensor):
"""Returns the shape of the tensor but sets middle dims to None."""
if isinstance(tensor, tf.TensorArray):
shape = None
else:
shape = tensor.shape.as_list()
for i in range(1, len(shape) - 1):
shape[i] = None
return tf.TensorShape(shape)
def _penalize_token(log_probs, token_id, penalty=-1e7):
"""Penalize token probabilities."""
depth = log_probs.shape[-1]
penalty = tf.one_hot([token_id], depth, on_value=tf.cast(penalty, log_probs.dtype))
return log_probs + penalty
def _sample_from(logits, num_samples, temperature=None):
"""Sample N values from the unscaled probability distribution."""
if temperature is not None:
logits /= tf.cast(temperature, logits.dtype)
return tf.random.categorical(logits, num_samples, dtype=tf.int32)
def _gather_from_word_indices(tensor, indices):
"""Index the depth dim of a 2D tensor."""
return tf.gather(tensor, indices, axis=-1, batch_dims=1)
def _lengths_from_ids(ids, end_id):
"""Compute sequence lengths from word ids."""
lengths = tf.not_equal(ids, end_id)
lengths = tf.cast(lengths, tf.int32)
lengths = tf.reduce_sum(lengths, axis=-1)
return lengths
# The gather_tree functions are imported from TensorFlow Addons:
# https://github.com/tensorflow/addons/blob/master/tensorflow_addons/seq2seq/beam_search_decoder.py
#
# We do not use the Addons version because the public gather_tree function is
# wrapped by a tf.function. This should not be an issue, but the function is
# unexpectedly garbage collected in our test suite.
def _gather_tree(step_ids, parent_ids, max_sequence_lengths, end_token):
"""Calculates the full beams from the per-step ids and parent beam ids.
For a given beam, past the time step containing the first decoded
``end_token`` all values are filled in with ``end_token``.
Args:
step_ids: The predicted token IDs.
A ``int32`` Tensor of shape ``[max_time, batch_size, beam_width]``.
parent_ids: The parent beam indices.
A ``int32`` Tensor of shape ``[max_time, batch_size, beam_width]``.
max_sequence_lengths: The maximum sequence length of each batch.
A ``int32`` Tensor of shape ``[batch_size]``.
end_token: The end token ID.
Returns:
The reordered token IDs based on ``parent_ids``.
Raises:
InvalidArgumentError: if ``parent_ids`` contains an invalid index.
"""
input_shape = tf.shape(parent_ids)
max_time = input_shape[0]
beam_width = input_shape[2]
max_sequence_lengths = tf.math.minimum(max_sequence_lengths, max_time)
mask = tf.expand_dims(
tf.transpose(tf.sequence_mask(max_sequence_lengths, maxlen=max_time)), -1
)
# Mask out of range ids.
end_tokens = tf.fill(input_shape, end_token)
step_ids = tf.where(mask, x=step_ids, y=end_tokens)
parent_ids = tf.where(mask, x=parent_ids, y=tf.zeros_like(parent_ids))
assert_op = tf.debugging.Assert(
tf.math.reduce_all(
tf.math.logical_and(parent_ids >= 0, parent_ids < beam_width)
),
["All parent ids must be positive and less than beam_width"],
)
# Reverse all sequences as we need to gather from the end.
with tf.control_dependencies([assert_op]):
rev_step_ids = tf.reverse_sequence(
step_ids, max_sequence_lengths, seq_axis=0, batch_axis=1
)
rev_parent_ids = tf.reverse_sequence(
parent_ids, max_sequence_lengths, seq_axis=0, batch_axis=1
)
# Initialize output ids and parent based on last step.
output_ids = tf.TensorArray(step_ids.dtype, size=max_time, dynamic_size=False)
output_ids = output_ids.write(0, rev_step_ids[0])
parent = rev_parent_ids[0]
# For each step, gather ids based on beam origin.
for t in tf.range(1, max_time):
ids = tf.gather(rev_step_ids[t], parent, batch_dims=1)
parent = tf.gather(rev_parent_ids[t], parent, batch_dims=1)
output_ids = output_ids.write(t, ids)
# Reverse sequences to their original order.
output_ids = output_ids.stack()
output_ids = tf.reverse_sequence(
output_ids, max_sequence_lengths, seq_axis=0, batch_axis=1
)
# Ensure that there are only end_token after the first end_token.
in_bound_steps = tf.math.cumsum(tf.cast(output_ids == end_token, tf.int32)) == 0
output_ids = tf.where(in_bound_steps, x=output_ids, y=end_tokens)
return output_ids
def _gather_tree_from_array(t, parent_ids, sequence_length):
"""Calculates the full beams for a ``TensorArray``.
Args:
t: A stacked ``TensorArray`` of size ``max_time`` that contains Tensors of
shape ``[batch_size, beam_width, s]`` or ``[batch_size * beam_width, s]``
where ``s`` is the depth shape.
parent_ids: The parent ids of shape ``[max_time, batch_size, beam_width]``.
sequence_length: The sequence length of shape ``[batch_size, beam_width]``.
Returns:
A Tensor which is a stacked ``TensorArray`` of the same size and type as
``t`` and where beams are sorted in each Tensor according to
``parent_ids``.
"""
max_time = parent_ids.shape[0] or tf.shape(parent_ids)[0]
batch_size = parent_ids.shape[1] or tf.shape(parent_ids)[1]
beam_width = parent_ids.shape[2] or tf.shape(parent_ids)[2]
# Generate beam ids that will be reordered by gather_tree.
beam_ids = tf.reshape(tf.range(beam_width), [1, 1, -1])
beam_ids = tf.tile(beam_ids, [max_time, batch_size, 1])
max_sequence_lengths = tf.cast(tf.reduce_max(sequence_length, axis=1), tf.int32)
sorted_beam_ids = _gather_tree(
step_ids=beam_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=beam_width + 1,
)
# For out of range steps, simply copy the same beam.
in_bound_steps = tf.transpose(
tf.sequence_mask(sequence_length, maxlen=max_time), perm=[2, 0, 1]
)
sorted_beam_ids = tf.where(in_bound_steps, x=sorted_beam_ids, y=beam_ids)
# Gather from a tensor with collapsed additional dimensions.
final_shape = tf.shape(t)
gather_from = tf.reshape(t, [max_time, batch_size, beam_width, -1])
ordered = tf.gather(gather_from, sorted_beam_ids, axis=2, batch_dims=2)
ordered = tf.reshape(ordered, final_shape)
return ordered