Source code for cerebras.modelzoo.common.utils.model.attention_utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np


[docs]class SparseAttentionBuilder: def __init__( self, num_heads, max_sequence_length, num_different_head_attn_configs=1 ): self.num_heads = num_heads self.max_sequence_length = max_sequence_length self.num_different_head_attn_configs = num_different_head_attn_configs self.attention_mask = np.zeros( (self.num_heads, self.max_sequence_length, self.max_sequence_length) ) self.mask_is_built = False def build_mask(self): raise NotImplementedError( "Make sure that a `build_mask` function " f"is correctly implemented in {self.__class__.__module__}" ) def get_pytorch_mask(self, dtype=None): import torch # Avoid rebuilding mask if already exists if not self.mask_is_built: self.build_mask() attn_matrix = self.attention_mask if self.num_different_head_attn_configs == 1: attn_matrix = self.attention_mask[0, :, :] mask = 1 - attn_matrix sparse_attn_mask = torch.Tensor(mask).to(dtype) return sparse_attn_mask
[docs]class FixedSparseAttentionBuilder(SparseAttentionBuilder): def __init__( self, num_heads, max_sequence_length, num_different_head_attn_configs=1, local_attn_ctx=4, global_attn_ctx=1, attention_type="unidirectional", horizontal_global_attention=False, ): super().__init__( num_heads, max_sequence_length, num_different_head_attn_configs ) self.local_attn_ctx = local_attn_ctx self.global_attn_ctx = global_attn_ctx assert self.local_attn_ctx % self.global_attn_ctx == 0, ( f"Number of blocks in a local window, {self.local_attn_ctx}, " f"must be dividable by number of global blocks, {self.global_attn_ctx}!" ) self.attention_type = attention_type.lower() assert ( self.attention_type == "unidirectional" or self.attention_type == "bidirectional" ), "Only 'unidirectional' or 'bidirectional' attention types are supported" self.horizontal_global_attention = horizontal_global_attention if self.horizontal_global_attention: assert ( self.attention_type == "bidirectional" ), "Horizontal global attention is only supported by the 'bidirectional' attention type" def set_local_attention(self, head_id): for i in range(0, self.max_sequence_length, self.local_attn_ctx): end = min(i + self.local_attn_ctx, self.max_sequence_length) for row in range(i, end): col_end = ( row + 1 if self.attention_type == "unidirectional" else end ) self.attention_mask[head_id, row, i:col_end] = 1 def set_global_attention(self, head_id): first_global_attn_col_idx = ( self.local_attn_ctx - (1 + head_id % self.num_different_head_attn_configs) * self.global_attn_ctx ) # set all global attention in all self.local_attn_ctx windows except the last one if (in last local window) end = self.max_sequence_length - ( self.max_sequence_length % self.local_attn_ctx ) for i in range(first_global_attn_col_idx, end, self.local_attn_ctx): # veritcal attn first_global_attn_row_idx = ( 0 if self.attention_type == "bidirectional" else i ) self.attention_mask[ head_id, first_global_attn_row_idx:, i : i + self.global_attn_ctx, ] = 1 if self.horizontal_global_attention: self.attention_mask[ head_id, i : i + self.global_attn_ctx, : ] = 1 # set last global attention in the last self.local_attn_ctx window if end < self.max_sequence_length: start = min( end + first_global_attn_col_idx, self.max_sequence_length - self.global_attn_ctx, ) end = start + self.global_attn_ctx first_global_attn_row_idx = ( 0 if self.attention_type == "bidirectional" else start ) self.attention_mask[ head_id, first_global_attn_row_idx:, start:end ] = 1 if self.horizontal_global_attention: self.attention_mask[head_id, start:end, :] = 1 def trim_to_autoregressive_format(self, head_id): # to ensure in the unidirectional case, the mask is in an autoregressive format autoregressive_mask = np.ones( (self.max_sequence_length, self.max_sequence_length) ) autoregressive_mask = np.tril(autoregressive_mask) self.attention_mask[head_id] = np.multiply( autoregressive_mask, self.attention_mask[head_id] ) def build_mask(self): num_head_attn_configs_to_build = ( self.num_heads if self.num_different_head_attn_configs > 1 else 1 ) for head_id in range(0, num_head_attn_configs_to_build): self.set_local_attention(head_id) self.set_global_attention(head_id) if self.attention_type == "unidirectional": self.trim_to_autoregressive_format(head_id) self.mask_is_built = True
[docs]class StridedSparseAttentionBuilder(SparseAttentionBuilder): def __init__( self, num_heads, max_sequence_length, local_attn_ctx=4, stride=2, attention_type="unidirectional", ): super().__init__(num_heads, max_sequence_length, 1) self.local_attn_ctx = local_attn_ctx self.stride = stride self.attention_type = attention_type assert ( self.attention_type == "unidirectional" or self.attention_type == "bidirectional" ), "Only 'unidirectional' or 'bidirectional' attention types are supported" def set_local_attention(self): for row_idx in range(0, self.max_sequence_length): col_start = max(0, row_idx - self.local_attn_ctx + 1) col_end = ( row_idx + 1 if self.attention_type == "unidirectional" else row_idx + self.local_attn_ctx ) self.attention_mask[0, row_idx, col_start:col_end] = 1 def set_global_attention(self): for row_idx in range(0, self.max_sequence_length): col_end = ( row_idx - self.local_attn_ctx ) # if self.attention_type == 'unidirectional' else self.max_sequence_length - 1 col_start = -1 if col_end < 0: continue global_attention_indices = [ *range(col_end, col_start, -self.stride) ] self.attention_mask[0, row_idx, global_attention_indices] = 1 if self.attention_type == "bidirectional": for row_idx in range(0, self.max_sequence_length): col_start = row_idx + self.local_attn_ctx col_end = self.max_sequence_length global_attention_indices = [ *range(col_start, col_end, self.stride) ] self.attention_mask[0, row_idx, global_attention_indices] = 1 def build_mask(self): self.set_local_attention() self.set_global_attention() self.mask_is_built = True