Source code for cerebras.modelzoo.layers.AlibiPositionEmbeddingLayer

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch
import torch.nn as nn

import cerebras.pytorch as cstorch
from cerebras.modelzoo.layers.create_initializer import create_initializer
from cerebras.modelzoo.layers.RelativePositionEmbeddingLayer import (
    RelativePositionEmbeddingLayer,
)


[docs]class AlibiPositionEmbeddingLayer(nn.Module): """Alibi Position Embedding Layer, Symmetric case with bidirectional supported alibi bias as in paper: https://arxiv.org/abs/2108.12409 Args: num_heads (int): number of attention heads. slopes (Tensor): slope values to use for alibi heads. Shape: [num_heads, 1]. Default to `None`. alibi_trainable_slopes (bool): whether the alibi slopes are trainable parameters. slopes_initializer (str): initializer for alibi slopes if it's trainable. Defaults to ``xavier_uniform``. Returns: position_bias (Tensor): Relative position bias, to be used in attention masking """ def __init__( self, num_heads, slopes=None, alibi_trainable_slopes=False, slopes_initializer="xavier_uniform", scaling_factor=1.0, constant_pos_embedding=None, ): super(AlibiPositionEmbeddingLayer, self).__init__() assert slopes is None, "Customized slope is not supported yet." self.num_heads = num_heads self.alibi_trainable_slopes = alibi_trainable_slopes if not slopes: if self.alibi_trainable_slopes: slopes = torch.zeros([num_heads, 1]) self.slopes_initializer = slopes_initializer else: slopes = torch.tensor( AlibiPositionEmbeddingLayer._get_alibi_slopes(num_heads) ).unsqueeze(-1) else: if self.alibi_trainable_slopes: self.slopes_initializer = slopes_initializer self.slopes = nn.parameter.Parameter( slopes, requires_grad=self.alibi_trainable_slopes ) self.scaling_factor = scaling_factor self.constant_pos_embedding = constant_pos_embedding self.__reset_parameters() def reset_parameters(self): self.__reset_parameters() def __reset_parameters(self): if self.alibi_trainable_slopes: create_initializer(self.slopes_initializer)(self.slopes.data)
[docs] def forward( self, seq_length, key_length, past_kv=None, constant_pos_mask=None, batch_size=None, ): """Return the position bias based on the alibi slopes. Args: seq_length (int): the length of query tokens. key_length (int): the length of key tokens. Returns: Position bias tensor with shape [num_heads, query_length, key_length] """ position_bias = self._compute_alibi_bias( seq_length, key_length, constant_pos_mask=constant_pos_mask, batch_size=batch_size, ) # if key and values are already calculated we want only # the last query position bias if past_kv is not None: position_bias = position_bias[:, :, -seq_length, :] return position_bias
@staticmethod def _get_alibi_slopes(n): def get_slopes_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2( n ) # In the paper, we only train models that have 2^a heads for some a. This function has else: # some good properties that only occur when the input is a power of 2. To maintain that even closest_power_of_2 = 2 ** math.floor( math.log2(n) ) # when the number of heads is not a power of 2, we use this workaround. return ( get_slopes_power_of_2(closest_power_of_2) + AlibiPositionEmbeddingLayer._get_alibi_slopes( 2 * closest_power_of_2 )[0::2][: n - closest_power_of_2] ) def _alibi_implementation_expand(self, seq_length, key_length, slopes): relative_position = ( RelativePositionEmbeddingLayer.compute_raw_relative_positions( seq_length, key_length, device=slopes.device ) ) relative_position = ( torch.abs(relative_position) .unsqueeze(0) .expand(self.num_heads, -1, -1) ) alibi = (slopes / -self.scaling_factor).unsqueeze(1) * relative_position return alibi def _constant_image_pos(self, relative_position, constant_pos_mask): mask_dim1 = constant_pos_mask[:, :, None] mask_dim2 = constant_pos_mask[:, None, :] # TODO: only support two modalities, the values in `constant_pos_mask` # are either 0 or 1. image_mask = (1 - mask_dim1) * (1 - mask_dim2) image_pos_id = (1 - image_mask) * self.constant_pos_embedding # Mask image portion's pos id to 0 and then set it to the value of # 'self.constant_pos_embedding'. relative_position = relative_position * image_mask + image_pos_id return relative_position def _alibi_implementation_batched( self, batch, seq_length, key_length, slopes, constant_pos_mask ): context_position = torch.arange( seq_length, device=slopes.device, dtype=torch.float32 )[None, :, None].broadcast_to(batch, seq_length, key_length) memory_position = torch.arange( key_length, device=slopes.device, dtype=torch.float32 )[None, None, :].broadcast_to(batch, key_length, seq_length) relative_position = memory_position - context_position if constant_pos_mask is not None: # apply constant position embedding. relative_position = self._constant_image_pos( relative_position, constant_pos_mask ) relative_position = ( torch.abs(relative_position) .unsqueeze(1) .expand(batch, self.num_heads, -1, -1) ) slopes = slopes / -self.scaling_factor # Cast weight to half type before broadcasting, or adding batch dimension # to it. Cast weight back to float32 after broadcasting. slopes = slopes.to(cstorch.amp.get_half_dtype())[ None, :, :, None ].broadcast_to(relative_position.shape) slopes = slopes.to(torch.float32) alibi = slopes * relative_position return alibi def _compute_alibi_bias( self, seq_length, key_length, slopes=None, constant_pos_mask=None, batch_size=None, ): assert (self.constant_pos_embedding is None) or ( self.constant_pos_embedding is not None and constant_pos_mask is not None ), "constant_pos_embedding is enabled, but 'constant_pos_mask' is not provided." if slopes is None: slopes = self.slopes if ( constant_pos_mask is not None and self.constant_pos_embedding is not None ): return self._alibi_implementation_batched( batch_size, seq_length, key_length, slopes, constant_pos_mask ) return self._alibi_implementation_expand(seq_length, key_length, slopes)