Source code for opennmt.layers.common

"""Defines common layers."""

import tensorflow as tf

from opennmt.utils.misc import shape_list


[docs]def dropout(x, rate, training=None): """Simple dropout layer.""" if not training or rate == 0: return x return tf.nn.dropout(x, rate)
[docs]def gelu(x): """Gaussian Error Linear Unit activation function described in https://arxiv.org/abs/1606.08415. """ return tf.nn.gelu(x, approximate=True)
[docs]class Dense(tf.keras.layers.Dense): """Small ``tf.keras.layers.Dense`` extension to possibly reuse an existing weight matrix. """
[docs] def __init__(self, units, weight=None, transpose=False, **kwargs): """Initializes the layer. Args: unit: Positive integer, dimensionality of the output space. weight: The weight to reuse. transpose: Whether :obj:`weight` should be transposed or not. kwargs: Additional layers arguments. """ super().__init__(units, **kwargs) self.set_kernel(weight, transpose=transpose)
[docs] def set_kernel(self, weight, transpose=False): """Use :obj:`weight` as the kernel weights matrix. Args: weight: The weight to use. transpose: Whether :obj:`weight` should be transposed or not. Raises: ValueError: if the layer is already built. """ if self.built: raise ValueError("The layer is already built") self.weight = weight self.transpose = transpose
[docs] def add_weight(self, name, *args, **kwargs): if self.weight is not None and name == "kernel": return self.weight return super().add_weight(name, *args, **kwargs)
[docs] def call(self, inputs): shape = shape_list(inputs) rank = len(shape) if rank > 2: inputs = tf.reshape(inputs, [-1, shape[-1]]) if inputs.dtype is tf.float16 and self.units % 8 != 0: padding_size = 8 - self.units % 8 paddings = ( [[0, padding_size], [0, 0]] if self.transpose else [[0, 0], [0, padding_size]] ) kernel = tf.pad(self.kernel, paddings) outputs = tf.matmul(inputs, kernel, transpose_b=self.transpose) outputs = outputs[:, : self.units] else: outputs = tf.matmul(inputs, self.kernel, transpose_b=self.transpose) if self.use_bias: outputs = tf.nn.bias_add(outputs, self.bias) if self.activation is not None: outputs = self.activation(outputs) if rank > 2: outputs = tf.reshape(outputs, shape[:-1] + [self.units]) return outputs
[docs] def map_v1_weights(self, weights): m = [(self.kernel, weights["kernel"])] if self.use_bias: m.append((self.bias, weights["bias"])) return m
[docs]class LayerNorm(tf.keras.layers.LayerNormalization): """Layer normalization."""
[docs] def map_v1_weights(self, weights): return [(self.beta, weights["beta"]), (self.gamma, weights["gamma"])]
[docs]class LayerWrapper(tf.keras.layers.Layer): """Layer wrapper for input/output normalization, input/output dropout and residual connection. """
[docs] def __init__( self, layer, normalize_input=False, normalize_output=False, input_dropout=0, output_dropout=0, residual_connection=False, **kwargs ): """Initializes the layer. Args: layer: The layer to wrap. normalize_input: Apply layer normalization on the input. normalize_output: Apply layer normalization on the output. input_dropout: The probability to drop units in the layer input. output_dropout: The probability to drop units in the layer output. residual_connection: Add the inputs to layer outputs (if their shape are compatible). kwargs: Additional layer arguments. """ super().__init__(**kwargs) self.layer = layer self.input_layer_norm = LayerNorm() if normalize_input else None self.output_layer_norm = LayerNorm() if normalize_output else None self.input_dropout = input_dropout self.output_dropout = output_dropout self.residual_connection = residual_connection
[docs] def call(self, inputs, *args, **kwargs): """Runs the wrapper.""" training = kwargs.get("training") x = inputs if self.input_layer_norm is not None: x = self.input_layer_norm(x) x = dropout(x, self.input_dropout, training=training) all_outputs = self.layer(x, *args, **kwargs) if isinstance(all_outputs, tuple): outputs = all_outputs[0] extra_outputs = list(all_outputs)[1:] else: outputs = all_outputs extra_outputs = None outputs = dropout(outputs, self.output_dropout, training=training) if self.residual_connection and outputs.shape[-1] == inputs.shape[-1]: outputs += inputs if self.output_layer_norm is not None: outputs = self.output_layer_norm(outputs) if extra_outputs: return tuple([outputs] + extra_outputs) return outputs
# The wrapper should be serializable to be used in tf.keras.layers.Bidirectional.
[docs] def get_config(self): """Returns the layer wrapper configuration.""" config = { "layer": tf.keras.layers.serialize(self.layer), "normalize_input": self.input_layer_norm is not None, "normalize_output": self.output_layer_norm is not None, "input_dropout": self.input_dropout, "output_dropout": self.output_dropout, "residual_connection": self.residual_connection, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items()))
[docs] @classmethod def from_config(cls, config): """Creates a layer wrapper from its configuration.""" layer = tf.keras.layers.deserialize(config.pop("layer")) return cls(layer, **config)