Source code for onmt.modules.position_ffn

"""Position feed-forward network from "Attention is All You Need"."""

import torch.nn as nn
import torch.nn.functional as F


class ActivationFunction(object):
    relu = "relu"
    gelu = "gelu"


ACTIVATION_FUNCTIONS = {
    ActivationFunction.relu: F.relu,
    ActivationFunction.gelu: F.gelu,
}


[docs]class PositionwiseFeedForward(nn.Module): """ A two-layer Feed-Forward-Network with residual layer norm. Args: d_model (int): the size of input for the first-layer of the FFN. d_ff (int): the hidden layer size of the second-layer of the FNN. dropout (float): dropout probability in :math:`[0, 1)`. activation_fn (ActivationFunction): activation function used. """ def __init__(self, d_model, d_ff, dropout=0.1, activation_fn=ActivationFunction.relu): super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.dropout_1 = nn.Dropout(dropout) self.activation = ACTIVATION_FUNCTIONS[activation_fn] self.dropout_2 = nn.Dropout(dropout)
[docs] def forward(self, x): """Layer definition. Args: x: ``(batch_size, input_len, model_dim)`` Returns: (FloatTensor): Output ``(batch_size, input_len, model_dim)``. """ inter = self.dropout_1(self.activation(self.w_1(self.layer_norm(x)))) output = self.dropout_2(self.w_2(inter)) return output + x
def update_dropout(self, dropout): self.dropout_1.p = dropout self.dropout_2.p = dropout