# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from cerebras.modelzoo.layers.create_initializer import create_initializer
from cerebras.pytorch.utils.kernel import kernel_annotater
[docs]class MultiheadAttention(nn.Module):
"""Multi-head attention layer. Adapted from:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/activation.html#MultiheadAttention.
Args:
embed_dim (int): Number of input units in each projection output
num_heads (int): Number of attention heads.
inner_dim (int): Number of output units in attention query/key/value projection. Defaults to ``embed_dim``.
dropout (float): Dropout rate for key-query weights. Defaults to 0.0.
batch_first (bool): If True, then the input and output tensors are
provided as (batch, seq, feature), otherwise the format will be
(seq, batch, feature). Default: True (batch, seq, feature).
add_bias_kv (bool): If specified, adds bias to the key and value sequences at dim=0. Default: False.
add_zero_attn (bool): If specified, adds a new batch of zeros to the key and value
sequences at dim=1. Default: False
kdim (int): Number of input units in the key projection
vdim (int): Number of input units in the value projection
use_projection_bias (bool): Whether to use bias in the key, query, and
value projections.
use_ffn_bias (bool): Whether to use bias in the output projection.
attention_initializer (str): Projection kernel initializer. Defaults to
``xavier_uniform``.
attention_q_initializer: Query projection kernel initializer. If not
specified, the query will be initialized via ``attention_initializer``
output_layer_initializer (str or initializer): If not None, use this
initializer for the output transform layer. Defaults to None.
bias_initializer (str): Bias initializer. Defaults to ``zeros``.
attention_type (str): The attention variant to execute. Currently
accepts ``dot_product`` and ``scaled_dot_product``. Defaults to
``scaled_dot_product``.
scale_qk_dot_by_d (bool): If ``True`` scales QK^T dot product by d(=hidden/d_head) instead of sqrt(d).
attention_logits_alpha (float): Scales the QK^T dot product. Used to stabilize logits in muP training.
softmax_dtype_fp32 (bool): Use an FP32 softmax implementation.
attention_kernel (str | None): Kernel to use. Uses ``default`` if None.
See accepted values below.
``None`` - Default implementation.
``fast_attention`` - Experimental optimized implementation.
device (optional): Device to create the model parameters on, can be a cuda device or CS device.
"""
def __init__(
self,
embed_dim,
num_heads,
inner_dim=None,
dropout=0.0,
batch_first=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
use_projection_bias=None,
use_ffn_bias=False,
attention_initializer="xavier_uniform",
attention_q_initializer=None,
output_layer_initializer=None,
bias_initializer="zeros",
attention_type="scaled_dot_product",
scale_qk_dot_by_d=False,
attention_logits_alpha=1.0,
q_projection_scale=1.0,
k_projection_scale=1.0,
v_projection_scale=1.0,
output_projection_scale=1.0,
softmax_dtype_fp32=True,
attention_kernel=None,
scale_qk_dot_by_layer_idx=False,
logit_softcapping=None,
device=None,
):
_SUPPORTED_ATTENTION_TYPES = [
"dot_product",
"scaled_dot_product",
"scaled_cosine",
]
assert (
attention_type in _SUPPORTED_ATTENTION_TYPES
), f"Attention type {attention_type} is not supported."
assert (
embed_dim % num_heads == 0
), f"embed_dim {embed_dim} must be divisible by num_heads {num_heads}."
if inner_dim is not None:
assert (
inner_dim % num_heads == 0
), "inner_dim must be divisible by num_heads."
assert batch_first, "Currently, only batch_first=True is supported"
assert not add_bias_kv, "add_bias_kv=True is not supported."
assert not add_zero_attn, "add_zero_attn=True is not supported."
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.inner_dim = inner_dim if inner_dim is not None else embed_dim
self.num_heads = num_heads
self.attention_type = attention_type
self.use_projection_bias = use_projection_bias
self.use_ffn_bias = use_ffn_bias
self.proj_q_dense_layer = nn.Linear(
self.embed_dim,
self.inner_dim,
bias=use_projection_bias,
device=device,
)
self.proj_k_dense_layer = nn.Linear(
self.kdim,
self.inner_dim,
bias=use_projection_bias,
device=device,
)
self.proj_v_dense_layer = nn.Linear(
self.vdim,
self.inner_dim,
bias=use_projection_bias,
device=device,
)
if self.attention_type == "scaled_cosine":
self.logits_scale = nn.Parameter(
torch.log(10 * torch.ones((self.num_heads, 1, 1)))
)
self.dropout_layer = nn.Dropout(dropout)
self.proj_output_dense_layer = nn.Linear(
self.inner_dim,
self.embed_dim,
bias=use_ffn_bias,
device=device,
)
# handle initialization
output_initializer = attention_initializer
if output_layer_initializer is not None:
output_initializer = output_layer_initializer
self.initializer = attention_initializer
self.query_initializer = self.initializer
if attention_q_initializer is not None:
self.query_initializer = attention_q_initializer
self.output_initializer = output_initializer
self.bias_initializer = bias_initializer
self.softmax_dtype_fp32 = softmax_dtype_fp32
if attention_kernel:
attention_kernel = attention_kernel.upper()
self.using_kernel = kernel_annotater(attention_kernel)
self.scale_qk_dot_by_d = scale_qk_dot_by_d
self.attention_logits_alpha = attention_logits_alpha
self.q_projection_scale = q_projection_scale
self.k_projection_scale = k_projection_scale
self.v_projection_scale = v_projection_scale
self.output_projection_scale = output_projection_scale
self.scale_qk_dot_by_layer_idx = scale_qk_dot_by_layer_idx
self.logit_softcapping = logit_softcapping
self.__reset_parameters()
def reset_parameters(self):
self.__reset_parameters()
def __reset_parameters(self):
# bias initialization
bias_initializer = create_initializer(self.bias_initializer)
if self.use_projection_bias:
bias_initializer(self.proj_q_dense_layer.bias.data)
bias_initializer(self.proj_k_dense_layer.bias.data)
bias_initializer(self.proj_v_dense_layer.bias.data)
if self.use_ffn_bias:
bias_initializer(self.proj_output_dense_layer.bias.data)
# q projection
weight_initializer = create_initializer(self.query_initializer)
weight_initializer(self.proj_q_dense_layer.weight.data)
# k, v projections
weight_initializer = create_initializer(self.initializer)
weight_initializer(self.proj_k_dense_layer.weight.data)
weight_initializer(self.proj_v_dense_layer.weight.data)
# output projections
weight_initializer = create_initializer(self.output_initializer)
weight_initializer(self.proj_output_dense_layer.weight.data)
[docs] def forward(
self,
q,
k,
v,
attn_mask=None,
key_padding_mask=None,
need_weights=False,
average_attn_weights=True,
past_kv=None,
cache_present_kv=False,
past_kv_self_attn=True,
position_bias=None,
rotary_position_embedding_helper=None,
layer_idx=None,
**extra_args,
):
"""Applies the attention mechanism to queries ``q``, keys ``k`` and values ``v``.
Args:
q (Tensor): Queries, shape ``[batch_size, seq_length, embed_dim]``.
k (Tensor): Keys, shape ``[batch_size, seq_length, embed_dim]``.
v (Tensor): Values, shape ``[batch_size, seq_length, embed_dim]``.
attn_mask (Tensor): Attention mask. Can be 2D of shape
``[batch_size, seq_length]``, or 3D of shape
``[batch, query_length, seq_length]``.
key_padding_mask (Tensor): If specified, a mask of shape (N, S) indicating
which elements within key to ignore for the purpose of attention
(i.e. treat as “padding”). Defaults to None.
need_weights (bool): If specified, returns attn_output_weights in addition
to attn_outputs. Default: False.
average_attn_weights (bool): If true, indicates that the returned attn_weights
should be averaged across heads. Otherwise, attn_weights are provided
separately per head. Note that this flag only has an effect when
need_weights=True. Default: True (i.e. average weights across heads)
past_kv (tuple(tensor, tensor)): Past keys and values. Tensors have shape
``[batch_size, num_heads, seq_length, embed_dim / num_heads]``.
The 0th and 1st tensor contain the past keys and values, respectively.
Defaults to ``None``.
cache_present_kv (bool): Specifies if the present keys and values
must be cached and returned. Needed to speed up the
computations when the decoder is called within an
autoregressive loop. Defaults to ``False``.
past_kv_self_attn (bool): Specifies whether the past keys & values should be
used for self-attention (true) or cross-attention (false). Ignored if
past_kv is not provided. Default: True
position_bias (Tensor): Tensor containing position bias to apply in attention
with shape ``[num_heads, query_length, key_length]``.
rotary_position_embedding_helper (Optional[RotaryPositionEmbeddingHelper]):
A helper class to apply rotary embedding on the input tensor.
Returns:
Attention output tensor with shape ``[batch_size, seq_length, embed_dim]``.
"""
assert not (
rotary_position_embedding_helper and position_bias
), "Cannot specify both rotary and relative position embeddings, pick one!"
assert (
past_kv is None and not cache_present_kv
), "Cannot provide past_kv because inference is not supported yet."
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = q.shape[:2]
real_seq_length = seq_length
assert (
real_seq_length > 1
), "Sequence length 1 is currently unsupported."
constant_pos_mask = None
if extra_args and ("constant_pos_mask" in extra_args):
constant_pos_mask = extra_args["constant_pos_mask"]
# construct query, key and value vector with a linear projection and split into heads
q = self.construct_query_vector(
q, attn_mask=attn_mask, key_padding_mask=key_padding_mask
)
k = self.construct_key_vector(
k, attn_mask=attn_mask, key_padding_mask=key_padding_mask
)
v = self.construct_value_vector(
v, attn_mask=attn_mask, key_padding_mask=key_padding_mask
)
offset_length, real_seq_length = self.get_sequence_length(
past_kv, real_seq_length
)
# Scale k for muP transfer before Transpose to get around the compile issue
if (
self.scale_qk_dot_by_d
and self.attention_type == "scaled_dot_product"
):
depth = self.inner_dim // self.num_heads
k = k * torch.tensor(1 / float(depth) ** 0.5, dtype=k.dtype)
# rotary embedding helper
k = self.apply_rotary_position_embedding(
k,
rotary_position_embedding_helper,
real_seq_length,
offset_length,
constant_pos_mask=constant_pos_mask,
)
q = self.apply_rotary_position_embedding(
q,
rotary_position_embedding_helper,
real_seq_length,
offset_length,
constant_pos_mask=constant_pos_mask,
)
# q, k now have shape [batch_size, num_heads, seq_length, head_dim]
q = self.process_q_before_logits_calc(q)
k = self.process_k_before_logits_calc(k)
v = self.process_v_before_logits_calc(v)
k, v = self.process_past_kv(past_kv, past_kv_self_attn, k, v)
present_kv = self.construct_present_kv(cache_present_kv, k, v)
logits = self.calculate_attention_logits(q, k, layer_idx)
attn_mask_processed = self.process_attention_mask(attn_mask, past_kv, q)
key_padding_mask_processed = self.process_key_padding_mask(
key_padding_mask, attn_mask, past_kv, q
)
attention_bias = self.combine_masks(
attn_mask_processed, key_padding_mask_processed
)
logits = self.apply_position_bias(logits, position_bias)
logits = self.apply_attention_bias(logits, attention_bias)
attention_scores = self.calculate_attention_scores(logits)
attention_output = self.calculate_attention_output(attention_scores, v)
if cache_present_kv:
return attention_output, present_kv
if not need_weights:
return attention_output
else:
if average_attn_weights:
attention_scores = torch.mean(attention_scores, dim=1).squeeze()
return (
attention_output,
attention_scores,
)
def _split_heads(self, x, rotary):
"""Split x into different heads, and transpose the resulting value. The
tensor is transposed to insure the inner dimensions hold the correct
values during the matrix multiplication.
Args:
x: A tensor with shape ``[batch_size, seq_length, hidden_size]``.
Returns:
If rotary is true, a tensor with shape
``[batch_size, seq_length, num_heads, hidden_size/num_heads]``
else, a tensor with shape
``[batch_size, num_heads, seq_length, hidden_size/num_heads]``
"""
batch_size, seq_length, hidden_size = x.shape
depth = hidden_size // self.num_heads
# Transpose the result if not rotary
if rotary:
return x.view(batch_size, seq_length, self.num_heads, depth)
return x.view(batch_size, seq_length, self.num_heads, depth).transpose(
1, 2
)
def _combine_heads(self, x):
"""Combine tensor that has been split.
Args:
x: A tensor with shape
``[batch_size, num_heads, seq_length, embed_dim/num_heads]``.
Returns:
A tensor with shape ``[batch_size, seq_length, embed_dim]``.
"""
batch_size, num_heads, seq_length, depth = x.shape
return x.transpose(1, 2).reshape(
batch_size, seq_length, num_heads * depth
)
def construct_query_vector(self, q, attn_mask=None, key_padding_mask=None):
# linear projection
q = self.proj_q_dense_layer(q) * self.q_projection_scale
# split into heads
q = self._split_heads(q, rotary=True)
return q
def construct_key_vector(self, k, attn_mask=None, key_padding_mask=None):
# linear projection
k = self.proj_k_dense_layer(k) * self.k_projection_scale
# split into heads
k = self._split_heads(k, rotary=True)
return k
def construct_value_vector(self, v, attn_mask=None, key_padding_mask=None):
# linear projection
v = self.proj_v_dense_layer(v) * self.v_projection_scale
# split into heads
v = self._split_heads(v, rotary=False)
return v
def get_sequence_length(self, past_kv, real_seq_length):
offset_length = 0
if past_kv is not None:
offset_length = past_kv[0].shape[-2]
real_seq_length += offset_length
return offset_length, real_seq_length
def apply_rotary_position_embedding(
self,
vector,
rotary_position_embedding_helper,
real_seq_length,
offset_length,
constant_pos_mask=None,
):
if rotary_position_embedding_helper:
vector = rotary_position_embedding_helper.rotate_tensor(
vector,
real_seq_length,
offset=offset_length,
constant_pos_mask=constant_pos_mask,
)
vector = vector.transpose(1, 2)
return vector
def process_q_before_logits_calc(self, q):
# May get overriden but other attention schemas
return q
def process_k_before_logits_calc(self, k):
# May get overriden but other attention schemas
return k
def process_v_before_logits_calc(self, v):
# May get overriden but other attention schemas
return v
def process_past_kv(self, past_kv, past_kv_self_attn, k, v):
if past_kv is not None:
k_past, v_past = past_kv[0], past_kv[1]
if past_kv_self_attn:
k = torch.cat([k_past, k], dim=-2)
v = torch.cat([v_past, v], dim=-2)
else:
k, v = k_past, v_past
return k, v
def construct_present_kv(self, cache_present_kv, k, v):
present_kv = None
if cache_present_kv:
present_kv = (k, v)
return present_kv
def calculate_attention_logits(self, q, k, layer_idx=None):
if self.attention_type == "scaled_dot_product":
depth = self.inner_dim // self.num_heads
q = q * torch.tensor(
1 / float(depth) ** 0.5,
dtype=q.dtype,
)
elif self.attention_type == "scaled_cosine":
q = F.normalize(q, p=2.0, dim=-1)
k = F.normalize(k, p=2.0, dim=-1)
if self.scale_qk_dot_by_layer_idx:
q = q * torch.tensor(
1 / float(layer_idx + 1),
dtype=q.dtype,
)
# calculate dot product attention
logits = self.attention_logits_alpha * self.using_kernel(torch.matmul)(
q, k.transpose(-1, -2)
) # (B, H, Lq, E) * (B, H, E, Lk) -> (B, H, Lq, Lk)
if self.attention_type == "scaled_cosine":
logits_scale = torch.clamp(
self.logits_scale, max=math.log(1.0 / 0.01)
).exp()
logits = logits * logits_scale
if self.logit_softcapping is not None:
logits = (
torch.tanh(logits / self.logit_softcapping)
* self.logit_softcapping
)
return logits
def process_attention_mask(self, attn_mask, past_kv, q):
attn_mask_reshaped = None
# apply attention mask
if attn_mask is not None:
# 2D [query_length, sequence_length]
# 3D [batch_size, query_length, sequence_length]
# 4D [batch_size, num_heads, query_length, sequence_length]
assert len(attn_mask.shape) in [
2,
3,
4,
], "Only 2D, 3D or 4D masks are supported for now"
if (
not attn_mask.is_floating_point()
and not attn_mask.dtype == torch.bool
):
attn_mask = attn_mask.to(torch.bool)
# for broadcasting over all heads
num_heads = 1
if len(attn_mask.shape) == 2:
if past_kv is not None:
past_mask = torch.zeros(
(q.shape[0], past_kv.shape[-2]),
dtype=attn_mask.dtype,
)
attn_mask = torch.cat([past_mask, attn_mask], axis=-1)
query_length, all_seq_length = attn_mask.shape
# for broadcasting over all batches
batch_size = 1
elif len(attn_mask.shape) == 3:
if past_kv is not None:
past_mask = torch.zeros(
(q.shape[0], q.shape[-2], past_kv.shape[-2]),
dtype=attn_mask.dtype,
)
attn_mask = torch.cat([past_mask, attn_mask], axis=-1)
batch_size, query_length, all_seq_length = attn_mask.shape
else:
num_heads = attn_mask.shape[1]
if past_kv is not None:
past_mask = torch.zeros(
(q.shape[0], num_heads, q.shape[-2], past_kv.shape[-2]),
dtype=attn_mask.dtype,
)
attn_mask = torch.cat([past_mask, attn_mask], axis=-1)
(
batch_size,
num_heads,
query_length,
all_seq_length,
) = attn_mask.shape
# compute the attention_bias based on the mask.
attn_mask_reshaped = attn_mask.view(
batch_size, num_heads, query_length, all_seq_length
)
return attn_mask_reshaped
def process_key_padding_mask(self, key_padding_mask, attn_mask, past_kv, q):
key_padding_mask_reshaped = None
if key_padding_mask is not None:
if (
not key_padding_mask.is_floating_point()
and not key_padding_mask.dtype == torch.bool
):
key_padding_mask = key_padding_mask.to(torch.bool)
num_heads = 1
query_length = 1
if len(key_padding_mask.shape) == 2:
if past_kv is not None:
past_mask = torch.zeros(
(q.shape[0], past_kv.shape[-2]),
dtype=key_padding_mask.dtype,
)
key_padding_mask = torch.cat(
[past_mask, key_padding_mask], axis=-1
)
batch_size, all_seq_length = key_padding_mask.shape
elif len(key_padding_mask.shape) == 3:
if past_kv is not None:
past_mask = torch.zeros(
(q.shape[0], q.shape[-2], past_kv.shape[-2]),
dtype=key_padding_mask.dtype,
)
key_padding_mask = torch.cat(
[past_mask, key_padding_mask], axis=-1
)
(
batch_size,
query_length,
all_seq_length,
) = key_padding_mask.shape
else:
num_heads = key_padding_mask.shape[1]
if past_kv is not None:
past_mask = torch.zeros(
(q.shape[0], num_heads, q.shape[-2], past_kv.shape[-2]),
dtype=key_padding_mask.dtype,
)
key_padding_mask = torch.cat(
[past_mask, key_padding_mask], axis=-1
)
(
batch_size,
num_heads,
query_length,
all_seq_length,
) = key_padding_mask.shape
# compute the attention_bias based on the mask.
key_padding_mask_reshaped = key_padding_mask.view(
batch_size, num_heads, query_length, all_seq_length
)
return key_padding_mask_reshaped
def combine_masks(self, attn_mask_reshaped, key_padding_mask_reshaped):
attention_bias = None
if (
attn_mask_reshaped is not None
and key_padding_mask_reshaped is not None
):
# Need to broadcast over dimensions before merging
(
attn_mask_reshaped,
key_padding_mask_reshaped,
) = torch.broadcast_tensors(
attn_mask_reshaped, key_padding_mask_reshaped
)
# Need to merge attention mask and key padding mask:
attn_mask_is_float = attn_mask_reshaped.is_floating_point()
key_padding_is_float = key_padding_mask_reshaped.is_floating_point()
if attn_mask_is_float and key_padding_is_float:
attention_bias = attn_mask_reshaped + key_padding_mask_reshaped
elif attn_mask_is_float:
mask_neg_inf = torch.tensor(
float("-inf"), dtype=attn_mask_reshaped.dtype
)
attention_bias = attn_mask_reshaped.masked_fill(
key_padding_mask_reshaped, mask_neg_inf
)
elif key_padding_is_float:
mask_neg_inf = torch.tensor(
float("-inf"), dtype=key_padding_mask_reshaped.dtype
)
attention_bias = key_padding_mask_reshaped.masked_fill(
attn_mask_reshaped, mask_neg_inf
)
else:
attention_bias = attn_mask_reshaped.logical_or(
key_padding_mask_reshaped
)
elif attn_mask_reshaped is not None:
attention_bias = attn_mask_reshaped
elif key_padding_mask_reshaped is not None:
attention_bias = key_padding_mask_reshaped
return attention_bias
def apply_attention_bias(self, logits, attention_bias):
if attention_bias is not None:
if attention_bias.dtype == torch.bool:
final_attention_bias = torch.zeros_like(
attention_bias, dtype=logits.dtype
)
mask_neg_inf = torch.tensor(
float("-inf"), dtype=final_attention_bias.dtype
)
final_attention_bias.masked_fill_(attention_bias, mask_neg_inf)
attention_bias = final_attention_bias
logits += attention_bias.type_as(logits).broadcast_to(logits.shape)
return logits
def apply_position_bias(self, logits, position_bias):
# Add relative position bias, if any
if position_bias is not None:
logits += position_bias.type_as(logits).broadcast_to(logits.shape)
return logits
def calculate_attention_scores(self, logits):
if self.softmax_dtype_fp32 and logits.dtype != torch.float32:
attention_scores = nn.functional.softmax(
logits.float(), dim=-1
).type_as(logits)
else:
attention_scores = nn.functional.softmax(logits, dim=-1)
attention_scores = self.dropout_layer(attention_scores)
return attention_scores
def calculate_attention_output(self, attention_scores, v):
# Shape: (batch_size, num_heads, query_length, embed_dim / num_heads)
attention_output = self.using_kernel(torch.matmul)(attention_scores, v)
# Recombine heads --> [batch_size, seq_length, embed_dim].
attention_output = self._combine_heads(attention_output)
# Run the combined outputs through another linear projection layer.
attention_output = (
self.output_projection_scale
* self.proj_output_dense_layer(attention_output)
)
return attention_output
def check_extra_params(params):
assert (
k in {"attention_kernel"} for k in params.keys()
), "Overflow extra params for attention module `MultiheadAttention`"