Source code for cerebras.modelzoo.layers.Transformer

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py
"""

from typing import Any, Callable, Optional, Union

import torch.nn as nn
from torch import Tensor
from torch.nn import LayerNorm

from cerebras.modelzoo.layers.TransformerDecoder import TransformerDecoder
from cerebras.modelzoo.layers.TransformerDecoderLayer import (
    TransformerDecoderLayer,
)
from cerebras.modelzoo.layers.TransformerEncoder import TransformerEncoder
from cerebras.modelzoo.layers.TransformerEncoderLayer import (
    TransformerEncoderLayer,
)


[docs]class Transformer(nn.Module): r"""A transformer model. User is able to modify the attributes as needed. The architecture is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. Args: d_model: the number of expected features in the encoder/decoder inputs (default=512). nhead: the number of heads in the multihead attention models (default=8). num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). activation: the activation function of encoder/decoder intermediate layer, can be a string ("relu" or "gelu") or a unary callable. Default: gelu custom_encoder: custom encoder (default=None). custom_decoder: custom decoder (default=None). layer_norm_eps: the eps value in layer normalization components (default=1e-5). batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` (seq, batch, feature). norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before other attention and feedforward operations, otherwise after. Default: ``False`` (after). attention_type: Should be in ["scaled_dot_product", "dot_product"]. use_projection_bias_in_attention: Add bias to Q,K,V projections in the Attention layer. Defaults to False. use_ffn_bias_in_attention: Add bias in the concluding FFN in the Attention layer. Defaults to False. use_ffn_bias: Add bias in all dense layers of the decoder's ffn sublayer. attention_initializer: Attention layer initializer. Defaults to "xavier_uniform". ffn_initializer: FFN layer initializer. Defaults to "xavier_uniform". device (optional): Device to create the model parameters on, can be a cuda device or CS device. Examples:: >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) >>> src = torch.rand((10, 32, 512)) >>> tgt = torch.rand((20, 32, 512)) >>> out = transformer_model(src, tgt) Note: A full example to apply nn.Transformer module for the word language model is available in https://github.com/pytorch/examples/tree/master/word_language_model """ def __init__( self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = "gelu", custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None, layer_norm_eps: float = 1e-5, batch_first: bool = True, norm_first: bool = False, attention_type="scaled_dot_product", use_projection_bias_in_attention=False, use_ffn_bias_in_attention=False, use_ffn_bias=False, attention_initializer="xavier_uniform", ffn_initializer="xavier_uniform", device=None, ) -> None: super(Transformer, self).__init__() assert batch_first, "Currently, only batch_first=True is supported" if custom_encoder is not None: self.encoder = custom_encoder else: encoder_layer = TransformerEncoderLayer( d_model, nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, batch_first=batch_first, norm_first=norm_first, attention_type=attention_type, use_projection_bias_in_attention=use_projection_bias_in_attention, use_ffn_bias_in_attention=use_ffn_bias_in_attention, use_ffn_bias=use_ffn_bias, attention_initializer=attention_initializer, ffn_initializer=ffn_initializer, device=device, ) self.encoder_norm = LayerNorm( d_model, eps=layer_norm_eps, device=device, ) self.encoder = TransformerEncoder( encoder_layer, num_encoder_layers, self.encoder_norm ) if custom_decoder is not None: self.decoder = custom_decoder else: decoder_layer = TransformerDecoderLayer( d_model, nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, batch_first=batch_first, norm_first=norm_first, attention_type=attention_type, use_projection_bias_in_attention=use_projection_bias_in_attention, use_ffn_bias_in_attention=use_ffn_bias_in_attention, use_ffn_bias=use_ffn_bias, attention_initializer=attention_initializer, ffn_initializer=ffn_initializer, device=device, ) self.decoder_norm = LayerNorm( d_model, eps=layer_norm_eps, device=device, ) self.decoder = TransformerDecoder( decoder_layer, num_decoder_layers, self.decoder_norm ) self.__reset_parameters() self.d_model = d_model self.nhead = nhead self.batch_first = batch_first
[docs] def forward( self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Take in and process masked source/target sequences. Args: src: the sequence to the encoder (required). tgt: the sequence to the decoder (required). src_mask: the additive mask for the src sequence (optional). tgt_mask: the additive mask for the tgt sequence (optional). memory_mask: the additive mask for the encoder output (optional). src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). Shape: - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or `(N, S, E)` if `batch_first=True`. - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or `(N, T, E)` if `batch_first=True`. - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`. - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`. - memory_mask: :math:`(T, S)`. - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`. - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`. - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`. Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or `(N, T, E)` if `batch_first=True`. Note: Due to the multi-head attention architecture in the transformer model, the output sequence length of a transformer is same as the input sequence (i.e. target) length of the decode. where S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number Examples: >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) """ is_batched = src.dim() == 3 if not self.batch_first and src.size(1) != tgt.size(1) and is_batched: raise RuntimeError( f"The batch number of `src` and `tgt` must be equal. " f"They were {src.size(1)} and {tgt.size(1)} respectively." ) elif self.batch_first and src.size(0) != tgt.size(0) and is_batched: raise RuntimeError( f"The batch number of `src` and `tgt` must be equal. " f"They were {src.size(0)} and {tgt.size(0)} respectively." ) if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model: raise RuntimeError( f"The feature number of `src` and `tgt` must be equal to `d_model`. " f"They were {src.size(-1)}, {tgt.size(-1)}, and {self.d_model} respectively." ) memory = self.encoder( src, mask=src_mask, src_key_padding_mask=src_key_padding_mask ) output = self.decoder( tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, ) return output
def reset_parameters(self): self.__reset_parameters() def __reset_parameters(self): if self.encoder_norm: self.encoder_norm.bias.data.zero_() self.encoder_norm.weight.data.fill_(1.0) if self.decoder_norm: self.decoder_norm.bias.data.zero_() self.decoder_norm.weight.data.fill_(1.0)