Source code for onmt.modules.multi_headed_attn

""" Multi-Head Attention module """
import torch
import torch.nn as nn
from math import log, sqrt
from torch import Tensor
from typing import Optional, Tuple
from torch.nn.functional import scaled_dot_product_attention
from torch.utils.checkpoint import checkpoint
from torch.nn.utils import skip_init
from .alibi_position_bias import AlibiPositionalBias
from torch.distributed import all_reduce
from importlib import import_module

# Help functions for Rotary Embeddings
# https://arxiv.org/pdf/2104.09864.pdf
# too convoluted to make maxseqlen a parameter.
# we suppose src_seq_len at training and max_length at inference
# are both < 2048 tokens.


def rotaryembeddings(dim: int, maxseqlen=2048, base=10000, device=None):
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    tmax = torch.arange(maxseqlen, device=inv_freq.device)
    rope = torch.outer(tmax, inv_freq).float()
    # rope is now matrix [maxseqlen, dim/2]
    rope = torch.polar(torch.ones_like(rope), rope)
    rope = torch.cat((rope, rope), dim=1)
    if device is not None:
        rope = rope.to(device)
    cos = rope[:, : rope.size(1) // 2].real.contiguous().half()
    sin = rope[:, : rope.size(1) // 2].imag.contiguous().half()
    return rope, cos, sin


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_emb(query, key, rope, interleave):
    if interleave:
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        query_ = query.float().reshape(*query.shape[:-1], -1, 2)
        query_ = torch.view_as_complex(query_)
        key_ = key.float().reshape(*key.shape[:-1], -1, 2)
        key_ = torch.view_as_complex(key_)
        rope = rope[:, : rope.size(1) // 2].view(1, query_.size(1), 1, query_.size(3))
        query_out = torch.view_as_real(query_ * rope).flatten(3)
        key_out = torch.view_as_real(key_ * rope).flatten(3)
        return query_out.transpose(1, 2).type_as(query), key_out.transpose(
            1, 2
        ).type_as(key)
    else:
        cos, sin = rope.real, rope.imag
        rotary_dim = cos.size(1)
        head_dim = query.size(3)
        if rotary_dim < head_dim:
            q_embed = (query[:, :, :, :rotary_dim] * cos) + (
                rotate_half(query[:, :, :, :rotary_dim]) * sin
            )
            k_embed = (key[:, :, :, :rotary_dim] * cos) + (
                rotate_half(key[:, :, :, :rotary_dim]) * sin
            )
            q_embed = torch.cat([q_embed, query[:, :, :, rotary_dim:]], dim=-1)
            k_embed = torch.cat([k_embed, key[:, :, :, rotary_dim:]], dim=-1)
        else:
            q_embed = (query * cos) + (rotate_half(query) * sin)
            k_embed = (key * cos) + (rotate_half(key) * sin)
        return q_embed.type_as(query), k_embed.type_as(key)


# Help functions for max_relative positions
# https://arxiv.org/abs/1803.02155


def relative_matmul(x: Tensor, z: Tensor, transpose: bool) -> Tensor:
    """
    Helper function for relative positions attention.
    https://arxiv.org/pdf/1803.02155.pdf
    x shape [batch_size x heads x q_len x k_len]
    """
    batch_size, heads, length, _ = x.size()
    x_t = x.permute(2, 0, 1, 3)
    x_t_r = x_t.contiguous().view(length, heads * batch_size, -1)
    if transpose:
        z = z.transpose(1, 2)
    x_tz_matmul = torch.matmul(x_t_r, z)
    x_tz_matmul_r = x_tz_matmul.view(length, batch_size, heads, -1)
    x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3)
    return x_tz_matmul_r_t


def gen_relative_positions(
    length: int,
    max_relative_positions: int,
    cache: bool = False,
    device: Optional[torch.device] = None,
) -> Tensor:
    """Generate the clipped relative positions matrix
    for a given length and maximum relative positions"""
    if cache:
        distance_mat = torch.arange(-length + 1, 1, 1, device=device).unsqueeze(0)
    else:
        range_vec = torch.arange(length, device=device)
        range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1)
        distance_mat = range_mat - range_mat.transpose(0, 1)
    distance_mat_clipped = torch.clamp(
        distance_mat, min=-max_relative_positions, max=max_relative_positions
    )
    # Shift values to be >= 0
    final_mat = distance_mat_clipped + max_relative_positions
    return final_mat


def _relative_position_bucket(
    relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
    """
    Adapted from Mesh Tensorflow:
    https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/
    mesh_tensorflow/transformer/transformer_layers.py#L593
    Translate relative position to a bucket number for relative attention.
    The relative position is defined as memory_position - query_position,
    i.e. the distance in tokens from the attending position to the attended-to
    position. If bidirectional=False, then positive relative positions are invalid.
    We use smaller buckets for small absolute relative_position and larger buckets for
    larger absolute relative_positions. All relative positions >=max_distance map to the
    same bucket. All relative positions <=-max_distance map to the same bucket.
    This should allow for more graceful generalization to longer sequences than the
    model has been trained on

    Args:
        relative_position: an int32 Tensor
        bidirectional: a boolean - whether the attention is bidirectional
        num_buckets: an integer
        max_distance: an integer

    Returns:
        a Tensor with the same shape as relative_position, containing int32 values
        in the range [0, num_buckets)
    """
    relative_buckets = 0
    if bidirectional:
        num_buckets //= 2
        relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
        relative_position = torch.abs(relative_position)
    else:
        relative_position = -torch.min(
            relative_position, torch.zeros_like(relative_position)
        )
    # now relative_position is in the range [0, inf)
    # half of the buckets are for exact increments in positions
    max_exact = num_buckets // 2
    is_small = relative_position < max_exact

    # The other half of the buckets are for logarithmically bigger bins in positions
    # up to max_distance
    relative_position_if_large = max_exact + (
        torch.log(relative_position.float() / max_exact)
        / log(max_distance / max_exact)
        * (num_buckets - max_exact)
    ).to(torch.long)
    relative_position_if_large = torch.min(
        relative_position_if_large,
        torch.full_like(relative_position_if_large, num_buckets - 1),
    )

    relative_buckets += torch.where(
        is_small, relative_position, relative_position_if_large
    )
    return relative_buckets


def compute_bias(
    query_length,
    key_length,
    is_decoder,
    max_relative_positions,
    relative_positions_buckets,
    device=None,
):
    """Compute binned relative position bias"""
    context_position = torch.arange(query_length, dtype=torch.long, device=device)[
        :, None
    ]
    memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
    relative_position = (
        memory_position - context_position
    )  # shape (query_length, key_length)
    relative_position_bucket = _relative_position_bucket(
        relative_position,  # shape (query_length, key_length)
        bidirectional=(not is_decoder),
        num_buckets=relative_positions_buckets,
        max_distance=max_relative_positions,
    )
    return relative_position_bucket


# Help functions to split model dim per head


def shape(x: Tensor, dim_per_head: int) -> Tensor:
    """
    Projection.
    [batchsize x length x modeldim]
    -> [batchsize x heads x length x dimperhead]
    """
    x_0, x_1, _ = x.size()
    return x.view(x_0, x_1, -1, dim_per_head).transpose(1, 2)


def unshape(x: Tensor) -> Tensor:
    """
    Compute context.
    [batchsize x heads x length x dimperhead]
    -> [batchsize x length x modeldim]
    """
    x_0, x_1, _, x_3 = x.size()
    return x.transpose(1, 2).contiguous().view(x_0, -1, x_1 * x_3)


[docs]class MultiHeadedAttention(torch.nn.Module): """Multi-Head Attention module from "Attention is All You Need" :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. Similar to standard `dot` attention but uses multiple attention distributions simulataneously to select relevant items. .. mermaid:: graph BT A[key] B[value] C[query] O[output] subgraph Attn D[Attn 1] E[Attn 2] F[Attn N] end A --> D C --> D A --> E C --> E A --> F C --> F D --> O E --> O F --> O B --> O Also includes several additional tricks. Args: head_count (int): number of parallel heads model_dim (int): the dimension of keys/values/queries, must be divisible by head_count dropout (float): dropout parameter max_relative_positions (int): max relative positions attn_type: "self" or "context" """ def __init__( self, head_count: int, model_dim: int, dropout: float = 0.1, is_decoder: bool = True, max_relative_positions: int = 0, relative_positions_buckets: int = 0, rotary_interleave: bool = True, rotary_theta: int = 1e4, rotary_dim: int = 0, attn_type: str = None, self_attn_type: str = None, add_qkvbias=False, num_kv=0, use_ckpting=[], parallel_gpu=1, ) -> None: assert ( model_dim % head_count == 0 ), "Model dimension must be divisible by the number of heads" self.dim_per_head = model_dim // head_count super(MultiHeadedAttention, self).__init__() self.head_count = head_count self.num_kv = num_kv self.parallel_gpu = parallel_gpu if num_kv == 0: assert ( model_dim % parallel_gpu == 0 ), "Model dimension must be divisible by the number of partitions" self.linear_keys = skip_init( nn.Linear, in_features=model_dim, out_features=model_dim // parallel_gpu, bias=add_qkvbias, ) self.linear_values = skip_init( nn.Linear, in_features=model_dim, out_features=model_dim // parallel_gpu, bias=add_qkvbias, ) else: assert ( self.dim_per_head * self.num_kv ) % parallel_gpu == 0, ( "Model dimension must be divisible by the number of partitions" ) self.linear_keys = skip_init( nn.Linear, in_features=model_dim, out_features=self.dim_per_head * self.num_kv // parallel_gpu, bias=add_qkvbias, ) self.linear_values = skip_init( nn.Linear, in_features=model_dim, out_features=self.dim_per_head * self.num_kv // parallel_gpu, bias=add_qkvbias, ) self.linear_query = skip_init( nn.Linear, in_features=model_dim, out_features=model_dim // parallel_gpu, bias=add_qkvbias, ) self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.dropout_p = dropout self.final_linear = skip_init( nn.Linear, in_features=model_dim // parallel_gpu, out_features=model_dim, bias=add_qkvbias, ) self.is_decoder = is_decoder self.max_relative_positions = max_relative_positions self.relative_positions_buckets = relative_positions_buckets self.attn_type = attn_type self.self_attn_type = self_attn_type self.layer_cache = ( False, {"keys": torch.tensor([]), "values": torch.tensor([])}, ) if relative_positions_buckets > 0: self.relative_attention_bias = nn.Embedding( relative_positions_buckets, head_count ) self.relative_positions_embeddings = None elif max_relative_positions > 0: # https://arxiv.org/pdf/1803.02155.pdf # in the paper they suggest either two embeds # relative_key / relative_value or only # relative_key. We implemented the same embed # for both. vocab_size = max_relative_positions * 2 + 1 self.relative_positions_embeddings = nn.Embedding( vocab_size, self.dim_per_head ) self.relative_attention_bias = None else: self.relative_positions_embeddings = None self.relative_attention_bias = None if max_relative_positions == -1: # rotary embeddings if rotary_dim == 0: self.rotary_dim = self.dim_per_head else: self.rotary_dim = rotary_dim self.rope, self.cos, self.sin = rotaryembeddings( self.rotary_dim, base=rotary_theta ) self.rotary_interleave = rotary_interleave self.rotary_theta = rotary_theta else: self.cos = None self.sin = None self.rotary_interleave = None if max_relative_positions == -2: # alibi positional bias self.alibi = AlibiPositionalBias(head_count) self.maybe_ckpt = checkpoint if "mha" in use_ckpting else lambda f, x: f(x) try: flash_pack = import_module("flash_attn") if ( hasattr(flash_pack, "flash_attn_func") and torch.cuda.get_device_capability()[0] >= 8 ): self.flash_attn_func = getattr(flash_pack, "flash_attn_func") self.flash_attn_with_kvcache = getattr( flash_pack, "flash_attn_with_kvcache" ) self.flash2 = True else: self.flash2 = False except ImportError: self.flash2 = False def update_dropout(self, dropout: float) -> None: self.dropout.p = dropout self.dropout_p = dropout
[docs] def forward( self, key: Tensor, value: Tensor, query: Tensor, mask: Optional[Tensor] = None, sliding_window: Optional[int] = 0, step: Optional[int] = 0, return_attn: Optional[bool] = False, self_attn_type: str = None, ) -> Tuple[Tensor, Tensor]: """ Compute the context vector and the attention vectors. Args: key (Tensor): set of `key_len` key vectors ``(batch, key_len, dim)`` value (Tensor): set of `key_len` value vectors ``(batch, key_len, dim)`` query (Tensor): set of `query_len` query vectors ``(batch, query_len, dim)`` mask: binary mask 1/0 indicating which keys have zero / non-zero attention ``(batch, query_len, key_len)`` step (int): decoding step (used for Rotary embedding) Returns: (Tensor, Tensor): * output context vectors ``(batch, query_len, dim)`` * Attention vector in heads ``(batch, head, query_len, key_len)``. """ # 1) Project key, value, and query. # as a reminder at training layer_cache[0] remains False key_pad_mask = self.layer_cache[1].get("key_pad_mask", None) if self.layer_cache[0]: # Retrieve keys and values from the KV cache (decoding mode only). if self.attn_type == "self": query, key, value = ( self.linear_query(query), self.linear_keys(query), self.linear_values(query), ) query = shape(query, self.dim_per_head) key = shape(key, self.dim_per_head) value = shape(value, self.dim_per_head) start_pos = step seqlen = query.size(2) if ( step == 0 or not self.flash2 or self.self_attn_type != "scaled-dot-flash" or self.max_relative_positions not in [0, -1] or query.size(0) > 128 or query.dtype != torch.float16 ): if self.max_relative_positions == -1: # Rotary Embeddings if seqlen + start_pos > self.rope.size(0): # Resize rotary embeddings. self.rope, self.cos, self.sin = rotaryembeddings( self.rotary_dim, maxseqlen=(seqlen + start_pos + 2048), base=self.rotary_theta, device=self.rope.device, ) rope = self.rope[start_pos : start_pos + seqlen] query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave ) if self.layer_cache[1]["keys"].numel() != 0: key = torch.cat((self.layer_cache[1]["keys"], key), dim=2) value = torch.cat((self.layer_cache[1]["values"], value), dim=2) if sliding_window > 0 and key.size(2) > sliding_window: key = key[:, :, 1:, :] value = value[:, :, 1:, :] self.layer_cache[1]["keys"] = key self.layer_cache[1]["values"] = value else: if start_pos >= self.layer_cache[1]["keys"].size(2): self.layer_cache[1]["keys"] = torch.cat( [ self.layer_cache[1]["keys"], torch.zeros( self.layer_cache[1]["keys"].shape[:-2] + (32,) + self.layer_cache[1]["keys"].shape[-1:], device=query.device, ).half(), ], dim=-2, ) self.layer_cache[1]["values"] = torch.cat( [ self.layer_cache[1]["values"], torch.zeros( self.layer_cache[1]["values"].shape[:-2] + (32,) + self.layer_cache[1]["values"].shape[-1:], device=query.device, ).half(), ], dim=-2, ) if ( self.max_relative_positions == -1 and start_pos + 32 >= self.rope.size(0) ): # Resize rotary embeddings. # We take a margin of 32 tokens as the kv_cache # is incremented by 32 tokens every 32 tokens. self.rope, self.cos, self.sin = rotaryembeddings( self.rotary_dim, maxseqlen=(start_pos + 2048), base=self.rotary_theta, device=self.rope.device, ) if sliding_window > 0 and key.size(2) > sliding_window: self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][ :, :, 1:, : ] self.layer_cache[1]["values"] = self.layer_cache[1]["values"][ :, :, 1:, : ] context = self.flash_attn_with_kvcache( query.transpose(1, 2), self.layer_cache[1]["keys"].transpose(1, 2), self.layer_cache[1]["values"].transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), rotary_cos=self.cos, rotary_sin=self.sin, cache_seqlens=step, rotary_interleaved=self.rotary_interleave, ).transpose(1, 2) attn_output = self.final_linear(unshape(context)) if self.parallel_gpu > 1: all_reduce(attn_output) return attn_output, None elif self.attn_type == "context": query = self.linear_query(query) query = shape(query, self.dim_per_head) if self.layer_cache[1]["keys"].numel() == 0: key, value = self.linear_keys(key), self.linear_values(value) key = shape(key, self.dim_per_head) value = shape(value, self.dim_per_head) else: key, value = ( self.layer_cache[1]["keys"], self.layer_cache[1]["values"], ) self.layer_cache[1]["keys"] = key self.layer_cache[1]["values"] = value if key_pad_mask is not None: # Increase the cached key pad mask by concatenation. # For decoding only. if step > 0: y = torch.zeros( (key_pad_mask.size(0), key_pad_mask.size(1), 1), dtype=torch.bool, device=key_pad_mask.device, ) self.layer_cache[1]["key_pad_mask"] = torch.cat( (key_pad_mask, y), 2 ) key_pad_mask = self.layer_cache[1]["key_pad_mask"] else: # Retrieve keys and values from linear layers (training mode). key = self.maybe_ckpt(self.linear_keys, key) value = self.maybe_ckpt(self.linear_values, value) query = self.maybe_ckpt(self.linear_query, query) key = shape(key, self.dim_per_head) value = shape(value, self.dim_per_head) query = shape(query, self.dim_per_head) if self.max_relative_positions == -1: # Rotary Embeddings start_pos = 0 seqlen = query.size(2) if seqlen > self.rope.size(0): # Resize rotary embeddings. self.rope, self.cos, self.sin = rotaryembeddings( self.rotary_dim, maxseqlen=(seqlen + 2048), base=self.rotary_theta, device=query.device, ) rope = self.rope[start_pos : start_pos + seqlen].to(query.device) query, key = apply_rotary_emb( query, key, rope, interleave=self.rotary_interleave ) b, h, l, d = key.size() if self.num_kv > 0: qh = query.size(1) # expand key on heads dimension when it's less than query heads (multi-query variant) key = key.view(b, -1, 1, l, d).repeat(1, 1, qh // h, 1, 1) key = key.view(b, qh, l, d) # expand value on heads dimension when it's less than query heads (multi-query variant) value = value.view(b, -1, 1, l, d).repeat(1, 1, qh // h, 1, 1) value = value.view(b, qh, l, d) # 2) When standard pos. enc. or rotary, use flash attention # Ultimately flashv2 will be part of pytorch https://github.com/pytorch/pytorch/pull/105602 # In the meantime: if vanilla tranformer or Rotary embeddings (not rel_pos, not alibi) # then use flash2 if seq len > 256 otherwise use xtransformer from pt2 uptream flash2 = ( self.flash2 and l > 256 # https://github.com/Dao-AILab/flash-attention/issues/591 ) if ( self.max_relative_positions in [-1, 0] and not return_attn and query.device != torch.device("cpu") and self.self_attn_type == "scaled-dot-flash" ): # Apply flash2 attention. causal = self.is_decoder and self.attn_type == "self" and mask is not None if self.is_decoder and self.attn_type == "self" and flash2: if causal: window_size = ( (-1, -1) if sliding_window == 0 else (sliding_window, 0) ) else: window_size = (-1, -1) attn_output = self.flash_attn_func( query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=self.dropout_p, causal=causal, window_size=window_size, ).transpose(1, 2) else: # Apply scaled dot product attention. with torch.backends.cuda.sdp_kernel( enable_flash=False, enable_math=True, enable_mem_efficient=True ): attn_output = scaled_dot_product_attention( query, key, value, ~mask if mask is not None else None, self.dropout_p, is_causal=causal, ) attn = None else: query /= sqrt(self.dim_per_head) # batch x num_heads x query_len x key_len scores = torch.matmul(query, key.transpose(2, 3)) if self.relative_attention_bias is not None: q_len = key.size(2) if self.layer_cache[0] else query.size(2) relative_position_bucket = compute_bias( q_len, key.size(2), self.is_decoder, self.max_relative_positions, self.relative_positions_buckets, device=key.device, ) values = self.relative_attention_bias( relative_position_bucket ) # shape (query_length, key_length, num_heads) position_bias = values.permute([2, 0, 1]).unsqueeze( 0 ) # shape (1, num_heads, query_length, key_length) if self.layer_cache[0]: position_bias = position_bias[:, :, -query.size(2) :, :] scores.add_(position_bias) elif self.relative_positions_embeddings is not None: key_len = key.size(2) # 1 or key_len x key_len relative_positions_matrix = gen_relative_positions( key_len, self.max_relative_positions, cache=self.layer_cache[0], device=key.device, ) # 1 or key_len x key_len x dim_per_head relations_keys = self.relative_positions_embeddings( relative_positions_matrix ) scores.add_(relative_matmul(query, relations_keys, True)) elif self.max_relative_positions == -2: # Alibi scores = self.alibi(scores) scores = scores.float() if key_pad_mask is not None and mask is None: mask = key_pad_mask.unsqueeze(1) if mask is not None: # not 100% necessary but expand to nb of heads mask = mask.expand(-1, self.head_count // self.parallel_gpu, -1, -1) # now mask and scores have the same shape scores = scores.masked_fill(mask, -1e18) # 3) Apply attention dropout and compute context vectors. attn = self.softmax(scores).to(query.dtype) drop_attn = self.dropout(attn) if self.dropout_p > 0 else attn attn_output = torch.matmul(drop_attn, value) if self.relative_positions_embeddings is not None: # We use the same embeddings for key and value relations_values = relations_keys attn_output.add_(relative_matmul(drop_attn, relations_values, False)) context = unshape(attn_output) if key_pad_mask is not None: if key_pad_mask.size(0) > 1 and context.size(1) > 1: x = key_pad_mask.squeeze(1).unsqueeze(2).expand(-1, -1, context.size(2)) context = context.masked_fill(x, 0) if self.layer_cache[0]: attn_output = self.final_linear(context) else: attn_output = self.maybe_ckpt(self.final_linear, context) if self.parallel_gpu > 1: all_reduce(attn_output) return attn_output, attn