Source code for opennmt.encoders.mean_encoder
"""Define a minimal encoder."""
import tensorflow as tf
from opennmt.encoders.encoder import Encoder
[docs]class MeanEncoder(Encoder):
"""A simple encoder that takes the mean of its inputs."""
[docs] def call(self, inputs, sequence_length=None, training=None):
outputs = tf.identity(inputs)
if sequence_length is not None:
inputs = tf.RaggedTensor.from_tensor(inputs, lengths=sequence_length)
state = tf.reduce_mean(inputs, axis=1)
return (outputs, state, sequence_length)