Source code for cerebras.modelzoo.layers.TransformerEncoderLayer

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py
"""

from typing import Callable, Optional, Type, Union

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Dropout, LayerNorm

from cerebras.modelzoo.layers.AttentionHelper import get_attention_module
from cerebras.modelzoo.layers.FeedForwardNetwork import (
    FeedForwardNetwork,
    FeedForwardNetworkConfig,
)
from cerebras.modelzoo.layers.RotaryPositionEmbeddingHelper import (
    RotaryPositionEmbeddingHelper,
)
from cerebras.modelzoo.layers.StochasticDepth import StochasticDepth


[docs]class TransformerEncoderLayer(nn.Module): r"""TransformerEncoderLayer is made up of self-attn and feedforward network. This standard encoder layer is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010. Users may modify or implement in a different way during application. Args: d_model: the number of expected features in the input (required). nhead: the number of heads in the multihead attention models (required). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable. Default: gelu layer_norm_eps: the eps value in layer normalization components (default=1e-5). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). norm_layer: the normalization class that will be used before/after FF layers (default=nn.LayerNorm) norm_first: if ``True``, layer norm is done prior to attention and feedforward operations, respectively. Otherwise it's done after. Default: ``False`` (after). attention_dropout_rate: Attention dropout rate. If None, defaults to dropout. use_projection_bias_in_attention: Add bias to Q,K,V projections in the Attention layer. Defaults to False. attention_type: Should be in ["scaled_dot_product", "dot_product"] scale_qk_dot_by_d (bool): If ``True`` scales QK^T dot product by d(=hidden/d_head) instead of sqrt(d). attention_softmax_fp32: Use FP32 softmax in attention block. attention_inner_dim (int): Number of output units in attention query/key/value projection. Defaults to d_model add_cross_attention: If ``True``, adds cross-attention layer between encoder/decoder, otherwise, only self-attention is used in the decoder (GPT-style models should set to ``False``) use_ffn_bias_in_attention: Add bias in the concluding FFN in the Attention layer. Defaults to False. use_ffn_bias: Add bias in all dense layers of the decoder's ffn sublayer attention_initializer: Attention layer initializer. Defaults to "xavier_uniform". attention_q_initializer: Query projection kernel initializer. If not specified, the query will be initialized via ``attention_initializer`` attention_output_layer_initializer: attention output layer projection initializer. If not specified, the output will be initialized via ``attention_initializer`` ffn_initializer: FFN layer initializer. Defaults to "xavier_uniform". ffn_output_layer_initializer: If not None, initialize the last FFN layer with this initializer. Defaults to None. use_ff_layer1_dropout: If ``True``, dropout will be enabled after the first feed forward layer. Default: True use_ff_layer2_dropout = If ``True``, dropout will be enabled after the second feed forward layer. Default: True ffn_dropout_rate: Controls dropout rate of FF's first layer. If None, defaults to dropout. layerscale_value: initial value to use for LayerScale in vision transformers. Defaults to None. stochastic_depth_drop_prob: drop probability for stochastic depth per sample (when applied in main path of residual blocks. stochastic_depth_mode: should be in ["batch", "row"]. Example: When ``batch_first`` is ``True``: >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) >>> src = torch.rand(32, 10, 512) >>> out = encoder_layer(src) """ def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = "gelu", layer_norm_eps: float = 1e-5, batch_first: bool = True, norm_layer: Type[nn.Module] = LayerNorm, norm_first: bool = False, attention_module: Union[str, nn.Module] = "aiayn_attention", extra_attention_params={}, device=None, attention_dropout_rate: Optional[float] = None, attention_type="scaled_dot_product", scale_qk_dot_by_d=False, attention_logits_alpha=1.0, q_projection_scale=1.0, k_projection_scale=1.0, v_projection_scale=1.0, output_projection_scale=1.0, attention_softmax_fp32: Optional[bool] = True, attention_inner_dim=None, use_projection_bias_in_attention=False, use_ffn_bias_in_attention=False, use_ffn_bias=False, attention_initializer="xavier_uniform", attention_q_initializer=None, attention_output_layer_initializer=None, attention_logit_softcapping=None, ffn_initializer="xavier_uniform", ffn_output_layer_initializer=None, use_ff_layer1_dropout: bool = True, use_ff_layer2_dropout: bool = True, ffn_dropout_rate: Optional[float] = None, layerscale_value: Optional[float] = None, stochastic_depth_drop_prob: Optional[float] = 0.0, stochastic_depth_mode: Optional[str] = "batch", ) -> None: super(TransformerEncoderLayer, self).__init__() assert batch_first, "Currently, only batch_first=True is supported" if attention_dropout_rate is None: attention_dropout_rate = dropout AttentionModule = get_attention_module( attention_module, extra_attention_params ) self.self_attn = AttentionModule( d_model, nhead, inner_dim=attention_inner_dim, dropout=attention_dropout_rate, batch_first=batch_first, attention_type=attention_type, scale_qk_dot_by_d=scale_qk_dot_by_d, attention_logits_alpha=attention_logits_alpha, q_projection_scale=q_projection_scale, k_projection_scale=k_projection_scale, v_projection_scale=v_projection_scale, output_projection_scale=output_projection_scale, softmax_dtype_fp32=attention_softmax_fp32, use_projection_bias=use_projection_bias_in_attention, use_ffn_bias=use_ffn_bias_in_attention, attention_initializer=attention_initializer, attention_q_initializer=attention_q_initializer, output_layer_initializer=attention_output_layer_initializer, logit_softcapping=attention_logit_softcapping, device=device, **extra_attention_params, ) self.layerscale_value = layerscale_value if self.layerscale_value is not None: self.layer_scale1 = nn.Parameter( self.layerscale_value * torch.ones(d_model) ) self.layer_scale2 = nn.Parameter( self.layerscale_value * torch.ones(d_model) ) self.stochastic_depth_drop_prob = stochastic_depth_drop_prob self.stochastic_depth_mode = stochastic_depth_mode if self.stochastic_depth_drop_prob > 0.0: self.drop_path = StochasticDepth( self.stochastic_depth_drop_prob, self.stochastic_depth_mode ) else: self.drop_path = nn.Identity() if ffn_dropout_rate is None: ffn_dropout_rate = dropout self.ffn = FeedForwardNetwork( FeedForwardNetworkConfig( input_unit=d_model, layers_units=[dim_feedforward, d_model], layers_activation=[activation, None], layers_dropout_rates=[ ffn_dropout_rate if use_ff_layer1_dropout else None, dropout if use_ff_layer2_dropout else None, ], use_bias=use_ffn_bias, kernel_initializer=ffn_initializer, output_layer_initializer=ffn_output_layer_initializer, bias_initializer="zeros", device=device, ) ) self.norm_first = norm_first self.norm1 = norm_layer( d_model, eps=layer_norm_eps, device=device, ) self.norm2 = norm_layer( d_model, eps=layer_norm_eps, device=device, ) self.dropout1 = Dropout(dropout) self.__reset_parameters() def reset_parameters(self): self.__reset_parameters() def __reset_parameters(self): self.self_attn.reset_parameters() self.ffn.reset_parameters() if hasattr(self.norm1, 'bias') and hasattr(self.norm1.bias, "data"): self.norm1.bias.data.zero_() if hasattr(self.norm1, 'weight') and hasattr(self.norm1.weight, "data"): self.norm1.weight.data.fill_(1.0) if hasattr(self.norm2, 'bias') and hasattr(self.norm2.bias, "data"): self.norm2.bias.data.zero_() if hasattr(self.norm2, 'weight') and hasattr(self.norm1.weight, "data"): self.norm2.weight.data.fill_(1.0)
[docs] def forward( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, rotary_position_embedding_helper: Optional[ RotaryPositionEmbeddingHelper ] = None, self_attn_position_bias: Optional[Tensor] = None, **extra_args, ) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). rotary_position_embedding_helper (Optional[RotaryPositionEmbeddingHelper]): A helper class to apply rotary embedding on the input tensor. self_attn_position_bias: the tensor containing position bias to apply in self-attention, can be obtained from relative or alibi position embeddings. Shape: see the docs in Transformer class. """ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf x = src if self.norm_first: x = x + self._sa_block( self.norm1(x), src_mask, src_key_padding_mask, rotary_position_embedding_helper=rotary_position_embedding_helper, self_attn_position_bias=self_attn_position_bias, **extra_args, ) x = x + self._ffn_block(self.norm2(x)) else: x = self.norm1( x + self._sa_block( x, src_mask, src_key_padding_mask, rotary_position_embedding_helper=rotary_position_embedding_helper, self_attn_position_bias=self_attn_position_bias, **extra_args, ) ) x = self.norm2(x + self._ffn_block(x)) return x
# self-attention block def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], rotary_position_embedding_helper: Optional[ RotaryPositionEmbeddingHelper ] = None, self_attn_position_bias: Optional[Tensor] = None, **extra_args, ) -> Tensor: x = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, rotary_position_embedding_helper=rotary_position_embedding_helper, position_bias=self_attn_position_bias, need_weights=False, **extra_args, ) x = self.dropout1(x) return x # ffn block def _ffn_block( self, x: Tensor, ) -> Tensor: x = self.ffn(x) return x