"""Defines common layers."""
import tensorflow as tf
from opennmt.utils.misc import shape_list
[docs]def dropout(x, rate, training=None):
"""Simple dropout layer."""
if not training or rate == 0:
return x
return tf.nn.dropout(x, rate)
[docs]def gelu(x):
"""Gaussian Error Linear Unit activation function described in
https://arxiv.org/abs/1606.08415.
"""
return tf.nn.gelu(x, approximate=True)
[docs]class Dense(tf.keras.layers.Dense):
"""Small ``tf.keras.layers.Dense`` extension to possibly reuse an existing weight
matrix.
"""
[docs] def __init__(self, units, weight=None, transpose=False, **kwargs):
"""Initializes the layer.
Args:
unit: Positive integer, dimensionality of the output space.
weight: The weight to reuse.
transpose: Whether :obj:`weight` should be transposed or not.
kwargs: Additional layers arguments.
"""
super().__init__(units, **kwargs)
self.set_kernel(weight, transpose=transpose)
[docs] def set_kernel(self, weight, transpose=False):
"""Use :obj:`weight` as the kernel weights matrix.
Args:
weight: The weight to use.
transpose: Whether :obj:`weight` should be transposed or not.
Raises:
ValueError: if the layer is already built.
"""
if self.built:
raise ValueError("The layer is already built")
self.weight = weight
self.transpose = transpose
[docs] def add_weight(self, name, *args, **kwargs):
if self.weight is not None and name == "kernel":
return self.weight
return super().add_weight(name, *args, **kwargs)
[docs] def call(self, inputs):
shape = shape_list(inputs)
rank = len(shape)
if rank > 2:
inputs = tf.reshape(inputs, [-1, shape[-1]])
if inputs.dtype is tf.float16 and self.units % 8 != 0:
padding_size = 8 - self.units % 8
paddings = (
[[0, padding_size], [0, 0]]
if self.transpose
else [[0, 0], [0, padding_size]]
)
kernel = tf.pad(self.kernel, paddings)
outputs = tf.matmul(inputs, kernel, transpose_b=self.transpose)
outputs = outputs[:, : self.units]
else:
outputs = tf.matmul(inputs, self.kernel, transpose_b=self.transpose)
if self.use_bias:
outputs = tf.nn.bias_add(outputs, self.bias)
if self.activation is not None:
outputs = self.activation(outputs)
if rank > 2:
outputs = tf.reshape(outputs, shape[:-1] + [self.units])
return outputs
[docs] def map_v1_weights(self, weights):
m = [(self.kernel, weights["kernel"])]
if self.use_bias:
m.append((self.bias, weights["bias"]))
return m
[docs]class LayerNorm(tf.keras.layers.LayerNormalization):
"""Layer normalization."""
[docs] def map_v1_weights(self, weights):
return [(self.beta, weights["beta"]), (self.gamma, weights["gamma"])]
[docs]class LayerWrapper(tf.keras.layers.Layer):
"""Layer wrapper for input/output normalization, input/output dropout and
residual connection.
"""
[docs] def __init__(
self,
layer,
normalize_input=False,
normalize_output=False,
input_dropout=0,
output_dropout=0,
residual_connection=False,
**kwargs
):
"""Initializes the layer.
Args:
layer: The layer to wrap.
normalize_input: Apply layer normalization on the input.
normalize_output: Apply layer normalization on the output.
input_dropout: The probability to drop units in the layer input.
output_dropout: The probability to drop units in the layer output.
residual_connection: Add the inputs to layer outputs (if their shape are
compatible).
kwargs: Additional layer arguments.
"""
super().__init__(**kwargs)
self.layer = layer
self.input_layer_norm = LayerNorm() if normalize_input else None
self.output_layer_norm = LayerNorm() if normalize_output else None
self.input_dropout = input_dropout
self.output_dropout = output_dropout
self.residual_connection = residual_connection
[docs] def call(self, inputs, *args, **kwargs):
"""Runs the wrapper."""
training = kwargs.get("training")
x = inputs
if self.input_layer_norm is not None:
x = self.input_layer_norm(x)
x = dropout(x, self.input_dropout, training=training)
all_outputs = self.layer(x, *args, **kwargs)
if isinstance(all_outputs, tuple):
outputs = all_outputs[0]
extra_outputs = list(all_outputs)[1:]
else:
outputs = all_outputs
extra_outputs = None
outputs = dropout(outputs, self.output_dropout, training=training)
if self.residual_connection and outputs.shape[-1] == inputs.shape[-1]:
outputs += inputs
if self.output_layer_norm is not None:
outputs = self.output_layer_norm(outputs)
if extra_outputs:
return tuple([outputs] + extra_outputs)
return outputs
# The wrapper should be serializable to be used in tf.keras.layers.Bidirectional.
[docs] def get_config(self):
"""Returns the layer wrapper configuration."""
config = {
"layer": tf.keras.layers.serialize(self.layer),
"normalize_input": self.input_layer_norm is not None,
"normalize_output": self.output_layer_norm is not None,
"input_dropout": self.input_dropout,
"output_dropout": self.output_dropout,
"residual_connection": self.residual_connection,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
[docs] @classmethod
def from_config(cls, config):
"""Creates a layer wrapper from its configuration."""
layer = tf.keras.layers.deserialize(config.pop("layer"))
return cls(layer, **config)