Modules¶
Embeddings¶
-
class
onmt.modules.
Embeddings
(word_vec_size, word_vocab_size, word_padding_idx, position_encoding=False, feat_merge='concat', feat_vec_exponent=0.7, feat_vec_size=-1, feat_padding_idx=[], feat_vocab_sizes=[], dropout=0, sparse=False, freeze_word_vecs=False)[source]¶ Bases:
torch.nn.modules.module.Module
Words embeddings for encoder/decoder.
Additionally includes ability to add sparse input features based on “Linguistic Input Features Improve Neural Machine Translation” [SH16].
graph LR A[Input] C[Feature 1 Lookup] A-->B[Word Lookup] A-->C A-->D[Feature N Lookup] B-->E[MLP/Concat] C-->E D-->E E-->F[Output]- Parameters
word_vec_size (int) – size of the dictionary of embeddings.
word_vocab_size (int) – size of dictionary of embeddings for words.
word_padding_idx (int) – padding index for words in the embeddings.
position_encoding (bool) – see
PositionalEncoding
feat_merge (string) – merge action for the features embeddings: concat, sum or mlp.
feat_vec_exponent (float) – when using -feat_merge concat, feature embedding size is N^feat_dim_exponent, where N is the number of values the feature takes.
feat_vec_size (int) – embedding dimension for features when using -feat_merge mlp
feat_padding_idx (List[int]) – padding index for a list of features in the embeddings.
feat_vocab_sizes (List[int], optional) – list of size of dictionary of embeddings for each feature.
dropout (float) – dropout probability.
sparse (bool) – sparse embbedings default False
freeze_word_vecs (bool) – freeze weights of word vectors.
-
property
emb_luts
¶ Embedding look-up table.
-
forward
(source, step=None)[source]¶ Computes the embeddings for words and features.
- Parameters
source (LongTensor) – index tensor
(batch, len, nfeat)
- Returns
Word embeddings
(batch, len, embedding_size)
- Return type
FloatTensor
-
load_pretrained_vectors
(emb_file)[source]¶ Load in pretrained embeddings.
- Parameters
emb_file (str) – path to torch serialized embeddings
-
property
word_lut
¶ Word look-up table.
-
class
onmt.modules.
PositionalEncoding
(dropout, dim, max_len=5000)[source]¶ Bases:
torch.nn.modules.module.Module
Sinusoidal positional encoding for non-recurrent neural networks.
Implementation based on “Attention Is All You Need” [VSP+17]
- Parameters
dropout (float) – dropout parameter
dim (int) – embedding size
-
class
onmt.modules.position_ffn.
PositionwiseFeedForward
(d_model, d_ff, dropout=0.1, activation_fn='relu')[source]¶ Bases:
torch.nn.modules.module.Module
A two-layer Feed-Forward-Network with residual layer norm.
- Parameters
d_model (int) – the size of input for the first-layer of the FFN.
d_ff (int) – the hidden layer size of the second-layer of the FNN.
dropout (float) – dropout probability in \([0, 1)\).
activation_fn (ActivationFunction) – activation function used.
Encoders¶
-
class
onmt.encoders.
EncoderBase
[source]¶ Bases:
torch.nn.modules.module.Module
Base encoder class. Specifies the interface used by different encoder types and required by
onmt.Models.NMTModel
.-
forward
(src, src_len=None)[source]¶ - Parameters
src (LongTensor) – padded sequences of sparse indices
(batch, src_len, nfeat)
src_len (LongTensor) – length of each sequence
(batch,)
- Returns
enc_out (encoder output used for attention),
(batch, src_len, hidden_size)
for bidirectional rnn last dimension is 2x hidden_sizeenc_final_hs: encoder final hidden state
(num_layers x dir, batch, hidden_size)
In the case of LSTM this is a tuple.src_len
(batch)
- Return type
(FloatTensor, FloatTensor, FloatTensor)
-
-
class
onmt.encoders.
TransformerEncoder
(num_layers, d_model, heads, d_ff, dropout, attention_dropout, embeddings, max_relative_positions, pos_ffn_activation_fn='relu', add_qkvbias=False)[source]¶ Bases:
onmt.encoders.encoder.EncoderBase
The Transformer encoder from “Attention is All You Need” [VSP+17]
- Parameters
num_layers (int) – number of encoder layers
d_model (int) – size of the model
heads (int) – number of heads
d_ff (int) – size of the inner FF layer
dropout (float) – dropout parameters
embeddings (onmt.modules.Embeddings) – embeddings to use, should have positional encodings
pos_ffn_activation_fn (ActivationFunction) – activation function choice for PositionwiseFeedForward layer
- Returns
enc_out
(batch_size, src_len, model_dim)
encoder final state: None in the case of Transformer
src_len
(batch_size)
- Return type
(torch.FloatTensor, torch.FloatTensor)
-
class
onmt.encoders.
RNNEncoder
(rnn_type, bidirectional, num_layers, hidden_size, dropout=0.0, embeddings=None, use_bridge=False)[source]¶ Bases:
onmt.encoders.encoder.EncoderBase
A generic recurrent neural network encoder.
- Parameters
rnn_type (str) – style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU]
bidirectional (bool) – use a bidirectional RNN
num_layers (int) – number of stacked layers
hidden_size (int) – hidden size of each layer
dropout (float) – dropout value for
torch.nn.Dropout
embeddings (onmt.modules.Embeddings) – embedding module to use
-
class
onmt.encoders.
GGNNEncoder
(rnn_type, src_word_vec_size, src_ggnn_size, state_dim, bidir_edges, n_edge_types, n_node, bridge_extra_node, n_steps, src_vocab)[source]¶ Bases:
onmt.encoders.encoder.EncoderBase
- A gated graph neural network configured as an encoder.
Based on github.com/JamesChuanggg/ggnn.pytorch.git, which is based on the paper “Gated Graph Sequence Neural Networks” by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel.
- Parameters
rnn_type (str) – style of recurrent unit to use, one of [LSTM]
src_ggnn_size (int) – Size of token-to-node embedding input
src_word_vec_size (int) – Size of token-to-node embedding output
state_dim (int) – Number of state dimensions in nodes
n_edge_types (int) – Number of edge types
bidir_edges (bool) – True if reverse edges should be autocreated
n_node (int) – Max nodes in graph
bridge_extra_node (bool) – True indicates only 1st extra node (after token listing) should be used for decoder init.
n_steps (int) – Steps to advance graph encoder for stabilization
src_vocab (int) – Path to source vocabulary.(The ggnn uses src_vocab during training because the graph is built using edge information which requires parsing the input sequence.)
-
class
onmt.encoders.
CNNEncoder
(num_layers, hidden_size, cnn_kernel_width, dropout, embeddings)[source]¶ Bases:
onmt.encoders.encoder.EncoderBase
Encoder based on “Convolutional Sequence to Sequence Learning” [GAG+17].
-
class
onmt.encoders.
MeanEncoder
(num_layers, embeddings)[source]¶ Bases:
onmt.encoders.encoder.EncoderBase
A trivial non-recurrent encoder. Simply applies mean pooling.
- Parameters
num_layers (int) – number of replicated layers
embeddings (onmt.modules.Embeddings) – embedding module to use
Decoders¶
-
class
onmt.decoders.
DecoderBase
(attentional=True)[source]¶ Bases:
torch.nn.modules.module.Module
Abstract class for decoders.
- Parameters
attentional (bool) – The decoder returns non-empty attention.
-
class
onmt.decoders.
TransformerDecoder
(num_layers, d_model, heads, d_ff, copy_attn, self_attn_type, dropout, attention_dropout, embeddings, max_relative_positions, aan_useffn, full_context_alignment, alignment_layer, alignment_heads, pos_ffn_activation_fn='relu', add_qkvbias=False)[source]¶ Bases:
onmt.decoders.transformer.TransformerDecoderBase
The Transformer decoder from “Attention is All You Need”. [VSP+17]
- Parameters
num_layers (int) – number of decoder layers.
d_model (int) – size of the model
heads (int) – number of heads
d_ff (int) – size of the inner FF layer
copy_attn (bool) – if using a separate copy attention
self_attn_type (str) – type of self-attention scaled-dot, average
dropout (float) – dropout in residual, self-attn(dot) and feed-forward
attention_dropout (float) – dropout in context_attn (and self-attn(avg))
embeddings (onmt.modules.Embeddings) – embeddings to use, should have positional encodings
max_relative_positions (int) – Max distance between inputs in relative positions representations
aan_useffn (bool) – Turn on the FFN layer in the AAN decoder
full_context_alignment (bool) – whether enable an extra full context decoder forward for alignment
alignment_layer (int) – N° Layer to supervise with for alignment guiding
alignment_heads (int) –
of cross attention heads to use for alignment guiding
add_qkvbias (bool) – whether to add bias to the Key/Value nn.Linear
-
class
onmt.decoders.decoder.
RNNDecoderBase
(rnn_type, bidirectional_encoder, num_layers, hidden_size, attn_type='general', attn_func='softmax', coverage_attn=False, context_gate=None, copy_attn=False, dropout=0.0, embeddings=None, reuse_copy_attn=False, copy_attn_type='general')[source]¶ Bases:
onmt.decoders.decoder.DecoderBase
Base recurrent attention-based decoder class.
Specifies the interface used by different decoder types and required by
NMTModel
.- Parameters
rnn_type (str) – style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU]
bidirectional_encoder (bool) – use with a bidirectional encoder
num_layers (int) – number of stacked layers
hidden_size (int) – hidden size of each layer
attn_type (str) – see
GlobalAttention
attn_func (str) – see
GlobalAttention
coverage_attn (str) – see
GlobalAttention
context_gate (str) – see
ContextGate
copy_attn (bool) – setup a separate copy attention mechanism
dropout (float) – dropout value for
torch.nn.Dropout
embeddings (onmt.modules.Embeddings) – embedding module to use
reuse_copy_attn (bool) – reuse the attention for copying
copy_attn_type (str) – The copy attention style. See
GlobalAttention
.
-
forward
(tgt, enc_out, src_len=None, step=None, **kwargs)[source]¶ - Parameters
tgt (LongTensor) – sequences of padded tokens
(batch, tgt_len, nfeats)
.enc_out (FloatTensor) – vectors from the encoder
(batch, src_len, hidden)
.src_len (LongTensor) – the padded source lengths
(batch,)
.
- Returns
dec_outs: output from the decoder (after attn)
(batch, tgt_len, hidden)
.attns: distribution over src at each tgt
(batch, tgt_len, src_len)
.
- Return type
(FloatTensor, dict[str, FloatTensor])
-
class
onmt.decoders.
StdRNNDecoder
(rnn_type, bidirectional_encoder, num_layers, hidden_size, attn_type='general', attn_func='softmax', coverage_attn=False, context_gate=None, copy_attn=False, dropout=0.0, embeddings=None, reuse_copy_attn=False, copy_attn_type='general')[source]¶ Bases:
onmt.decoders.decoder.RNNDecoderBase
Standard fully batched RNN decoder with attention.
Faster implementation, uses CuDNN for implementation. See
RNNDecoderBase
for options.Based around the approach from “Neural Machine Translation By Jointly Learning To Align and Translate” [BCB14]
Implemented without input_feeding and currently with no coverage_attn or copy_attn support.
-
class
onmt.decoders.
InputFeedRNNDecoder
(rnn_type, bidirectional_encoder, num_layers, hidden_size, attn_type='general', attn_func='softmax', coverage_attn=False, context_gate=None, copy_attn=False, dropout=0.0, embeddings=None, reuse_copy_attn=False, copy_attn_type='general')[source]¶ Bases:
onmt.decoders.decoder.RNNDecoderBase
Input feeding based decoder.
See
RNNDecoderBase
for options.Based around the input feeding approach from “Effective Approaches to Attention-based Neural Machine Translation” [LPM15]
-
class
onmt.decoders.
CNNDecoder
(num_layers, hidden_size, attn_type, copy_attn, cnn_kernel_width, dropout, embeddings, copy_attn_type)[source]¶ Bases:
onmt.decoders.decoder.DecoderBase
Decoder based on “Convolutional Sequence to Sequence Learning” [GAG+17].
Consists of residual convolutional layers, with ConvMultiStepAttention.
Attention¶
-
class
onmt.modules.
GlobalAttention
(dim, coverage=False, attn_type='dot', attn_func='softmax')[source]¶ Bases:
torch.nn.modules.module.Module
Global attention takes a matrix and a query vector. It then computes a parameterized convex combination of the matrix based on the input query.
Constructs a unit mapping a query q of size dim and a source matrix H of size n x dim, to an output of size dim.
graph BT A[Query] subgraph RNN C[H 1] D[H 2] E[H N] end F[Attn] G[Output] A --> F C --> F D --> F E --> F C -.-> G D -.-> G E -.-> G F --> GAll models compute the output as \(c = \sum_{j=1}^{\text{SeqLength}} a_j H_j\) where \(a_j\) is the softmax of a score function. Then then apply a projection layer to [q, c].
However they differ on how they compute the attention score.
- Luong Attention (dot, general):
dot: \(\text{score}(H_j,q) = H_j^T q\)
general: \(\text{score}(H_j, q) = H_j^T W_a q\)
- Bahdanau Attention (mlp):
\(\text{score}(H_j, q) = v_a^T \text{tanh}(W_a q + U_a h_j)\)
- Parameters
dim (int) – dimensionality of query and key
coverage (bool) – use coverage term
attn_type (str) – type of attention to use, options [dot,general,mlp]
attn_func (str) – attention function to use, options [softmax,sparsemax]
-
forward
(src, enc_out, src_len=None, coverage=None)[source]¶ - Parameters
src (FloatTensor) – query vectors
(batch, tgt_len, dim)
enc_out (FloatTensor) – encoder out vectors
(batch, src_len, dim)
src_len (LongTensor) – source context lengths
(batch,)
coverage (FloatTensor) – None (not supported yet)
- Returns
Computed vector
(batch, tgt_len, dim)
Attention distribtutions for each query
(batch, tgt_len, src_len)
- Return type
(FloatTensor, FloatTensor)
-
class
onmt.modules.
MultiHeadedAttention
(head_count: int, model_dim: int, dropout: float = 0.1, max_relative_positions: int = 0, attn_type: str = None, add_qkvbias=False)[source]¶ Bases:
torch.nn.modules.module.Module
Multi-Head Attention module from “Attention is All You Need” [VSP+17].
Similar to standard dot attention but uses multiple attention distributions simulataneously to select relevant items.
graph BT A[key] B[value] C[query] O[output] subgraph Attn D[Attn 1] E[Attn 2] F[Attn N] end A --> D C --> D A --> E C --> E A --> F C --> F D --> O E --> O F --> O B --> OAlso includes several additional tricks.
- Parameters
head_count (int) – number of parallel heads
model_dim (int) – the dimension of keys/values/queries, must be divisible by head_count
dropout (float) – dropout parameter
max_relative_positions (int) – max relative positions
attn_type – “self” or “context”
-
forward
(key: torch.Tensor, value: torch.Tensor, query: torch.Tensor, mask: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor][source]¶ Compute the context vector and the attention vectors.
- Parameters
key (Tensor) – set of key_len key vectors
(batch, key_len, dim)
value (Tensor) – set of key_len value vectors
(batch, key_len, dim)
query (Tensor) – set of query_len query vectors
(batch, query_len, dim)
mask – binary mask 1/0 indicating which keys have zero / non-zero attention
(batch, query_len, key_len)
- Returns
output context vectors
(batch, query_len, dim)
Attention vector in heads
(batch, head, query_len, key_len)
.
- Return type
(Tensor, Tensor)
-
class
onmt.modules.
AverageAttention
(model_dim, dropout=0.1, aan_useffn=False, pos_ffn_activation_fn='relu')[source]¶ Bases:
torch.nn.modules.module.Module
Average Attention module from “Accelerating Neural Transformer via an Average Attention Network” [ZXS18].
- Parameters
model_dim (int) – the dimension of keys/values/queries, must be divisible by head_count
dropout (float) – dropout parameter
pos_ffn_activation_fn (ActivationFunction) – activation function choice for PositionwiseFeedForward layer
-
class
onmt.modules.
ConvMultiStepAttention
(input_size)[source]¶ Bases:
torch.nn.modules.module.Module
Conv attention takes a key matrix, a value matrix and a query vector. Attention weight is calculated by key matrix with the query vector and sum on the value matrix. And the same operation is applied in each decode conv layer.
-
forward
(base_target_emb, input_from_dec, encoder_out_top, encoder_out_combine)[source]¶ - Parameters
base_target_emb – target emb tensor
(batch, channel, height, width)
input_from_dec – output of dec conv
(batch, channel, height, width)
encoder_out_top – the key matrix for calc of attention weight, which is the top output of encode conv
encoder_out_combine – the value matrix for the attention-weighted sum, which is the combination of base emb and top output of encode
-
-
class
onmt.modules.
CopyGenerator
(input_size, output_size, pad_idx)[source]¶ Bases:
torch.nn.modules.module.Module
An implementation of pointer-generator networks [SLM17].
These networks consider copying words directly from the source sequence.
The copy generator is an extended version of the standard generator that computes three values.
\(p_{softmax}\) the standard softmax over tgt_dict
\(p(z)\) the probability of copying a word from the source
\(p_{copy}\) the probility of copying a particular word. taken from the attention distribution directly.
The model returns a distribution over the extend dictionary, computed as
\(p(w) = p(z=1) p_{copy}(w) + p(z=0) p_{softmax}(w)\)
- Parameters
input_size (int) – size of input representation
output_size (int) – size of output vocabulary
pad_idx (int) –
-
forward
(hidden, attn, src_map)[source]¶ Compute a distribution over the target dictionary extended by the dynamic dictionary implied by copying source words.
- Parameters
hidden (FloatTensor) – hidden output
(batch x tlen, input_size)
attn (FloatTensor) – attn for each
(batch x tlen, slen)
src_map (FloatTensor) – A sparse indicator matrix mapping each source word to its index in the “extended” vocab containing.
(batch, src_len, extra_words)
-
class
onmt.modules.structured_attention.
MatrixTree
(eps=1e-05)[source]¶ Bases:
torch.nn.modules.module.Module
Implementation of the matrix-tree theorem for computing marginals of non-projective dependency parsing. This attention layer is used in the paper “Learning Structured Text Representations” [LL17].
-
forward
(input)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-