Source code for onmt.modules.embeddings

""" Embeddings module """
import math
import warnings

import torch
import torch.nn as nn
from torch.nn.utils import skip_init

from onmt.modules.util_class import Elementwise
from onmt.utils.logging import logger


class SequenceTooLongError(Exception):
    pass


[docs]class PositionalEncoding(nn.Module): """Sinusoidal positional encoding for non-recurrent neural networks. Implementation based on "Attention Is All You Need" :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` Args: dim (int): embedding size """ def __init__(self, dim, enc_type, max_len=5000): if dim % 2 != 0: raise ValueError( "Cannot use sin/cos positional encoding with " "odd dim (got dim={:d})".format(dim) ) if enc_type == "SinusoidalInterleaved": pe = torch.zeros(max_len, dim) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp( ( torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim) ) ) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) elif enc_type == "SinusoidalConcat": half_dim = dim // 2 pe = math.log(10000) / (half_dim - 1) pe = torch.exp(torch.arange(half_dim, dtype=torch.float) * -pe) pe = torch.arange(max_len, dtype=torch.float).unsqueeze(1) * pe.unsqueeze(0) pe = torch.cat([torch.sin(pe), torch.cos(pe)], dim=1).view(max_len, -1) else: raise ValueError( "Choice of Position encoding is SinusoidalInterleaved or" " SinusoidalConcat." ) pe = pe.unsqueeze(1) # we keep pe (len x batch x dim) for back comp super(PositionalEncoding, self).__init__() self.register_buffer("pe", pe) self.dim = dim
[docs] def forward(self, emb, step=None): """Embed inputs. Args: emb (FloatTensor): Sequence of word vectors ``(batch_size, seq_len, self.dim)`` step (int or NoneType): If stepwise (``seq_len = 1``), use the encoding for this position. """ pe = self.pe.transpose(0, 1) # (batch x len x dim) emb = emb * math.sqrt(self.dim) step = step or 0 if pe.size(1) < step + emb.size(1): raise SequenceTooLongError( f"Sequence is {emb.size(1) + step} but PositionalEncoding is" f" limited to {self.pe.size(1)}. See max_len argument." ) emb = emb + pe[:, step : emb.size(1) + step, :] return emb
[docs]class Embeddings(nn.Module): """Words embeddings for encoder/decoder. Additionally includes ability to add sparse input features based on "Linguistic Input Features Improve Neural Machine Translation" :cite:`sennrich2016linguistic`. .. mermaid:: graph LR A[Input] C[Feature 1 Lookup] A-->B[Word Lookup] A-->C A-->D[Feature N Lookup] B-->E[MLP/Concat] C-->E D-->E E-->F[Output] Args: word_vec_size (int): size of the dictionary of embeddings. word_vocab_size (int): size of dictionary of embeddings for words. word_padding_idx (int): padding index for words in the embeddings. position_encoding (bool): see :class:`~onmt.modules.PositionalEncoding` feat_merge (string): merge action for the features embeddings: concat, sum or mlp. feat_vec_exponent (float): when using `-feat_merge concat`, feature embedding size is N^feat_dim_exponent, where N is the number of values the feature takes. feat_vec_size (int): embedding dimension for features when using `-feat_merge mlp` feat_padding_idx (List[int]): padding index for a list of features in the embeddings. feat_vocab_sizes (List[int], optional): list of size of dictionary of embeddings for each feature. dropout (float): dropout probability. sparse (bool): sparse embbedings default False freeze_word_vecs (bool): freeze weights of word vectors. """ def __init__( self, word_vec_size, word_vocab_size, word_padding_idx, position_encoding=False, position_encoding_type="SinusoidalInterleaved", feat_merge="concat", feat_vec_exponent=0.7, feat_vec_size=-1, feat_padding_idx=[], feat_vocab_sizes=[], dropout=0, sparse=False, freeze_word_vecs=False, ): self._validate_args( feat_merge, feat_vocab_sizes, feat_vec_exponent, feat_vec_size, feat_padding_idx, ) if feat_padding_idx is None: feat_padding_idx = [] self.word_padding_idx = word_padding_idx self.word_vec_size = word_vec_size # Dimensions and padding for constructing the word embedding matrix vocab_sizes = [word_vocab_size] emb_dims = [word_vec_size] pad_indices = [word_padding_idx] # Dimensions and padding for feature embedding matrices # (these have no effect if feat_vocab_sizes is empty) if feat_merge == "sum": feat_dims = [word_vec_size] * len(feat_vocab_sizes) elif feat_vec_size > 0: feat_dims = [feat_vec_size] * len(feat_vocab_sizes) else: feat_dims = [int(vocab**feat_vec_exponent) for vocab in feat_vocab_sizes] vocab_sizes.extend(feat_vocab_sizes) emb_dims.extend(feat_dims) pad_indices.extend(feat_padding_idx) # The embedding matrix look-up tables. The first look-up table # is for words. Subsequent ones are for features, if any exist. emb_params = zip(vocab_sizes, emb_dims, pad_indices) embeddings = [ skip_init( nn.Embedding, num_embeddings=vocab, embedding_dim=dim, padding_idx=pad, sparse=sparse, ) for vocab, dim, pad in emb_params ] emb_luts = Elementwise(feat_merge, embeddings) # The final output size of word + feature vectors. This can vary # from the word vector size if and only if features are defined. # This is the attribute you should access if you need to know # how big your embeddings are going to be. self.embedding_size = sum(emb_dims) if feat_merge == "concat" else word_vec_size # The sequence of operations that converts the input sequence # into a sequence of embeddings. At minimum this consists of # looking up the embeddings for each word and feature in the # input. Model parameters may require the sequence to contain # additional operations as well. super(Embeddings, self).__init__() self.make_embedding = nn.Sequential() self.make_embedding.add_module("emb_luts", emb_luts) if feat_merge == "mlp" and len(feat_vocab_sizes) > 0: in_dim = sum(emb_dims) mlp = nn.Sequential(nn.Linear(in_dim, word_vec_size), nn.ReLU()) self.make_embedding.add_module("mlp", mlp) self.position_encoding = position_encoding self.dropout = nn.Dropout(p=dropout) self.dropout_p = dropout if self.position_encoding: pe = PositionalEncoding(self.embedding_size, position_encoding_type) self.make_embedding.add_module("pe", pe) if freeze_word_vecs: self.word_lut.weight.requires_grad = False def _validate_args( self, feat_merge, feat_vocab_sizes, feat_vec_exponent, feat_vec_size, feat_padding_idx, ): if feat_merge == "sum": # features must use word_vec_size if feat_vec_exponent != 0.7: warnings.warn( "Merging with sum, but got non-default " "feat_vec_exponent. It will be unused." ) if feat_vec_size != -1: warnings.warn( "Merging with sum, but got non-default " "feat_vec_size. It will be unused." ) elif feat_vec_size > 0: # features will use feat_vec_size if feat_vec_exponent != -1: warnings.warn( "Not merging with sum and positive " "feat_vec_size, but got non-default " "feat_vec_exponent. It will be unused." ) else: if feat_vec_exponent <= 0: raise ValueError( "Using feat_vec_exponent to determine " "feature vec size, but got feat_vec_exponent " "less than or equal to 0." ) n_feats = len(feat_vocab_sizes) if n_feats != len(feat_padding_idx): raise ValueError( "Got unequal number of feat_vocab_sizes and " "feat_padding_idx ({:d} != {:d})".format(n_feats, len(feat_padding_idx)) ) @property def word_lut(self): """Word look-up table.""" return self.make_embedding[0][0] @property def emb_luts(self): """Embedding look-up table.""" return self.make_embedding[0]
[docs] def load_pretrained_vectors(self, emb_file): """Load in pretrained embeddings. Args: emb_file (str) : path to torch serialized embeddings """ if emb_file: pretrained = torch.load(emb_file) pretrained_vec_size = pretrained.size(1) if self.word_vec_size > pretrained_vec_size: self.word_lut.weight.data[:, :pretrained_vec_size] = pretrained elif self.word_vec_size < pretrained_vec_size: self.word_lut.weight.data.copy_(pretrained[:, : self.word_vec_size]) else: self.word_lut.weight.data.copy_(pretrained)
[docs] def forward(self, source, step=None): """Computes the embeddings for words and features. Args: source (LongTensor): index tensor ``(batch, len, nfeat)`` Returns: FloatTensor: Word embeddings ``(batch, len, embedding_size)`` """ if self.position_encoding: for i, module in enumerate(self.make_embedding._modules.values()): if i == len(self.make_embedding._modules.values()) - 1: source = module(source, step=step) else: source = module(source) else: source = self.make_embedding(source) if self.dropout_p > 0: return self.dropout(source) else: return source
def update_dropout(self, dropout): self.dropout.p = dropout
# Some utilitary functions for pretrained embeddings def read_embeddings(path, skip_lines=0, filter_set=None): """ Read an embeddings file in the glove format. """ embs = dict() total_vectors_in_file = 0 with open(path, "rb") as f: for i, line in enumerate(f): if i < skip_lines: continue if not line: break if len(line) == 0: # is this reachable? continue l_split = line.decode("utf8").strip().split(" ") if len(l_split) == 2: continue total_vectors_in_file += 1 if filter_set is not None and l_split[0] not in filter_set: continue embs[l_split[0]] = [float(em) for em in l_split[1:]] return embs, total_vectors_in_file def calc_vocab_load_stats(vocab, loaded_embed_dict): matching_count = len(set(vocab.ids_to_tokens) & set(loaded_embed_dict.keys())) missing_count = len(vocab) - matching_count percent_matching = matching_count / len(vocab) * 100 return matching_count, missing_count, percent_matching def convert_to_torch_tensor(word_to_float_list_dict, vocab): dim = len(next(iter(word_to_float_list_dict.values()))) tensor = torch.zeros((len(vocab), dim)) for word, values in word_to_float_list_dict.items(): tensor[vocab.tokens_to_ids[word]] = torch.Tensor(values) return tensor def prepare_pretrained_embeddings(opt, vocabs): if all( [ opt.both_embeddings is None, opt.src_embeddings is None, opt.tgt_embeddings is None, ] ): return assert ( opt.save_data ), "-save_data is required when using \ pretrained embeddings." vocs = [] for side in ["src", "tgt"]: vocab = vocabs[side] vocs.append(vocab) enc_vocab, dec_vocab = vocs skip_lines = 1 if opt.embeddings_type == "word2vec" else 0 if opt.both_embeddings is not None: set_of_src_and_tgt_vocab = set(enc_vocab.ids_to_tokens) | set( dec_vocab.ids_to_tokens ) logger.info( "Reading encoder and decoder embeddings from {}".format(opt.both_embeddings) ) src_vectors, total_vec_count = read_embeddings( opt.both_embeddings, skip_lines, set_of_src_and_tgt_vocab ) tgt_vectors = src_vectors logger.info("\tFound {} total vectors in file".format(total_vec_count)) else: if opt.src_embeddings is not None: logger.info("Reading encoder embeddings from {}".format(opt.src_embeddings)) src_vectors, total_vec_count = read_embeddings( opt.src_embeddings, skip_lines, filter_set=set(enc_vocab.ids_to_tokens) ) logger.info("\tFound {} total vectors in file.".format(total_vec_count)) else: src_vectors = None if opt.tgt_embeddings is not None: logger.info("Reading decoder embeddings from {}".format(opt.tgt_embeddings)) tgt_vectors, total_vec_count = read_embeddings( opt.tgt_embeddings, skip_lines, filter_set=set(dec_vocab.ids_to_tokens) ) logger.info("\tFound {} total vectors in file".format(total_vec_count)) else: tgt_vectors = None logger.info("After filtering to vectors in vocab:") if opt.src_embeddings is not None or opt.both_embeddings is not None: logger.info( "\t* enc: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(enc_vocab, src_vectors) ) if opt.tgt_embeddings is not None or opt.both_embeddings is not None: logger.info( "\t* dec: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(dec_vocab, tgt_vectors) ) # Write to file enc_output_file = opt.save_data + ".enc_embeddings.pt" dec_output_file = opt.save_data + ".dec_embeddings.pt" if opt.src_embeddings is not None or opt.both_embeddings is not None: logger.info("\nSaving encoder embeddings as:\n\t* enc: %s" % enc_output_file) torch.save(convert_to_torch_tensor(src_vectors, enc_vocab), enc_output_file) # set the opt in place opt.pre_word_vecs_enc = enc_output_file if opt.tgt_embeddings is not None or opt.both_embeddings is not None: logger.info("\nSaving decoder embeddings as:\n\t* dec: %s" % dec_output_file) torch.save(convert_to_torch_tensor(tgt_vectors, dec_vocab), dec_output_file) # set the opt in place opt.pre_word_vecs_dec = dec_output_file