Source code for cerebras.modelzoo.layers.TransformerDecoderLayer
# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py
"""
from typing import Callable, Optional, Tuple, Type, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Dropout, LayerNorm
from cerebras.modelzoo.layers.AttentionHelper import get_attention_module
from cerebras.modelzoo.layers.FeedForwardNetwork import (
FeedForwardNetwork,
FeedForwardNetworkConfig,
)
from cerebras.modelzoo.layers.RotaryPositionEmbeddingHelper import (
RotaryPositionEmbeddingHelper,
)
from cerebras.modelzoo.layers.SparseMoEBlock import SparseMoEBlock
SelfAttnKV = Tuple[Tensor, Tensor]
CrossAttnKV = Tuple[Tensor, Tensor]
SelfAndCrossAttnKV = Tuple[Tensor, Tensor, Tensor, Tensor]
[docs]class TransformerDecoderLayer(nn.Module):
r"""
TransformerDecoderLayer is made up of self-attn, multihead-attn and feedforward network.
This standard decoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multihead-attention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: gelu
layer_norm_eps: the eps value in layer normalization components (default=1e-5).
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
norm_layer: the normalization class that will be used before/after FF layers (default=nn.LayerNorm)
norm_first: if ``True``, layer norm is done prior to self attention, multihead
attention and feedforward operations, respectively. Otherwise it's done after.
Default: ``False`` (after).
attention_dropout_rate: Attention dropout rate. If None, defaults to dropout.
attention_softmax_fp32: Use FP32 softmax in attention block.
use_projection_bias_in_attention: Add bias to Q,K,V projections
in the Attention layer. Defaults to False.
attention_type: Should be in ["scaled_dot_product", "dot_product"]
scale_qk_dot_by_d (bool): If ``True`` scales QK^T dot product by d(=hidden/d_head) instead of sqrt(d).
attention_logit_alpha (float): Scales the QK^T dot product. Used to stabilize logits in muP training.
attention_inner_dim (int): Number of output units in attention query/key/value projection. Defaults to d_model
add_cross_attention: If ``True``, adds cross-attention layer between encoder/decoder,
otherwise, only self-attention is used in the decoder (GPT-style models should set to ``False``)
use_ffn_bias_in_attention: Add bias in the concluding FFN
in the Attention layer. Defaults to False.
use_ffn_bias: Add bias in all dense layers of the decoder's ffn sublayer
attention_initializer: Attention layer initializer. Defaults to "xavier_uniform".
attention_q_initializer: Query projection kernel initializer. If not
specified, the query will be initialized via ``attention_initializer``
attention_output_layer_initializer: attention output layer projection initializer. If not
specified, the output will be initialized via ``attention_initializer``
ffn_initializer: FFN layer initializer. Defaults to "xavier_uniform".
ffn_output_layer_initializer: If not None, initialize the last FFN layer
with this initializer. Defaults to None.
use_ff_layer1_dropout: If ``True``, dropout will be enabled after the first feed forward layer. Default: True
use_ff_layer2_dropout = If ``True``, dropout will be enabled after the second feed forward layer. Default: True
ffn_dropout_rate: Controls dropout rate of FF's first layer. If None, defaults to dropout.
moe_params: A dict of MoE params including num_experts, top_k and load_balancing_loss_coef
Examples:
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
>>> memory = torch.rand(32, 10, 512)
>>> tgt = torch.rand(32, 20, 512)
>>> out = decoder_layer(tgt, memory)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = "gelu",
layer_norm_eps: float = 1e-5,
batch_first: bool = True,
norm_layer: Type[nn.Module] = LayerNorm,
norm_first: bool = False,
norm_first_sandwich: bool = False,
attention_module: Union[str, nn.Module] = "aiayn_attention",
extra_attention_params={},
extra_ffn_params={},
device=None,
add_cross_attention: bool = True,
attention_dropout_rate: Optional[float] = None,
attention_softmax_fp32: Optional[bool] = True,
attention_type="scaled_dot_product",
scale_qk_dot_by_d=False,
attention_logits_alpha: Optional[float] = 1.0,
q_projection_scale=1.0,
k_projection_scale=1.0,
v_projection_scale=1.0,
output_projection_scale=1.0,
scale_qk_dot_by_layer_idx=False,
attention_inner_dim=None,
cross_attention_kv_dim=None,
use_projection_bias_in_attention=False,
use_ffn_bias_in_attention=False,
use_ffn_bias=False,
attention_initializer="xavier_uniform",
attention_q_initializer=None,
attention_output_layer_initializer=None,
attention_logit_softcapping=None,
ffn_initializer="xavier_uniform",
ffn_output_layer_initializer=None,
use_ff_layer1_dropout: bool = True,
use_ff_layer2_dropout: bool = True,
ffn_dropout_rate: Optional[float] = None,
moe_params=dict(num_experts=1),
) -> None:
super(TransformerDecoderLayer, self).__init__()
assert batch_first, "Currently, only batch_first=True is supported"
self.add_cross_attention = add_cross_attention
if attention_dropout_rate is None:
attention_dropout_rate = dropout
AttentionModule = get_attention_module(
attention_module, extra_attention_params
)
self.self_attn = AttentionModule(
d_model,
nhead,
inner_dim=attention_inner_dim,
dropout=attention_dropout_rate,
batch_first=batch_first,
attention_type=attention_type,
scale_qk_dot_by_d=scale_qk_dot_by_d,
attention_logits_alpha=attention_logits_alpha,
q_projection_scale=q_projection_scale,
k_projection_scale=k_projection_scale,
v_projection_scale=v_projection_scale,
output_projection_scale=output_projection_scale,
softmax_dtype_fp32=attention_softmax_fp32,
use_projection_bias=use_projection_bias_in_attention,
use_ffn_bias=use_ffn_bias_in_attention,
attention_initializer=attention_initializer,
attention_q_initializer=attention_q_initializer,
output_layer_initializer=attention_output_layer_initializer,
scale_qk_dot_by_layer_idx=scale_qk_dot_by_layer_idx,
logit_softcapping=attention_logit_softcapping,
device=device,
**extra_attention_params,
)
self.norm_first = norm_first
if norm_first_sandwich:
assert self.norm_first, (
"When norm_first_sandwich is enabled, norm_first must be "
"enabled too"
)
self.norm_first_sandwich = norm_first_sandwich
self.norm1 = norm_layer(
d_model,
eps=layer_norm_eps,
device=device,
)
self.dropout1 = Dropout(dropout)
self.norm3 = norm_layer(
d_model,
eps=layer_norm_eps,
device=device,
)
if norm_first_sandwich:
self.norm1_post = norm_layer(
d_model,
eps=layer_norm_eps,
device=device,
)
self.norm3_post = norm_layer(
d_model,
eps=layer_norm_eps,
device=device,
)
if self.add_cross_attention:
if cross_attention_kv_dim is None:
cross_attention_kv_dim = d_model
self.multihead_attn = AttentionModule(
d_model,
nhead,
kdim=cross_attention_kv_dim,
vdim=cross_attention_kv_dim,
inner_dim=attention_inner_dim,
dropout=attention_dropout_rate,
batch_first=batch_first,
attention_type=attention_type,
scale_qk_dot_by_d=scale_qk_dot_by_d,
attention_logits_alpha=attention_logits_alpha,
q_projection_scale=q_projection_scale,
k_projection_scale=k_projection_scale,
v_projection_scale=v_projection_scale,
output_projection_scale=output_projection_scale,
softmax_dtype_fp32=attention_softmax_fp32,
use_projection_bias=use_projection_bias_in_attention,
use_ffn_bias=use_ffn_bias_in_attention,
attention_initializer=attention_initializer,
attention_q_initializer=attention_q_initializer,
output_layer_initializer=attention_output_layer_initializer,
scale_qk_dot_by_layer_idx=scale_qk_dot_by_layer_idx,
logit_softcapping=attention_logit_softcapping,
device=device,
**extra_attention_params,
)
self.norm2 = norm_layer(
d_model,
eps=layer_norm_eps,
device=device,
)
self.dropout2 = Dropout(dropout)
if ffn_dropout_rate is None:
ffn_dropout_rate = dropout
ffn_config = FeedForwardNetworkConfig(
input_unit=d_model,
layers_units=[dim_feedforward, d_model],
layers_activation=[activation, None],
layers_dropout_rates=[
ffn_dropout_rate if use_ff_layer1_dropout else None,
dropout if use_ff_layer2_dropout else None,
],
use_bias=use_ffn_bias,
kernel_initializer=ffn_initializer,
output_layer_initializer=ffn_output_layer_initializer,
bias_initializer="zeros",
device=device,
**moe_params,
**extra_ffn_params,
)
self.moe_enabled = moe_params["num_experts"] > 1
if self.moe_enabled:
self.ffn = SparseMoEBlock(ffn_config)
else:
self.ffn = FeedForwardNetwork(ffn_config)
self.__reset_parameters()
def reset_parameters(self):
self.__reset_parameters()
def __reset_parameters(self):
self.self_attn.reset_parameters()
self.ffn.reset_parameters()
if hasattr(self.norm1, 'bias') and hasattr(self.norm1.bias, 'data'):
self.norm1.bias.data.zero_()
if hasattr(self.norm1, 'weight') and hasattr(self.norm1.weight, 'data'):
self.norm1.weight.data.fill_(1.0)
if self.norm3 is not None:
if hasattr(self.norm3, 'bias') and hasattr(self.norm3.bias, 'data'):
self.norm3.bias.data.zero_()
if hasattr(self.norm3, 'weight') and hasattr(
self.norm3.weight, 'data'
):
self.norm3.weight.data.fill_(1.0)
if self.norm_first_sandwich:
if hasattr(self.norm1_post, 'bias') and hasattr(
self.norm1_post.bias, 'data'
):
self.norm1_post.bias.data.zero_()
if hasattr(self.norm1_post, 'weight') and hasattr(
self.norm1_post.weight, 'data'
):
self.norm1_post.weight.data.fill_(1.0)
if self.norm3_post is not None:
if hasattr(self.norm3_post, 'bias') and hasattr(
self.norm3_post.bias, 'data'
):
self.norm3_post.bias.data.zero_()
if hasattr(self.norm3_post, 'weight') and hasattr(
self.norm3_post.weight, 'data'
):
self.norm3_post.weight.data.fill_(1.0)
if self.add_cross_attention:
self.multihead_attn.reset_parameters()
if hasattr(self.norm2, 'bias') and hasattr(self.norm2.bias, 'data'):
self.norm2.bias.data.zero_()
if hasattr(self.norm2, 'weight') and hasattr(
self.norm2.weight, 'data'
):
self.norm2.weight.data.fill_(1.0)
[docs] def forward(
self,
tgt: Tensor,
memory: Optional[Tensor] = None,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
rotary_position_embedding_helper: Optional[
RotaryPositionEmbeddingHelper
] = None,
past_kv: Optional[Union[SelfAttnKV, SelfAndCrossAttnKV]] = None,
cache_present_kv: bool = False,
self_attn_position_bias: Optional[Tensor] = None,
cross_attn_position_bias: Optional[Tensor] = None,
layer_idx: Optional[int] = None,
expert_hash_idx: Optional[Tensor] = None,
**extra_args,
) -> Union[Tensor, Tuple[Tensor, Union[SelfAttnKV, SelfAndCrossAttnKV]]]:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
memory_key_padding_mask: the mask for the memory keys per batch (optional).
rotary_position_embedding_helper (Optional[RotaryPositionEmbeddingHelper]):
A helper class to apply rotary embedding on the input tensor.
past_kv: Past keys and values for self attention and (if applicable) cross
attention modules. Key/value tensors have shape
``[batch_size, num_heads, seq_length, embed_dim / num_heads]``. (optional).
cache_present_kv: Specifies if the present keys and values
must be cached and returned. Needed to speed up the
computations when the decoder is called within an
autoregressive loop. (optional).
self_attn_position_bias: the tensor containing position bias to apply in self-attention,
can be obtained from relative or alibi position embeddings.
expert_hash_idx: tensor containing mixture-of-experts expert
selection indices for each token in the batch. Only used with
MoE with hash-based routing enabled (optional).
Shape:
see the docs in Transformer class.
"""
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
assert (
past_kv is None and not cache_present_kv
), "Cannot provide past_kv because inference is not supported yet."
ffn_extra_args = {}
if (
extra_args
and ("token_modality_idx" in extra_args)
and (extra_args["token_modality_idx"] is not None)
):
ffn_extra_args["token_modality_idx"] = extra_args[
"token_modality_idx"
]
extra_args.pop('token_modality_idx')
if self.moe_enabled and expert_hash_idx is not None:
ffn_extra_args["expert_hash_idx"] = expert_hash_idx
x = tgt
if self.norm_first:
attn1_out = self._sa_block(
self.norm1(x),
tgt_mask,
tgt_key_padding_mask,
rotary_position_embedding_helper=rotary_position_embedding_helper,
past_kv=past_kv[:2] if past_kv is not None else None,
cache_present_kv=cache_present_kv,
self_attn_position_bias=self_attn_position_bias,
layer_idx=layer_idx,
**extra_args,
)
post_attn1 = attn1_out[0]
if self.norm_first_sandwich:
post_attn1 = self.norm1_post(post_attn1)
x = x + post_attn1
if self.add_cross_attention:
attn2_out = self._mha_block(
self.norm2(x),
memory,
memory_mask,
memory_key_padding_mask,
past_kv=past_kv[2:] if past_kv is not None else None,
cache_present_kv=cache_present_kv,
cross_attn_position_bias=cross_attn_position_bias,
layer_idx=layer_idx,
**extra_args,
)
x = x + attn2_out[0]
ffn_output = (
self.ffn(self.norm3(x), **ffn_extra_args)
if self.norm3 is not None
else self.ffn(x, **ffn_extra_args)
)
if self.moe_enabled:
(
ffn_output,
routing_weights,
expert_mask,
) = ffn_output
post_ffn_output = ffn_output
if self.norm_first_sandwich:
post_ffn_output = self.norm3_post(post_ffn_output)
x = x + post_ffn_output
else:
attn1_out = self._sa_block(
x,
tgt_mask,
tgt_key_padding_mask,
rotary_position_embedding_helper=rotary_position_embedding_helper,
past_kv=past_kv[:2] if past_kv is not None else None,
cache_present_kv=cache_present_kv,
self_attn_position_bias=self_attn_position_bias,
layer_idx=layer_idx,
**extra_args,
)
x = self.norm1(x + attn1_out[0])
if self.add_cross_attention:
attn2_out = self._mha_block(
x,
memory,
memory_mask,
memory_key_padding_mask,
past_kv=past_kv[2:] if past_kv is not None else None,
cache_present_kv=cache_present_kv,
cross_attn_position_bias=cross_attn_position_bias,
layer_idx=layer_idx,
**extra_args,
)
x = self.norm2(x + attn2_out[0])
ffn_output = self.ffn(x, **ffn_extra_args)
if self.moe_enabled:
(
ffn_output,
routing_weights,
expert_mask,
) = ffn_output
x = (
self.norm3(x + ffn_output)
if self.norm3 is not None
else x + ffn_output
)
if not self.moe_enabled:
if not cache_present_kv:
return x
else:
present_kv = (
attn1_out[1]
if not self.add_cross_attention
else attn1_out[1] + attn2_out[1]
)
return x, present_kv
else:
if not cache_present_kv:
return x, routing_weights, expert_mask
else:
present_kv = (
attn1_out[1]
if not self.add_cross_attention
else attn1_out[1] + attn2_out[1]
)
return x, routing_weights, expert_mask, present_kv
# self-attention block
def _sa_block(
self,
x: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
rotary_position_embedding_helper: Optional[
RotaryPositionEmbeddingHelper
] = None,
past_kv: Optional[SelfAttnKV] = None,
cache_present_kv: bool = False,
self_attn_position_bias: Optional[Tensor] = None,
layer_idx: Optional[int] = None,
**extra_args,
) -> Tensor:
attn_out = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
rotary_position_embedding_helper=rotary_position_embedding_helper,
past_kv=past_kv,
cache_present_kv=cache_present_kv,
position_bias=self_attn_position_bias,
layer_idx=layer_idx,
**extra_args,
)
if cache_present_kv:
out, present_kv = attn_out
else:
out = attn_out
out = (self.dropout1(out),)
if cache_present_kv:
out += (present_kv,)
return out
# multihead attention block
def _mha_block(
self,
x: Tensor,
mem: Tensor,
attn_mask: Optional[Tensor],
key_padding_mask: Optional[Tensor],
past_kv: Optional[CrossAttnKV] = None,
cache_present_kv: bool = False,
cross_attn_position_bias: Optional[Tensor] = None,
layer_idx: Optional[int] = None,
**extra_args,
) -> Tensor:
attn_out = self.multihead_attn(
x,
mem,
mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
past_kv=past_kv,
cache_present_kv=cache_present_kv,
past_kv_self_attn=False,
position_bias=cross_attn_position_bias,
layer_idx=layer_idx,
**extra_args,
)
if cache_present_kv:
x, present_kv = attn_out
else:
x = attn_out
out = (self.dropout2(x),)
if cache_present_kv:
out += (present_kv,)
return out