Source code for cerebras.modelzoo.common.utils.model.lora

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This code is adapted from
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
#
# Copyright 2022 Cerebras Systems.
#
#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------

import logging
import math
from dataclasses import dataclass
from typing import List, Optional, Set, Union

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]@dataclass class LoraConfig: r""" r: Rank of LoRA matrix projections alpha: Scaling factor (see paper for additional details) dropout: Dropout to apply to LoRA updates fan_in_fan_out: merge_weights: Determines whether lora weights should be merged/folded into underlying layers target_modules: A list of module names that must all exist in layers that will be converted to LoRA. For example, setting target_modules to ["TransformerDecoderLayer", "Linear"] would mean that all linear layers that were children of a TransformerDecoderLayer would be converted to LoRA. """ r: int = 0 alpha: int = 1 dropout: float = 0.0 fan_in_fan_out: bool = False merge_weights: bool = False target_modules: Optional[list] = None
[docs]def disable_lora_merge_weights(lora_params_dict: Union[dict, List[dict]]): r"""Sets merge_weights=False in LoRA parameters. This is helpful during eval mode to ensure that the weights don't get folded prior to checkpoint loading. """ def _disable_merge_weights(params, printed_already=False): if params["merge_weights"] and not printed_already: logging.warning( "Automatically switching LoRA merge_weights to False in order " "to run evals." ) printed_already = True params["merge_weights"] = False return printed_already if isinstance(lora_params_dict, list): printed = True for params in lora_params_dict: printed = _disable_merge_weights(params, printed) else: _disable_merge_weights(lora_params_dict)
[docs]class LoRALayer: r""" Base LoRA layer From https://github.com/microsoft/LoRA/blob/main/loralib/layers.py. """ def __init__( self, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool, ): self.r = r self.lora_alpha = lora_alpha # Optional dropout if lora_dropout > 0.0: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = lambda x: x # Mark the weight as unmerged self.merged = False self.merge_weights = merge_weights
[docs]class LoRA_Embedding(nn.Embedding, LoRALayer): r""" LoRA embedding layer From https://github.com/microsoft/LoRA/blob/main/loralib/layers.py. """ def __init__( self, num_embeddings: int, embedding_dim: int, r: int = 0, lora_alpha: int = 1, merge_weights: bool = True, **kwargs, ): nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) LoRALayer.__init__( self, r=r, lora_alpha=lora_alpha, lora_dropout=0, merge_weights=merge_weights, ) # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter( self.weight.new_zeros((r, num_embeddings)) ) self.lora_B = nn.Parameter( self.weight.new_zeros((embedding_dim, r)) ) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() def reset_parameters(self): nn.Embedding.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.zeros_(self.lora_A) nn.init.normal_(self.lora_B) def train(self, mode: bool = True): nn.Embedding.train(self, mode) if mode: if self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: self.weight.data -= (self.lora_B @ self.lora_A).transpose( 0, 1 ) * self.scaling self.merged = False else: if self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0: self.weight.data += (self.lora_B @ self.lora_A).transpose( 0, 1 ) * self.scaling self.merged = True def forward(self, x: torch.Tensor): if self.r > 0 and not self.merged: result = nn.Embedding.forward(self, x) if self.r > 0: after_A = F.embedding( x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse, ) result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling return result else: return nn.Embedding.forward(self, x)
[docs]class LoRA_Linear(nn.Linear, LoRALayer): r""" LoRA linear layer From https://github.com/microsoft/LoRA/blob/main/loralib/layers.py. """ def __init__( self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, fan_in_fan_out: bool = False, merge_weights: bool = True, **kwargs, ): nn.Linear.__init__(self, in_features, out_features, **kwargs) LoRALayer.__init__( self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights, ) self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.transpose(0, 1) def reset_parameters(self): nn.Linear.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def train(self, mode: bool = True): def T(w): return w.transpose(0, 1) if self.fan_in_fan_out else w nn.Linear.train(self, mode) if mode: if self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: self.weight.data -= ( T(self.lora_B @ self.lora_A) * self.scaling ) self.merged = False else: if self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0: self.weight.data += ( T(self.lora_B @ self.lora_A) * self.scaling ) self.merged = True def forward(self, x: torch.Tensor): def T(w): return w.transpose(0, 1) if self.fan_in_fan_out else w if self.r > 0 and not self.merged: result = F.linear(x, T(self.weight), bias=self.bias) if self.r > 0: result += ( self.lora_dropout(x)
[docs] @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1) ) * self.scaling return result else: return F.linear(x, T(self.weight), bias=self.bias)
def get_lora_config_for_module( lora_params: Union[LoraConfig, List[LoraConfig]], module_names: List[str] ) -> Optional[LoraConfig]: r""" Gets lora parameters for a particular module Args: lora_params: LoRA top-level config. module_names: Hierarchical list of module names. Returns: lora parameters (LoraConfig) for the given module if applicable or None if the module is not targeted. """ lora_params_list = ( lora_params if isinstance(lora_params, list) else [lora_params] ) for group_params in lora_params_list: target_modules = group_params.target_modules if target_modules is None or all( [e in module_names for e in target_modules] ): return group_params return None
[docs]def make_model_lora( model: nn.Module, lora_params_dict: Union[dict, List[dict]] ): r""" Create a Low Rank Adaptation (LoRA) model from a non-LoRA model. Note that the original non-LoRA model may be modified through this process. Args: model: Initial model to make LoRA lora_params_dict: LoRA parameters (in the form of a dict or list of dicts) which dictate how the supplied model will be converted into a LoRA model. The parameters should align with LoraConfig. Returns: LoRA model """ if isinstance(lora_params_dict, list): lora_params = [LoraConfig(**e) for e in lora_params_dict] else: lora_params = LoraConfig(**lora_params_dict) loraified_modules = set() lora_model = make_model_lora_helper( model, lora_params, [], loraified_modules ) if len(loraified_modules) == 0: raise RuntimeError( f"No modules were converted to LoRA. Please ensure that the " f"target_modules listed in the lora_params are valid." ) logging.info( f"All layers matching the following module names were converted to LoRA" f": {loraified_modules}" ) for n, p in lora_model.named_parameters(): if not n.endswith(".lora_A") and not n.endswith(".lora_B"): p.requires_grad = False return lora_model
[docs]def make_model_lora_helper( model: nn.Module, lora_params: Union[LoraConfig, List[LoraConfig]], module_names: List[str], loraified_modules: Set[str], ): module_names = module_names + [type(model).__name__] for name, child in model.named_children(): model.add_module( name, make_model_lora_helper( child, lora_params, module_names, loraified_modules ), ) module_lora_params = get_lora_config_for_module(lora_params, module_names) if module_lora_params is not None and isinstance(model, nn.Embedding): loraified_modules.add(".".join(module_names)) lora_embedding = LoRA_Embedding( # Embedding Args: model.num_embeddings, model.embedding_dim, padding_idx=model.padding_idx, max_norm=model.max_norm, norm_type=model.norm_type, scale_grad_by_freq=model.scale_grad_by_freq, sparse=model.sparse, device=model.weight.device, dtype=model.weight.dtype, # LoRA Args: r=module_lora_params.r, lora_alpha=module_lora_params.alpha, merge_weights=module_lora_params.merge_weights, ) with torch.no_grad(): lora_embedding.weight.copy_(model.weight) del model return lora_embedding elif module_lora_params is not None and isinstance(model, nn.Linear): loraified_modules.add(".".join(module_names)) lora_linear = LoRA_Linear( # Linear Args: model.in_features, model.out_features, bias=model.bias is not None, device=model.weight.device, dtype=model.weight.dtype, # LoRA Args: r=module_lora_params.r, lora_alpha=module_lora_params.alpha, lora_dropout=module_lora_params.dropout, fan_in_fan_out=module_lora_params.fan_in_fan_out, merge_weights=module_lora_params.merge_weights, ) with torch.no_grad(): lora_linear.weight.copy_(model.weight) if model.bias is not None: with torch.no_grad(): lora_linear.bias.copy_(model.bias) del model return lora_linear else: return model