Source code for cerebras.modelzoo.layers.FeedForwardNetwork

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
from dataclasses import dataclass
from typing import Callable, ClassVar, List, Optional, Union

import torch.nn as nn
from torch import Tensor

from cerebras.modelzoo.layers.activations import (
    get_activation,
    is_glu_activation,
)
from cerebras.modelzoo.layers.create_initializer import create_initializer


class StaticDualExpertLinear(nn.Module):
    """
    Description of the linear op where tokens are sent to two different experts based on 'token_modality_idx'
    """

    def __init__(
        self, in_features, out_features, bias=True, device=None, dtype=None
    ):
        super().__init__()
        self.linear_img = nn.Linear(
            in_features, out_features, bias=bias, device=device, dtype=dtype
        )
        self.linear_text = nn.Linear(
            in_features, out_features, bias=bias, device=device, dtype=dtype
        )

    def forward(self, input, token_modality_idx):
        """
        We can model this part as a MOE structure, which we have one expert
        for each text and image portion. The text and image portion is masked
        by `token_modality_idx`.
        TODO: we can leverage the MOE optimization by using the MOE Sparse linear layer.

        Args:
            token_modality_idx: tensor to mask different modality.
                value '1' is for image tocken and value '0' is for text token.
        """
        assert (
            token_modality_idx != None
        ), "Expect 'token_modality_idx' when 'static_dual_expert' is True."
        img_out = self.linear_img(input)
        text_out = self.linear_text(input)

        token_modality_idx = token_modality_idx.to(img_out.dtype)[
            :, :, None
        ].broadcast_to(*img_out.shape)
        x = text_out * (1 - token_modality_idx) + img_out * token_modality_idx
        return x


class SingleFeedForwardLayer(nn.Module):
    """
    Initialize Single FFN layer instance.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        use_bias: bool = False,
        activation: Optional[Union[str, Callable[[Tensor], Tensor]]] = None,
        dropout: Optional[float] = None,
        device=None,
        static_dual_expert: bool = False,
    ):
        super(SingleFeedForwardLayer, self).__init__()

        self.static_dual_expert = static_dual_expert
        if static_dual_expert == True:
            linearOp = StaticDualExpertLinear
        else:
            linearOp = nn.Linear
        self.linear_layer = linearOp(
            in_features,
            out_features,
            bias=use_bias,
            device=device,
        )

        self.is_glu_activation = is_glu_activation(activation)
        if self.is_glu_activation:
            self.linear_layer_for_glu = linearOp(
                in_features,
                out_features,
                bias=use_bias,
                device=device,
            )

        if activation:
            self.act_layer = get_activation(activation)
        else:
            self.act_layer = None

        if dropout and dropout > 0.0:
            self.dropout_layer = nn.Dropout(p=dropout)
        else:
            self.dropout_layer = None

    def forward(self, inputs, **extra_args):
        if self.static_dual_expert:
            linear_inputs = [inputs, extra_args.get("token_modality_idx", None)]
        else:
            linear_inputs = [inputs]
        if self.is_glu_activation:
            glu_component_1 = self.linear_layer(*linear_inputs)
            glu_component_2 = self.linear_layer_for_glu(*linear_inputs)
            outputs = self.act_layer(glu_component_1, glu_component_2)
        else:
            outputs = self.linear_layer(*linear_inputs)
            if self.act_layer:
                outputs = self.act_layer(outputs)
        if self.dropout_layer:
            outputs = self.dropout_layer(outputs)
        return outputs


@dataclass
class FeedForwardNetworkConfig:
    """Feed forward network config.

    Args:
        input_unit (int): integer for number of in_features of input.
        layers_units (list[int]): List of units for each layer.
        layers_activation (list[str]): List of activation types (str) for each layer.
        layers_dropout_rates (list[float]): List of dropout rates (float) for each
            layer.
        use_bias (bool): If `True`, use bias throughout all layers.
        kernel_initializer: Kernel initializer. Defaults to
            `"xavier_uniform"`.
        bias_initializer: Bias initializer. Defaults to `"zeros"`.
        output_layer_initializer: If not None, initialize the last projection
            layer with this initializer. Defaults to None.
        num_experts (int): The number of experts in the MoE block. Defaults to 1.
        top_k (int): The number of experts to be used on each token. Defaults to None.
        gate_initializer: Router (gate) initializer. Defaults to `"xavier_uniform"`.
        load_balancing_loss_coef (float): The float coefficient to scale the load balancing loss.
            Defaults to `0.01`.
        router_fp32 (bool): If `True`, the router operate in FP32 dtype, Defaults to `True`.
        routing_algorithm (str): The routing algorithm used in Mixture-of-Experts
            models. Choose from: "learned", "hash". Defaults to "learned".
        moe_implementation (str): "functional" or "optimized" implementation. Defaults to "optimized".
        device (optional): Device to create the model parameters on, can be a cuda device or CS device.
    """

    input_unit: int
    layers_units: List[int]
    layers_activation: Optional[
        List[Union[str, Callable[[Tensor], Tensor]]]
    ] = None
    layers_dropout_rates: Optional[List[float]] = None
    use_bias: bool = False
    kernel_initializer: str = "xavier_uniform"
    bias_initializer: str = "zeros"
    output_layer_initializer: Optional[str] = None
    num_experts: int = 1
    top_k: Optional[int] = None
    gate_initializer: Optional[str] = "xavier_uniform"
    load_balancing_loss_coef: Optional[float] = 0.01
    router_fp32: bool = True
    routing_algorithm: str = "learned"
    moe_implementation: str = "optimized"
    device: str = None
    static_dual_expert: bool = False

    # Class variables.
    MoEImpl: ClassVar[enum.Enum] = enum.Enum(
        "MoEImpl", ["functional", "optimized"]
    )

    def __post_init__(self):
        self.num_dense_layers = len(self.layers_units)
        self.input_units = [self.input_unit] + self.layers_units[:-1]
        if self.output_layer_initializer is None:
            self.output_layer_initializer = self.kernel_initializer

        assert (
            self.num_dense_layers > 0
        ), "Number of dense layers should be at least 1."

        if self.layers_activation:
            assert len(self.layers_activation) == self.num_dense_layers, (
                "len(layers_activation) should equal the number"
                " of dense layers."
            )
        else:
            self.layers_activation = [None] * self.num_dense_layers

        if self.layers_dropout_rates:
            assert len(self.layers_dropout_rates) == self.num_dense_layers, (
                "len(layers_dropout) should equal the number" "of dense layers."
            )
        else:
            self.layers_dropout_rates = [None] * self.num_dense_layers

        if self.num_experts > 1:
            assert (
                self.load_balancing_loss_coef >= 0.0
            ), f"load_balancing_loss_coef cannot be less than 0, got {self.load_balancing_loss_coef}"
            assert (
                self.top_k >= 1 and self.top_k <= self.num_experts
            ), f"{self.top_k=} should be [1, {self.num_experts=}]"
            if self.routing_algorithm == "hash":
                assert (
                    self.top_k == 1
                ), f"{self.top_k=} but should be 1 for hash routing"

    def moe_optimized_impl(self) -> bool:
        """Return True if optimized implementation is used."""
        return self.MoEImpl[self.moe_implementation] == self.MoEImpl.optimized

    def moe_functional_impl(self) -> bool:
        """Return True if functional implementation is used."""
        return self.MoEImpl[self.moe_implementation] == self.MoEImpl.functional


[docs]class FeedForwardNetwork(nn.Module): """ A feed forward network that consists of a stack of fully connected\ layers arranged as [LinearLayer -> Activation -> Dropout] block repeated `len(layers_units)` times. Args: config (FeedForwardNetworkConfig): Feed forward network config. """ def __init__(self, config: FeedForwardNetworkConfig): """ Initialize the FFN object instance. """ super(FeedForwardNetwork, self).__init__() self.config = config self.static_dual_expert = config.static_dual_expert # This sets the namespace of the layer. # Using `nn.ModuleList` to have clear namespace such as # `ffn.{layer_num}.weight` and `ffn.{layer_num}.bias` # Class attributes cannot have `.` in their names when # inheriting from `nn.Module` and therefore cannot generate # attribute names on the fly and hence the need to use ModuleList. self.ffn = nn.ModuleList( [ SingleFeedForwardLayer( in_features, out_features, use_bias=self.config.use_bias, activation=activation, dropout=dropout, device=config.device, static_dual_expert=config.static_dual_expert, ) for in_features, out_features, activation, dropout in zip( self.config.input_units, self.config.layers_units, self.config.layers_activation, self.config.layers_dropout_rates, ) ] ) # Initialize weights in Linear layers. self.__reset_parameters() def reset_parameters(self): self.__reset_parameters() def __reset_parameters(self): # Initialize weights for all Linear layers for layer_num, linear_layer_module in enumerate(self.ffn): weight_initializer = create_initializer( self.config.kernel_initializer ) if layer_num == self.config.num_dense_layers - 1: weight_initializer = create_initializer( self.config.output_layer_initializer ) # Initialize linear layer weights associated with the # 'GLU' type activation function with the kernel_initializer for m in linear_layer_module.modules(): if type(m) == nn.Linear: weight_initializer(m.weight.data) if m.bias is not None: create_initializer(self.config.bias_initializer)( m.bias.data ) def forward(self, inputs, **extra_args): outputs = inputs for ffn_layer in self.ffn: outputs = ffn_layer(outputs, **extra_args) return outputs