Modules

Embeddings

class onmt.modules.Embeddings(word_vec_size, word_vocab_size, word_padding_idx, position_encoding=False, position_encoding_type='SinusoidalInterleaved', feat_merge='concat', feat_vec_exponent=0.7, feat_vec_size=-1, feat_padding_idx=[], feat_vocab_sizes=[], dropout=0, sparse=False, freeze_word_vecs=False)[source]

Bases: Module

Words embeddings for encoder/decoder.

Additionally includes ability to add sparse input features based on “Linguistic Input Features Improve Neural Machine Translation” [SH16].

graph LR A[Input] C[Feature 1 Lookup] A-->B[Word Lookup] A-->C A-->D[Feature N Lookup] B-->E[MLP/Concat] C-->E D-->E E-->F[Output]
Parameters:
  • word_vec_size (int) – size of the dictionary of embeddings.

  • word_vocab_size (int) – size of dictionary of embeddings for words.

  • word_padding_idx (int) – padding index for words in the embeddings.

  • position_encoding (bool) – see PositionalEncoding

  • feat_merge (string) – merge action for the features embeddings: concat, sum or mlp.

  • feat_vec_exponent (float) – when using -feat_merge concat, feature embedding size is N^feat_dim_exponent, where N is the number of values the feature takes.

  • feat_vec_size (int) – embedding dimension for features when using -feat_merge mlp

  • feat_padding_idx (List[int]) – padding index for a list of features in the embeddings.

  • feat_vocab_sizes (List[int], optional) – list of size of dictionary of embeddings for each feature.

  • dropout (float) – dropout probability.

  • sparse (bool) – sparse embbedings default False

  • freeze_word_vecs (bool) – freeze weights of word vectors.

property emb_luts

Embedding look-up table.

forward(source, step=None)[source]

Computes the embeddings for words and features.

Parameters:

source (LongTensor) – index tensor (batch, len, nfeat)

Returns:

Word embeddings (batch, len, embedding_size)

Return type:

FloatTensor

load_pretrained_vectors(emb_file)[source]

Load in pretrained embeddings.

Parameters:

emb_file (str) – path to torch serialized embeddings

property word_lut

Word look-up table.

class onmt.modules.PositionalEncoding(dim, enc_type, max_len=5000)[source]

Bases: Module

Sinusoidal positional encoding for non-recurrent neural networks.

Implementation based on “Attention Is All You Need” [VSP+17]

Parameters:

dim (int) – embedding size

forward(emb, step=None)[source]

Embed inputs.

Parameters:
  • emb (FloatTensor) – Sequence of word vectors (batch_size, seq_len, self.dim)

  • step (int or NoneType) – If stepwise (seq_len = 1), use the encoding for this position.

class onmt.modules.position_ffn.PositionwiseFeedForward(d_model, d_ff, dropout=0.1, activation_fn='relu', add_ffnbias=True, parallel_residual=False, layer_norm='standard', norm_eps=1e-06, use_ckpting=[], parallel_gpu=1)[source]

Bases: Module

A two-layer Feed-Forward-Network with residual layer norm.

Parameters:
  • d_model (int) – the size of input for the first-layer of the FFN.

  • d_ff (int) – the hidden layer size of the second-layer of the FNN.

  • dropout (float) – dropout probability in \([0, 1)\).

  • activation_fn (ActivationFunction) – activation function used.

  • layer_norm (string) – ‘standard’ or ‘rms’

forward(x)[source]

Layer definition.

Parameters:

x(batch_size, input_len, model_dim)

Returns:

Output (batch_size, input_len, model_dim).

Return type:

(FloatTensor)

Encoders

class onmt.encoders.EncoderBase(*args, **kwargs)[source]

Bases: Module

Base encoder class. Specifies the interface used by different encoder types and required by onmt.Models.NMTModel.

forward(src, src_len=None)[source]
Parameters:
  • src (LongTensor) – padded sequences of sparse indices (batch, src_len, nfeat)

  • src_len (LongTensor) – length of each sequence (batch,)

Returns:

  • enc_out (encoder output used for attention), (batch, src_len, hidden_size) for bidirectional rnn last dimension is 2x hidden_size

  • enc_final_hs: encoder final hidden state (num_layers x dir, batch, hidden_size) In the case of LSTM this is a tuple.

  • src_len (batch)

Return type:

(FloatTensor, FloatTensor, FloatTensor)

class onmt.encoders.TransformerEncoder(num_layers, d_model, heads, d_ff, dropout, attention_dropout, embeddings, max_relative_positions, relative_positions_buckets, pos_ffn_activation_fn='relu', add_qkvbias=False, num_kv=0, add_ffnbias=True, parallel_residual=False, layer_norm='standard', norm_eps=1e-06, use_ckpting=[], parallel_gpu=1, rotary_interleave=True, rotary_theta=10000.0, rotary_dim=0)[source]

Bases: EncoderBase

The Transformer encoder from “Attention is All You Need” [VSP+17]

Parameters:
  • num_layers (int) – number of encoder layers

  • d_model (int) – size of the model

  • heads (int) – number of heads

  • d_ff (int) – size of the inner FF layer

  • dropout (float) – dropout parameters

  • embeddings (onmt.modules.Embeddings) – embeddings to use, should have positional encodings

  • pos_ffn_activation_fn (ActivationFunction) – activation function choice for PositionwiseFeedForward layer

Returns:

  • enc_out (batch_size, src_len, model_dim)

  • encoder final state: None in the case of Transformer

  • src_len (batch_size)

Return type:

(torch.FloatTensor, torch.FloatTensor)

forward(src, src_len=None)[source]

See EncoderBase.forward()

classmethod from_opt(opt, embeddings)[source]

Alternate constructor.

class onmt.encoders.RNNEncoder(rnn_type, bidirectional, num_layers, hidden_size, dropout=0.0, embeddings=None, use_bridge=False)[source]

Bases: EncoderBase

A generic recurrent neural network encoder.

Parameters:
  • rnn_type (str) – style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU]

  • bidirectional (bool) – use a bidirectional RNN

  • num_layers (int) – number of stacked layers

  • hidden_size (int) – hidden size of each layer

  • dropout (float) – dropout value for torch.nn.Dropout

  • embeddings (onmt.modules.Embeddings) – embedding module to use

forward(src, src_len=None)[source]

See EncoderBase.forward()

classmethod from_opt(opt, embeddings)[source]

Alternate constructor.

class onmt.encoders.GGNNEncoder(rnn_type, src_word_vec_size, src_ggnn_size, state_dim, bidir_edges, n_edge_types, n_node, bridge_extra_node, n_steps, src_vocab)[source]

Bases: EncoderBase

A gated graph neural network configured as an encoder.

Based on github.com/JamesChuanggg/ggnn.pytorch.git, which is based on the paper “Gated Graph Sequence Neural Networks” by Y. Li, D. Tarlow, M. Brockschmidt, and R. Zemel.

Parameters:
  • rnn_type (str) – style of recurrent unit to use, one of [LSTM]

  • src_ggnn_size (int) – Size of token-to-node embedding input

  • src_word_vec_size (int) – Size of token-to-node embedding output

  • state_dim (int) – Number of state dimensions in nodes

  • n_edge_types (int) – Number of edge types

  • bidir_edges (bool) – True if reverse edges should be autocreated

  • n_node (int) – Max nodes in graph

  • bridge_extra_node (bool) – True indicates only 1st extra node (after token listing) should be used for decoder init.

  • n_steps (int) – Steps to advance graph encoder for stabilization

  • src_vocab (int) – Path to source vocabulary.(The ggnn uses src_vocab during training because the graph is built using edge information which requires parsing the input sequence.)

forward(src, src_len=None)[source]

See EncoderBase.forward()

classmethod from_opt(opt, embeddings)[source]

Alternate constructor.

class onmt.encoders.CNNEncoder(num_layers, hidden_size, cnn_kernel_width, dropout, embeddings)[source]

Bases: EncoderBase

Encoder based on “Convolutional Sequence to Sequence Learning” [GAG+17].

forward(input, src_len=None, hidden=None)[source]

See EncoderBase.forward()

classmethod from_opt(opt, embeddings)[source]

Alternate constructor.

class onmt.encoders.MeanEncoder(num_layers, embeddings)[source]

Bases: EncoderBase

A trivial non-recurrent encoder. Simply applies mean pooling.

Parameters:
  • num_layers (int) – number of replicated layers

  • embeddings (onmt.modules.Embeddings) – embedding module to use

forward(src, src_len=None)[source]

See EncoderBase.forward()

classmethod from_opt(opt, embeddings)[source]

Alternate constructor.

Decoders

class onmt.decoders.DecoderBase(attentional=True)[source]

Bases: Module

Abstract class for decoders.

Parameters:

attentional (bool) – The decoder returns non-empty attention.

classmethod from_opt(opt, embeddings)[source]

Alternate constructor.

Subclasses should override this method.

class onmt.decoders.TransformerDecoder(num_layers, d_model, heads, d_ff, copy_attn, self_attn_type, dropout, attention_dropout, embeddings, max_relative_positions, relative_positions_buckets, aan_useffn, full_context_alignment, alignment_layer, alignment_heads, pos_ffn_activation_fn='relu', add_qkvbias=False, num_kv=0, add_ffnbias=True, parallel_residual=False, shared_layer_norm=False, layer_norm='standard', norm_eps=1e-06, use_ckpting=[], parallel_gpu=1, sliding_window=0, rotary_interleave=True, rotary_theta=10000.0, rotary_dim=0, num_experts=0, num_experts_per_tok=2)[source]

Bases: TransformerDecoderBase

The Transformer decoder from “Attention is All You Need”. [VSP+17]

Parameters:
  • num_layers (int) – number of decoder layers.

  • d_model (int) – size of the model

  • heads (int) – number of heads

  • d_ff (int) – size of the inner FF layer

  • copy_attn (bool) – if using a separate copy attention

  • self_attn_type (str) – type of self-attention scaled-dot, scaled-dot-flash, average

  • dropout (float) – dropout in residual, self-attn(dot) and feed-forward

  • attention_dropout (float) – dropout in context_attn (and self-attn(avg))

  • embeddings (onmt.modules.Embeddings) – embeddings to use, should have positional encodings

  • max_relative_positions (int) – Max distance between inputs in relative positions representations

  • relative_positions_buckets (int) – Number of buckets when using relative position bias

  • aan_useffn (bool) – Turn on the FFN layer in the AAN decoder

  • full_context_alignment (bool) – whether enable an extra full context decoder forward for alignment

  • alignment_layer (int) – N° Layer to supervise with for alignment guiding

  • alignment_heads (int) –

    1. of cross attention heads to use for alignment guiding

  • pos_ffn_activation_fn (ActivationFunction) – activation function choice for PositionwiseFeedForward layer

  • add_qkvbias (bool) – whether to add bias to the Key/Value nn.Linear

  • num_kv (int) – number of heads for KV when different vs Q (multiquery)

  • add_ffnbias (bool) – whether to add bias to the FF nn.Linear

  • parallel_residual (bool) – Use parallel residual connections in each layer block, as used by the GPT-J and GPT-NeoX models

  • shared_layer_norm (bool) – When using parallel residual, share the input and post attention layer norms.

  • layer_norm (string) – type of layer normalization standard/rms

  • norm_eps (float) – layer norm epsilon

  • use_ckpting (List) – layers for which we checkpoint for backward

  • parallel_gpu (int) – Number of gpu for tensor parallelism

  • sliding_window (int) – Width of the band mask and KV cache (cf Mistral Model)

  • rotary_interleave (bool) – Interleave the head dimensions when rotary embeddings are applied

  • rotary_theta (int) – rotary base theta

  • rotary_dim (int) – in some cases the rotary dim is lower than head dim

  • num_experts (int) – Number of experts for MoE

  • num_experts_per_tok (int) – Number of experts choice per token

forward(tgt, enc_out=None, step=None, **kwargs)[source]

Decode, possibly stepwise. when training step is always None, when decoding, step increases tgt (Tensor): batch x tlen x feats enc_out (Tensor): encoder output (batch x slen x model_dim)

class onmt.decoders.decoder.RNNDecoderBase(rnn_type, bidirectional_encoder, num_layers, hidden_size, attn_type='general', attn_func='softmax', coverage_attn=False, context_gate=None, copy_attn=False, dropout=0.0, embeddings=None, reuse_copy_attn=False, copy_attn_type='general')[source]

Bases: DecoderBase

Base recurrent attention-based decoder class.

Specifies the interface used by different decoder types and required by NMTModel.

Parameters:
  • rnn_type (str) – style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU]

  • bidirectional_encoder (bool) – use with a bidirectional encoder

  • num_layers (int) – number of stacked layers

  • hidden_size (int) – hidden size of each layer

  • attn_type (str) – see GlobalAttention

  • attn_func (str) – see GlobalAttention

  • coverage_attn (str) – see GlobalAttention

  • context_gate (str) – see ContextGate

  • copy_attn (bool) – setup a separate copy attention mechanism

  • dropout (float) – dropout value for torch.nn.Dropout

  • embeddings (onmt.modules.Embeddings) – embedding module to use

  • reuse_copy_attn (bool) – reuse the attention for copying

  • copy_attn_type (str) – The copy attention style. See GlobalAttention.

forward(tgt, enc_out, src_len=None, step=None, **kwargs)[source]
Parameters:
  • tgt (LongTensor) – sequences of padded tokens (batch, tgt_len, nfeats).

  • enc_out (FloatTensor) – vectors from the encoder (batch, src_len, hidden).

  • src_len (LongTensor) – the padded source lengths (batch,).

Returns:

  • dec_outs: output from the decoder (after attn) (batch, tgt_len, hidden).

  • attns: distribution over src at each tgt (batch, tgt_len, src_len).

Return type:

(FloatTensor, dict[str, FloatTensor])

classmethod from_opt(opt, embeddings)[source]

Alternate constructor.

init_state(src, _, enc_final_hs)[source]

Initialize decoder state with last state of the encoder.

class onmt.decoders.StdRNNDecoder(rnn_type, bidirectional_encoder, num_layers, hidden_size, attn_type='general', attn_func='softmax', coverage_attn=False, context_gate=None, copy_attn=False, dropout=0.0, embeddings=None, reuse_copy_attn=False, copy_attn_type='general')[source]

Bases: RNNDecoderBase

Standard fully batched RNN decoder with attention.

Faster implementation, uses CuDNN for implementation. See RNNDecoderBase for options.

Based around the approach from “Neural Machine Translation By Jointly Learning To Align and Translate” [BCB14]

Implemented without input_feeding and currently with no coverage_attn or copy_attn support.

class onmt.decoders.InputFeedRNNDecoder(rnn_type, bidirectional_encoder, num_layers, hidden_size, attn_type='general', attn_func='softmax', coverage_attn=False, context_gate=None, copy_attn=False, dropout=0.0, embeddings=None, reuse_copy_attn=False, copy_attn_type='general')[source]

Bases: RNNDecoderBase

Input feeding based decoder.

See RNNDecoderBase for options.

Based around the input feeding approach from “Effective Approaches to Attention-based Neural Machine Translation” [LPM15]

class onmt.decoders.CNNDecoder(num_layers, hidden_size, attn_type, copy_attn, cnn_kernel_width, dropout, embeddings, copy_attn_type)[source]

Bases: DecoderBase

Decoder based on “Convolutional Sequence to Sequence Learning” [GAG+17].

Consists of residual convolutional layers, with ConvMultiStepAttention.

forward(tgt, enc_out, step=None, **kwargs)[source]

See onmt.modules.RNNDecoderBase.forward()

classmethod from_opt(opt, embeddings)[source]

Alternate constructor.

init_state(_, enc_out, enc_hidden)[source]

Init decoder state.

Attention

class onmt.modules.GlobalAttention(dim, coverage=False, attn_type='dot', attn_func='softmax')[source]

Bases: Module

Global attention takes a matrix and a query vector. It then computes a parameterized convex combination of the matrix based on the input query.

Constructs a unit mapping a query q of size dim and a source matrix H of size n x dim, to an output of size dim.

graph BT A[Query] subgraph RNN C[H 1] D[H 2] E[H N] end F[Attn] G[Output] A --> F C --> F D --> F E --> F C -.-> G D -.-> G E -.-> G F --> G

All models compute the output as \(c = \sum_{j=1}^{\text{SeqLength}} a_j H_j\) where \(a_j\) is the softmax of a score function. Then then apply a projection layer to [q, c].

However they differ on how they compute the attention score.

  • Luong Attention (dot, general):
    • dot: \(\text{score}(H_j,q) = H_j^T q\)

    • general: \(\text{score}(H_j, q) = H_j^T W_a q\)

  • Bahdanau Attention (mlp):
    • \(\text{score}(H_j, q) = v_a^T \text{tanh}(W_a q + U_a h_j)\)

Parameters:
  • dim (int) – dimensionality of query and key

  • coverage (bool) – use coverage term

  • attn_type (str) – type of attention to use, options [dot,general,mlp]

  • attn_func (str) – attention function to use, options [softmax,sparsemax]

forward(src, enc_out, src_len=None, coverage=None)[source]
Parameters:
  • src (FloatTensor) – query vectors (batch, tgt_len, dim)

  • enc_out (FloatTensor) – encoder out vectors (batch, src_len, dim)

  • src_len (LongTensor) – source context lengths (batch,)

  • coverage (FloatTensor) – None (not supported yet)

Returns:

  • Computed vector (batch, tgt_len, dim)

  • Attention distribtutions for each query (batch, tgt_len, src_len)

Return type:

(FloatTensor, FloatTensor)

score(h_t, h_s)[source]
Parameters:
  • h_t (FloatTensor) – sequence of queries (batch, tgt_len, dim)

  • h_s (FloatTensor) – sequence of sources (batch, src_len, dim

Returns:

raw attention scores (unnormalized) for each src index

(batch, tgt_len, src_len)

Return type:

FloatTensor

class onmt.modules.MultiHeadedAttention(head_count: int, model_dim: int, dropout: float = 0.1, is_decoder: bool = True, max_relative_positions: int = 0, relative_positions_buckets: int = 0, rotary_interleave: bool = True, rotary_theta: int = 10000.0, rotary_dim: int = 0, attn_type: str | None = None, self_attn_type: str | None = None, add_qkvbias=False, num_kv=0, use_ckpting=[], parallel_gpu=1)[source]

Bases: Module

Multi-Head Attention module from “Attention is All You Need” [VSP+17].

Similar to standard dot attention but uses multiple attention distributions simulataneously to select relevant items.

graph BT A[key] B[value] C[query] O[output] subgraph Attn D[Attn 1] E[Attn 2] F[Attn N] end A --> D C --> D A --> E C --> E A --> F C --> F D --> O E --> O F --> O B --> O

Also includes several additional tricks.

Parameters:
  • head_count (int) – number of parallel heads

  • model_dim (int) – the dimension of keys/values/queries, must be divisible by head_count

  • dropout (float) – dropout parameter

  • max_relative_positions (int) – max relative positions

  • attn_type – “self” or “context”

forward(key: Tensor, value: Tensor, query: Tensor, mask: Tensor | None = None, sliding_window: int | None = 0, step: int | None = 0, return_attn: bool | None = False, self_attn_type: str | None = None) Tuple[Tensor, Tensor][source]

Compute the context vector and the attention vectors.

Parameters:
  • key (Tensor) – set of key_len key vectors (batch, key_len, dim)

  • value (Tensor) – set of key_len value vectors (batch, key_len, dim)

  • query (Tensor) – set of query_len query vectors (batch, query_len, dim)

  • mask – binary mask 1/0 indicating which keys have zero / non-zero attention (batch, query_len, key_len)

  • step (int) – decoding step (used for Rotary embedding)

Returns:

  • output context vectors (batch, query_len, dim)

  • Attention vector in heads (batch, head, query_len, key_len).

Return type:

(Tensor, Tensor)

class onmt.modules.AverageAttention(model_dim, dropout=0.1, aan_useffn=False, pos_ffn_activation_fn='relu')[source]

Bases: Module

Average Attention module from “Accelerating Neural Transformer via an Average Attention Network” [ZXS18].

Parameters:
  • model_dim (int) – the dimension of keys/values/queries, must be divisible by head_count

  • dropout (float) – dropout parameter

  • pos_ffn_activation_fn (ActivationFunction) – activation function choice for PositionwiseFeedForward layer

forward(layer_in, mask=None, step=None)[source]
Parameters:

layer_in (FloatTensor) – (batch, t_len, dim)

Returns:

  • gating_out (batch, tlen, dim)

  • average_out average attention

    (batch, input_len, dim)

Return type:

(FloatTensor, FloatTensor)

class onmt.modules.ConvMultiStepAttention(input_size)[source]

Bases: Module

Conv attention takes a key matrix, a value matrix and a query vector. Attention weight is calculated by key matrix with the query vector and sum on the value matrix. And the same operation is applied in each decode conv layer.

apply_mask(mask)[source]

Apply mask

forward(base_target_emb, input_from_dec, encoder_out_top, encoder_out_combine)[source]
Parameters:
  • base_target_emb – target emb tensor (batch, channel, height, width)

  • input_from_dec – output of dec conv (batch, channel, height, width)

  • encoder_out_top – the key matrix for calc of attention weight, which is the top output of encode conv

  • encoder_out_combine – the value matrix for the attention-weighted sum, which is the combination of base emb and top output of encode

class onmt.modules.CopyGenerator(input_size, output_size, pad_idx)[source]

Bases: Module

An implementation of pointer-generator networks [SLM17].

These networks consider copying words directly from the source sequence.

The copy generator is an extended version of the standard generator that computes three values.

  • \(p_{softmax}\) the standard softmax over tgt_dict

  • \(p(z)\) the probability of copying a word from the source

  • \(p_{copy}\) the probility of copying a particular word. taken from the attention distribution directly.

The model returns a distribution over the extend dictionary, computed as

\(p(w) = p(z=1) p_{copy}(w) + p(z=0) p_{softmax}(w)\)

Parameters:
  • input_size (int) – size of input representation

  • output_size (int) – size of output vocabulary

  • pad_idx (int) –

forward(hidden, attn, src_map)[source]

Compute a distribution over the target dictionary extended by the dynamic dictionary implied by copying source words.

Parameters:
  • hidden (FloatTensor) – hidden output (batch x tlen, input_size)

  • attn (FloatTensor) – attn for each (batch x tlen, slen)

  • src_map (FloatTensor) – A sparse indicator matrix mapping each source word to its index in the “extended” vocab containing. (batch, src_len, extra_words)

class onmt.modules.structured_attention.MatrixTree(eps=1e-05)[source]

Bases: Module

Implementation of the matrix-tree theorem for computing marginals of non-projective dependency parsing. This attention layer is used in the paper “Learning Structured Text Representations” [LL17].

forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.