Source code for cerebras.modelzoo.common.utils.model.mup_utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from dataclasses import asdict, is_dataclass
from typing import List, Optional, Union

import torch

from cerebras.modelzoo.config_manager.config_classes.base.model_config import (
    InitializerConfig,
)
from cerebras.pytorch.utils.utils import FilterCallable, make_param_filter

SUPPORTED_MUP_INITIALIZERS = [
    "normal",
    "truncated_normal",
]


[docs]class LRAdjustmentGroup: """ Stores data for a group of params that share a learning rate scalar. Stores a callable that returns True if a given model param corresponds to the group. Additionally, it stores the scale that should be applied to the LR of the model params that correspond to the group. """ def __init__( self, param_filter: Union[str, List[str]], scale: Optional[float] = 1.0, ): """ param_filter: A string or a list of strings that contains glob expressions used to match whether a given model param name belongs to the group. scale: The scale that should be applied to the LR of this group """ # Convert the strings into a callable that returns True if a given param # name corresponds to the LR group self.param_filter = make_param_filter(param_filter) self.scale = scale def set_scale(self, scale): self.scale = scale
[docs]def scale_initializers_by_dimension( initializers: Union[InitializerConfig, List[InitializerConfig]], width_scale: Optional[float] = None, depth_scale: Optional[float] = None, ): """ Scales the std of an initializer or list of initializers by the specified width and depth scalars. Unsupported initializers are ignored and a warning is printed to the user. """ if not width_scale: width_scale = 1.0 if not depth_scale: depth_scale = 1.0 mup_scalar = width_scale * depth_scale if not isinstance(initializers, list): initializers = [initializers] for initializer in initializers: if type(initializer) == str: initializer = {"name": initializer} if "name" not in initializer: raise ValueError("Initializer name must be provided") initializer_name = initializer["name"].lower() if initializer_name not in SUPPORTED_MUP_INITIALIZERS: raise RuntimeError( f"Initializer {initializer} does not support mup scaling. " f"Please use a supported initializer from the following: " f"{SUPPORTED_MUP_INITIALIZERS}" ) continue if initializer_name == "normal": initializer["std"] = initializer.get("std", 1.0) * mup_scalar elif initializer_name == "truncated_normal": std = initializer.get("std", 1.0) initializer["std"] = std * mup_scalar initializer["a"] = initializer.get("a", -2 * std) * mup_scalar initializer["b"] = initializer.get("b", 2 * std) * mup_scalar std = None
[docs]def is_mup(model: Union[dict, torch.nn.Module]): if is_dataclass(model): model = asdict(model) return any( name.startswith('mup_base_') and model[name] is not None for name in model ) if not isinstance(model, dict): model = {k: getattr(model, k) for k in dir(model)} return any( name.startswith('mup_base_') and model[name] is not None for name in model )
[docs]def process_lr_adjustment_params( model_lr_adjustment_groups, params_lr_adjustment_groups ): """ Parses the model's supported lr adjustment groups and optionally overrides any user set scales Args: model_lr_adjustment_groups (dict): Keys are the LR group name and the values are LRAdjustmentGroup instances params_lr_adjustment_groups (dict): Keys are the LR group name and the values are the scale override value Returns: Tuple: A tuple consisting of a list of the adjustment scales with a corresponding list of parameter filter callables used to identify params that belong to the scales """ lr_adjustment_scales: List[float] = [] lr_adjustment_filters: List[List[FilterCallable]] = [] for lr_group_name, lr_group in model_lr_adjustment_groups.items(): if lr_group_name in params_lr_adjustment_groups: param_scale = params_lr_adjustment_groups.get(lr_group_name, 1.0) if param_scale != lr_group.scale: if lr_group.scale != 1.0: logging.warning( f"Overriding the scale for adjust_learning_rate group {lr_group_name} " f"with the provided value of {param_scale}" ) lr_group.set_scale(param_scale) if lr_group.scale != 1.0: # Prior to 2.3 release, we had a single LR adjustment group which matched # a number of parameters. In 2.3, however, the adjustment groups were broken # up to allow finer-grained control over scaling different parameters. # To keep backwards compatibility with older checkpoints, we need to ensure # ordering of param groups remains the same even if old configs/checkpoints # are provided. As such, we merge params that have the same scale here which # ensures same ordering as before. If we treated them as separate groups and # merged them later on, the ordering would have been different. for idx in range(len(lr_adjustment_scales)): if lr_group.scale == lr_adjustment_scales[idx]: lr_adjustment_filters[idx].append(lr_group.param_filter) break else: lr_adjustment_scales.append(lr_group.scale) lr_adjustment_filters.append([lr_group.param_filter]) return (lr_adjustment_scales, lr_adjustment_filters)