# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is adapted from
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
#
# Copyright 2022 Cerebras Systems.
#
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import math
from dataclasses import dataclass
from typing import List, Optional, Set, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]@dataclass
class LoraConfig:
r"""
r: Rank of LoRA matrix projections
alpha: Scaling factor (see paper for additional details)
dropout: Dropout to apply to LoRA updates
fan_in_fan_out:
merge_weights: Determines whether lora weights should be merged/folded
into underlying layers
target_modules: A list of module names that must all exist in layers
that will be converted to LoRA. For example, setting target_modules
to ["TransformerDecoderLayer", "Linear"] would mean that all linear
layers that were children of a TransformerDecoderLayer would be
converted to LoRA.
"""
r: int = 0
alpha: int = 1
dropout: float = 0.0
fan_in_fan_out: bool = False
merge_weights: bool = False
target_modules: Optional[list] = None
[docs]def disable_lora_merge_weights(lora_params_dict: Union[dict, List[dict]]):
r"""Sets merge_weights=False in LoRA parameters. This is helpful during
eval mode to ensure that the weights don't get folded prior to checkpoint
loading.
"""
def _disable_merge_weights(params, printed_already=False):
if params["merge_weights"] and not printed_already:
logging.warning(
"Automatically switching LoRA merge_weights to False in order "
"to run evals."
)
printed_already = True
params["merge_weights"] = False
return printed_already
if isinstance(lora_params_dict, list):
printed = True
for params in lora_params_dict:
printed = _disable_merge_weights(params, printed)
else:
_disable_merge_weights(lora_params_dict)
[docs]class LoRALayer:
r"""
Base LoRA layer
From https://github.com/microsoft/LoRA/blob/main/loralib/layers.py.
"""
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
[docs]class LoRA_Embedding(nn.Embedding, LoRALayer):
r"""
LoRA embedding layer
From https://github.com/microsoft/LoRA/blob/main/loralib/layers.py.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
r: int = 0,
lora_alpha: int = 1,
merge_weights: bool = True,
**kwargs,
):
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
LoRALayer.__init__(
self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=0,
merge_weights=merge_weights,
)
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(
self.weight.new_zeros((r, num_embeddings))
)
self.lora_B = nn.Parameter(
self.weight.new_zeros((embedding_dim, r))
)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
nn.Embedding.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.zeros_(self.lora_A)
nn.init.normal_(self.lora_B)
def train(self, mode: bool = True):
nn.Embedding.train(self, mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (self.lora_B @ self.lora_A).transpose(
0, 1
) * self.scaling
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (self.lora_B @ self.lora_A).transpose(
0, 1
) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
result = nn.Embedding.forward(self, x)
if self.r > 0:
after_A = F.embedding(
x,
self.lora_A.transpose(0, 1),
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
return result
else:
return nn.Embedding.forward(self, x)
[docs]class LoRA_Linear(nn.Linear, LoRALayer):
r"""
LoRA linear layer
From https://github.com/microsoft/LoRA/blob/main/loralib/layers.py.
"""
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False,
merge_weights: bool = True,
**kwargs,
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(
self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights,
)
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.transpose(0, 1)
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool = True):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
if mode:
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
if self.r > 0:
self.weight.data -= (
T(self.lora_B @ self.lora_A) * self.scaling
)
self.merged = False
else:
if self.merge_weights and not self.merged:
# Merge the weights and mark it
if self.r > 0:
self.weight.data += (
T(self.lora_B @ self.lora_A) * self.scaling
)
self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
if self.r > 0:
result += (
self.lora_dropout(x)
[docs] @ self.lora_A.transpose(0, 1)
@ self.lora_B.transpose(0, 1)
) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
def get_lora_config_for_module(
lora_params: Union[LoraConfig, List[LoraConfig]], module_names: List[str]
) -> Optional[LoraConfig]:
r"""
Gets lora parameters for a particular module
Args:
lora_params: LoRA top-level config.
module_names: Hierarchical list of module names.
Returns:
lora parameters (LoraConfig) for the given module if applicable or None
if the module is not targeted.
"""
lora_params_list = (
lora_params if isinstance(lora_params, list) else [lora_params]
)
for group_params in lora_params_list:
target_modules = group_params.target_modules
if target_modules is None or all(
[e in module_names for e in target_modules]
):
return group_params
return None
[docs]def make_model_lora(
model: nn.Module, lora_params_dict: Union[dict, List[dict]]
):
r"""
Create a Low Rank Adaptation (LoRA) model from a non-LoRA model. Note that
the original non-LoRA model may be modified through this process.
Args:
model: Initial model to make LoRA
lora_params_dict: LoRA parameters (in the form of a dict or list of
dicts) which dictate how the supplied model will be converted into
a LoRA model. The parameters should align with LoraConfig.
Returns:
LoRA model
"""
if isinstance(lora_params_dict, list):
lora_params = [LoraConfig(**e) for e in lora_params_dict]
else:
lora_params = LoraConfig(**lora_params_dict)
loraified_modules = set()
lora_model = make_model_lora_helper(
model, lora_params, [], loraified_modules
)
if len(loraified_modules) == 0:
raise RuntimeError(
f"No modules were converted to LoRA. Please ensure that the "
f"target_modules listed in the lora_params are valid."
)
logging.info(
f"All layers matching the following module names were converted to LoRA"
f": {loraified_modules}"
)
for n, p in lora_model.named_parameters():
if not n.endswith(".lora_A") and not n.endswith(".lora_B"):
p.requires_grad = False
return lora_model
[docs]def make_model_lora_helper(
model: nn.Module,
lora_params: Union[LoraConfig, List[LoraConfig]],
module_names: List[str],
loraified_modules: Set[str],
):
module_names = module_names + [type(model).__name__]
for name, child in model.named_children():
model.add_module(
name,
make_model_lora_helper(
child, lora_params, module_names, loraified_modules
),
)
module_lora_params = get_lora_config_for_module(lora_params, module_names)
if module_lora_params is not None and isinstance(model, nn.Embedding):
loraified_modules.add(".".join(module_names))
lora_embedding = LoRA_Embedding(
# Embedding Args:
model.num_embeddings,
model.embedding_dim,
padding_idx=model.padding_idx,
max_norm=model.max_norm,
norm_type=model.norm_type,
scale_grad_by_freq=model.scale_grad_by_freq,
sparse=model.sparse,
device=model.weight.device,
dtype=model.weight.dtype,
# LoRA Args:
r=module_lora_params.r,
lora_alpha=module_lora_params.alpha,
merge_weights=module_lora_params.merge_weights,
)
with torch.no_grad():
lora_embedding.weight.copy_(model.weight)
del model
return lora_embedding
elif module_lora_params is not None and isinstance(model, nn.Linear):
loraified_modules.add(".".join(module_names))
lora_linear = LoRA_Linear(
# Linear Args:
model.in_features,
model.out_features,
bias=model.bias is not None,
device=model.weight.device,
dtype=model.weight.dtype,
# LoRA Args:
r=module_lora_params.r,
lora_alpha=module_lora_params.alpha,
lora_dropout=module_lora_params.dropout,
fan_in_fan_out=module_lora_params.fan_in_fan_out,
merge_weights=module_lora_params.merge_weights,
)
with torch.no_grad():
lora_linear.weight.copy_(model.weight)
if model.bias is not None:
with torch.no_grad():
lora_linear.bias.copy_(model.bias)
del model
return lora_linear
else:
return model