Source code for onmt.encoders.mean_encoder

"""Define a minimal encoder."""
from onmt.encoders.encoder import EncoderBase
from onmt.utils.misc import sequence_mask
import torch


[docs]class MeanEncoder(EncoderBase): """A trivial non-recurrent encoder. Simply applies mean pooling. Args: num_layers (int): number of replicated layers embeddings (onmt.modules.Embeddings): embedding module to use """ def __init__(self, num_layers, embeddings): super(MeanEncoder, self).__init__() self.num_layers = num_layers self.embeddings = embeddings
[docs] @classmethod def from_opt(cls, opt, embeddings): """Alternate constructor.""" return cls(opt.enc_layers, embeddings)
[docs] def forward(self, src, src_len=None): """See :func:`EncoderBase.forward()`""" emb = self.embeddings(src) batch, _, emb_dim = emb.size() if src_len is not None: # we avoid padding while mean pooling mask = (~sequence_mask(src_len)).float() mask = mask / src_len.unsqueeze(1).float() mean = torch.bmm(mask.unsqueeze(1), emb).squeeze(1) else: mean = emb.mean(1) mean = mean.expand(self.num_layers, batch, emb_dim) enc_out = emb enc_final_hs = (mean, mean) return enc_out, enc_final_hs, src_len