Source code for cerebras.modelzoo.tools.checkpoint_converters.llama

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import re
from typing import Tuple

import torch

from cerebras.modelzoo.tools.checkpoint_converters.base_converter import (
    BaseCheckpointConverter_HF_CS,
    BaseConfigConverter,
    BaseConfigConverter_HF_CS,
    ConfigConversionError,
    ConversionRule,
    EquivalentSubkey,
    FormatVersions,
)
from cerebras.modelzoo.tools.checkpoint_converters.gpt2_hf_cs import (
    ConfigConverter_GPT2Model_CS18_CS20,
    Converter_GPT2LMHeadModel_CS18_CS20,
    Converter_GPT2LMHeadModel_CS20_CS21,
)


[docs]class Converter_LlamaAttention_HF_CS(BaseCheckpointConverter_HF_CS): def __init__(self): super().__init__() self.rules = [ ConversionRule( [ EquivalentSubkey("q_proj", "proj_q_dense_layer"), r"\.(?:weight|bias)", ], action=self.convert_with_interleaving_query, ), ConversionRule( [ EquivalentSubkey("k_proj", "proj_k_dense_layer"), r"\.(?:weight|bias)", ], action=self.convert_with_interleaving_key, ), ConversionRule( [ EquivalentSubkey("v_proj", "proj_v_dense_layer"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("o_proj", "proj_output_dense_layer"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ] def convert_with_interleaving_query( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): # Query & Keys should be interleaved since HF and CS RoPE differ cs_config = action_fn_args["configs"][1] tensor = old_state_dict[old_key] initial_shape = tensor.size() num_heads = cs_config["model"]["num_heads"] if from_index == 0: if len(tensor.size()) == 2: tensor = tensor.view( num_heads, tensor.size(0) // num_heads, tensor.size(-1) ) elif len(tensor.size()) == 1: tensor = tensor.view(num_heads, tensor.size(0) // num_heads) tensor = self.interleave_helper(tensor, cs_config) else: tensor = self.reverse_interleave_helper( tensor, cs_config, num_heads ) tensor = tensor.view(*initial_shape) new_state_dict[new_key] = tensor def convert_with_interleaving_key( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): # Query & Keys should be interleaved since HF and CS RoPE differ cs_config = action_fn_args["configs"][1] if ( cs_config["model"].get("attention_module", "aiayn_attention") == "aiayn_attention" ): self.convert_with_interleaving_query( old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ) return elif cs_config["model"]["attention_module"] == "multiquery_attention": tensor = old_state_dict[old_key] initial_shape = tensor.size() num_group = cs_config["model"]["extra_attention_params"][ "num_kv_groups" ] if from_index == 0: if len(tensor.size()) == 2: tensor = tensor.view( num_group, tensor.size(0) // num_group, tensor.size(-1) ) elif len(tensor.size()) == 1: tensor = tensor.view(num_group, tensor.size(0) // num_group) tensor = self.interleave_helper(tensor, cs_config) else: tensor = self.reverse_interleave_helper( tensor, cs_config, num_group ) tensor = tensor.view(*initial_shape) new_state_dict[new_key] = tensor else: assert False, ( f"attention_module {cs_config['model']['attention_module']} is not supported for " f"llama" ) def interleave_helper(self, t, cs_config): rotary_dim = cs_config["model"]["rotary_dim"] if len(t.shape) == 3: to_rotate = t[:, :rotary_dim, :] to_pass = t[:, rotary_dim:, :] to_rotate = ( to_rotate.reshape(t.shape[0], 2, -1, t.shape[-1]) .permute(0, 2, 1, 3) .reshape(t.shape[0], -1, t.shape[-1]) ) interleaved = torch.cat((to_rotate, to_pass), dim=1) elif len(t.shape) == 2: to_rotate = t[:, :rotary_dim] to_pass = t[:, rotary_dim:] to_rotate = ( to_rotate.reshape(t.shape[0], 2, -1) .permute(0, 2, 1) .reshape(t.shape[0], -1) ) interleaved = torch.cat((to_rotate, to_pass), dim=1) else: assert False, ( "shape of query, key, value projection tensor has to have shape of length 2 " "(biases) or 3 (weights) when converting from HF to CS." ) return interleaved def reverse_interleave_helper(self, t, cs_config, num_heads): rotary_dim = cs_config["model"]["rotary_dim"] if len(t.shape) == 2: t = t.reshape(num_heads, -1, t.shape[-1]) to_rotate = t[:, :rotary_dim, :] to_pass = t[:, rotary_dim:, :] # pylint: disable=redefined-builtin reversed = ( to_rotate.reshape(num_heads, -1, 2, t.shape[-1]) .permute(0, 2, 1, 3) .reshape(num_heads, rotary_dim, t.shape[-1]) ) reversed = torch.cat((reversed, to_pass), dim=1) elif len(t.shape) == 1: t = t.reshape(num_heads, -1) to_rotate = t[:, :rotary_dim] to_pass = t[:, rotary_dim:] reversed = ( to_rotate.reshape(num_heads, -1, 2) .permute(0, 2, 1) .reshape(num_heads, -1) ) reversed = torch.cat((reversed, to_pass), dim=1) else: assert False, ( "shape of query, key, value projection tensor has to have shape of length 1 " "(biases) or 2 (weights) when converting from CS to HF." ) return reversed @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.9")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return None
[docs]class Converter_LlamaModel_HF_CS(BaseCheckpointConverter_HF_CS): def __init__(self): super().__init__() self.rules = [ # word embeddings ConversionRule( [ EquivalentSubkey( "embed_tokens", "embedding_layer.word_embeddings" ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), # final layer norm ConversionRule( [ EquivalentSubkey("norm", "transformer_decoder.norm"), r"\.(?:weight|bias)", ], action=self.replace_final_norm, ), # attention ConversionRule( [ EquivalentSubkey("layers", "transformer_decoder.layers"), r"\.\d+\.self_attn\.", Converter_LlamaAttention_HF_CS(), ], action=None, ), # Rotary embedding ConversionRule( [r"layers\.\d+\.self_attn\.rotary_emb\.inv_freq"], exists="left", action=None, ), # attention norm ConversionRule( [ EquivalentSubkey("layers", "transformer_decoder.layers"), r"\.\d+\.", EquivalentSubkey("input_layernorm", "norm1"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("layers", "transformer_decoder.layers"), r"\.\d+\.", EquivalentSubkey("post_attention_layernorm", "norm3"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), # intermediate ffn ConversionRule( [ EquivalentSubkey("layers", "transformer_decoder.layers"), r"\.\d+\.", EquivalentSubkey("mlp.up_proj", "ffn.ffn.0.linear_layer"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("layers", "transformer_decoder.layers"), r"\.\d+\.", EquivalentSubkey( "mlp.gate_proj", "ffn.ffn.0.linear_layer_for_glu" ), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("layers", "transformer_decoder.layers"), r"\.\d+\.", EquivalentSubkey("mlp.down_proj", "ffn.ffn.1.linear_layer"), r"\.(?:weight|bias)", ], action=self.replaceKey, ), ConversionRule([r"lm_head\.(?:weight|bias)"], exists="right"), ConversionRule([r"ln_f\.(?:weight|bias)"], exists="right"), ] def replace_final_norm( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): new_state_dict[new_key] = old_state_dict[old_key] # CS 1.7 has both "ln_f" and "transformer_decoder.norm" # we need to copy the original ("ln_f") too: if from_index == 0: ln_f_key = re.sub(r"transformer_decoder\.norm\.", "ln_f.", new_key) new_state_dict[ln_f_key] = old_state_dict[old_key] def post_model_convert( self, old_state_dict, new_state_dict, configs, converter_indices, drop_unmatched_keys, key_prefix="", ): if converter_indices.direction == 0: # We are converting from HF LlamaModel (which is headless) -> # CS GPT2LMHeadModel configured as llama (which has a head) # We need to create 'lm_head' and init to default values logging.warning( f"{self.formats()[1]} has a language model head (lm_head) " f"while {self.formats()[0]} does not. Initializing lm_head to default." ) hf_config = configs[0] cs_config = configs[1] use_bias_in_output = cs_config["model"].get( "use_bias_in_output", False ) vocab_size = cs_config["model"]["vocab_size"] embed_dim = cs_config["model"]["hidden_size"] if hf_config["tie_word_embeddings"]: lm_head_weight = old_state_dict['embed_tokens.weight'] else: lm_head_weight = torch.zeros((vocab_size, embed_dim)) lm_head_weight.normal_(mean=0.0, std=0.02) new_state_dict[key_prefix + "lm_head.weight"] = lm_head_weight if use_bias_in_output: lm_head_bias = torch.zeros(vocab_size) new_state_dict[key_prefix + "lm_head.bias"] = lm_head_bias super().post_model_convert( old_state_dict, new_state_dict, configs, converter_indices, drop_unmatched_keys, key_prefix=key_prefix, ) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return None
[docs]class Converter_LlamaModel_HF_CS19(Converter_LlamaModel_HF_CS): def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [ Converter_LlamaModel_HF_CS(), ], action=None, ), # Catch checkpoints from 1.7/1.8 ConversionRule( [ EquivalentSubkey("", "model."), Converter_LlamaModel_HF_CS(), ], action=None, ), ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.9")) @classmethod def converter_note(cls) -> str: return ( f"{cls.formats()[0]} LlamaModel <-> {cls.formats()[1]} GPT2LMHeadModel (configured as " f"Llama)\nThe HF model doesn't contain a language model head while the CS one does. " f"When converting to CS, the exported checkpoint will contain a language model head " f"initialized to default random values. When converting to HF, the language model head " f"will be dropped." ).format(cls.formats()[0], cls.formats()[1]) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_LLaMa_HF_CS19
[docs]class Converter_LlamaForCausalLM_HF_CS(BaseCheckpointConverter_HF_CS): def __init__(self): super().__init__() self.rules = [ ConversionRule( [r"lm_head\.(?:weight|bias)"], action=self.replaceKey, ), ConversionRule( [ EquivalentSubkey("model.", ""), Converter_LlamaModel_HF_CS(), ], action=None, ), ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return None
[docs]class Converter_LlamaForCausalLM_HF_CS19(BaseCheckpointConverter_HF_CS): def __init__(self): super().__init__() self.rules = [ # Catch checkpoints from Pytorch 2.0 API ConversionRule( [ Converter_LlamaForCausalLM_HF_CS(), ], action=None, ), # Catch checkpoints from 1.7/1.8 ConversionRule( [ EquivalentSubkey("", "model."), Converter_LlamaForCausalLM_HF_CS(), ], action=None, ), ] @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.9")) @classmethod def converter_note(cls) -> str: return "{} LlamaForCausalLM <-> {} GPT2LMHeadModel (configured as Llama)".format( cls.formats()[0], cls.formats()[1] ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_LLaMa_HF_CS19
[docs]class ConfigConverter_LLaMa_HF_CS19(BaseConfigConverter_HF_CS): def __init__(self): super().__init__() if not hasattr(self, "model_type"): self.model_type = "llama" self.rules = [ ConversionRule( ["model_type"], action=BaseConfigConverter.assert_factory_fn( 0, self.model_type ), ), # Embedding ConversionRule(["vocab_size"], action=self.replaceKey), ConversionRule( ["position_embedding_type"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, "rotary"), ), ConversionRule( ["use_position_embedding"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["embedding_dropout_rate"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, 0.0), ), ConversionRule( [ EquivalentSubkey( "tie_word_embeddings", "share_embedding_weights" ) ], action=self.replaceKey, ), ConversionRule( ["embedding_layer_norm"], action=BaseConfigConverter.assert_factory_fn(1, False), ), # Decoder Block ConversionRule( ["hidden_size"], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("num_attention_heads", "num_heads")], action=self.replaceKey, ), ConversionRule( ["num_hidden_layers"], action=self.replaceKey, ), ConversionRule( ["max_position_embeddings"], action=self.replaceKey, ), ConversionRule( ["attention_type"], exists="right", action=BaseConfigConverter.assert_factory_fn( 1, "scaled_dot_product" ), ), ConversionRule( ["use_projection_bias_in_attention"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, False), ), ConversionRule( ["use_ffn_bias_in_attention"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, False), ), ConversionRule( ["use_ffn_bias"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, False), ), ConversionRule( [EquivalentSubkey("intermediate_size", "filter_size")], action=self.replaceKey, ), ConversionRule( [EquivalentSubkey("hidden_act", "nonlinearity")], action=self.convert_nonlinearity, ), ConversionRule( ["attention_dropout_rate"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, 0.0), ), ConversionRule( ["dropout_rate"], exists="right", action=BaseConfigConverter.assert_factory_fn(1, 0.0), ), ConversionRule( ["rotary_dim"], exists="right", action=self.assert_rotary_dim ), ConversionRule(["rope_theta"], action=self.replaceKey), ConversionRule( [EquivalentSubkey("rms_norm_eps", "layer_norm_epsilon")], action=self.replaceKey, ), ConversionRule( ["use_bias_in_output"], action=BaseConfigConverter.assert_factory_fn(1, False), ), ConversionRule(["initializer_range"], action=self.replaceKey), ConversionRule( ["fixed_sparse_attention"], action=BaseConfigConverter.assert_factory_fn(1, None), ), ConversionRule( ["norm_first"], action=BaseConfigConverter.assert_factory_fn(1, True), ), ConversionRule( ["use_ff_layer1_dropout"], action=BaseConfigConverter.assert_factory_fn(1, False), ), ConversionRule( ["use_rms_norm"], action=BaseConfigConverter.assert_factory_fn(1, True), ), ] self.pre_convert_defaults[0].update( { "vocab_size": 32000, "hidden_size": 4096, "intermediate_size": 11008, "num_hidden_layers": 32, "num_attention_heads": 32, "hidden_act": "silu", "initializer_range": 0.02, "rms_norm_eps": 1e-6, "tie_word_embeddings": False, "max_position_embeddings": 2048, } ) self.pre_convert_defaults[1].update( { "share_embedding_weights": True, "use_rms_norm": False, "max_position_embeddings": 1024, "position_embedding_type": "learned", "layer_norm_epsilon": 1.0e-5, "use_projection_bias_in_attention": True, "use_ffn_bias_in_attention": True, "nonlinearity": "gelu", "use_ffn_bias": True, "use_bias_in_output": False, "norm_first": True, }, ) self.post_convert_defaults[0].update({"model_type": "llama"}) self.post_convert_defaults[1].update( { "use_position_embedding": True, "position_embedding_type": "rotary", "embedding_dropout_rate": 0.0, "embedding_layer_norm": False, "attention_type": "scaled_dot_product", "use_projection_bias_in_attention": False, "use_ffn_bias_in_attention": False, "use_ffn_bias": False, "attention_dropout_rate": 0.0, "dropout_rate": 0.0, "use_bias_in_output": False, "norm_first": True, "use_ff_layer1_dropout": False, "use_rms_norm": True, }, ) def convert_nonlinearity( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): activation = old_state_dict[old_key] if from_index == 0: gated_hf2cs = {"silu": "swiglu", "relu": "reglu", "gelu": "geglu"} if activation not in gated_hf2cs: raise ConfigConversionError( "{} is not a GLU-able activation in CS".format(activation) ) activation = gated_hf2cs[activation] elif from_index == 1: gated_cs2hf = {"swiglu": "silu", "reglu": "relu", "geglu": "gelu"} if activation not in gated_cs2hf: raise ConfigConversionError( "{} is not a supported GLU activation in HF".format( activation ) ) activation = gated_cs2hf[activation] new_state_dict[new_key] = activation def assert_rotary_dim( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): assert from_index == 1, "{} should only exist in CS config".format( old_key ) if ( old_state_dict[old_key] != old_state_dict["hidden_size"] // old_state_dict["num_heads"] ): raise ConfigConversionError( "rotary_dim must be hidden_size // num_heads in order to be compatible with HF" ) def pre_config_convert( self, config, converter_indices, ): config = super().pre_config_convert(config, converter_indices) if converter_indices.direction == 1 and ( "rotary_dim" not in config or config["rotary_dim"] is None ): raise ConfigConversionError("rotary_dim must be specified") return config def post_config_convert( self, original_config, old_config, new_config, converter_indices, drop_unmatched_keys, ): if converter_indices.direction == 0: new_config["rotary_dim"] = ( new_config["hidden_size"] // new_config["num_heads"] ) return super().post_config_convert( original_config, old_config, new_config, converter_indices, drop_unmatched_keys, ) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-1.9"))
[docs]class Converter_LlamaForCausalLM_CS19_CS20(Converter_GPT2LMHeadModel_CS18_CS20): r""" Llama uses the GPT2 backbone """ @classmethod def converter_note(cls) -> str: return "GPT2LMHeadModel class (configured as Llama)" @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("cs-1.9"), FormatVersions("cs-2.0")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_LlamaModel_CS19_CS20
[docs]class ConfigConverter_LlamaModel_CS19_CS20(ConfigConverter_GPT2Model_CS18_CS20): r""" Llama uses the GPT2 backbone """ @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("cs-1.9"), FormatVersions("cs-2.0"))
[docs]class Converter_LlamaModel_HF_CS20(Converter_LlamaModel_HF_CS19): @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-2.0")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_LLaMa_HF_CS20
[docs]class Converter_LlamaForCausalLM_HF_CS20(Converter_LlamaForCausalLM_HF_CS19): @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-2.0")) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_LLaMa_HF_CS20
[docs]class ConfigConverter_LLaMa_HF_CS20(ConfigConverter_LLaMa_HF_CS19): def __init__(self): super().__init__() self.rules = [ ConversionRule( ["norm_type"], action=BaseConfigConverter.assert_factory_fn(1, "rmsnorm"), ), ConversionRule( [ EquivalentSubkey( "num_key_value_heads", "extra_attention_params" ) ], action=self.convert_gqa, ), *self.rules, ] del self.pre_convert_defaults[1]["use_rms_norm"] del self.post_convert_defaults[1]["use_rms_norm"] self.pre_convert_defaults[1]["norm_type"] = "layernorm" self.post_convert_defaults[1]["norm_type"] = "rmsnorm" del self.post_convert_defaults[1]["use_position_embedding"] def convert_gqa( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 0: # check mha or gqa if old_state_dict[old_key] == old_state_dict["num_attention_heads"]: new_state_dict["attention_module"] = "aiayn_attention" else: assert ( old_state_dict["num_attention_heads"] % old_state_dict[old_key] == 0 ), ( f"number of attention heads should be divisible by num_key_value_heads but " f"got {old_state_dict['num_attention_heads']} and {old_state_dict[old_key]}," ) extra = {"num_kv_groups": old_state_dict[old_key]} new_state_dict[new_key] = extra new_state_dict["attention_module"] = "multiquery_attention" elif from_index == 1: if ( old_state_dict.get("attention_module", "aiayn_attention") == "aiayn_attention" ): assert ( old_key not in old_state_dict or "num_kv_groups" not in old_state_dict[old_key] ), "Conflict between use of multi-query and multi-head attention" new_state_dict[new_key] = old_state_dict["num_heads"] elif old_state_dict["attention_module"] == "multiquery_attention": num_heads = old_state_dict["num_heads"] num_kv_groups = old_state_dict[old_key]["num_kv_groups"] assert num_heads % num_kv_groups == 0, ( f"number of attention heads should be divisible by num_key_value_heads but " f"got {num_heads} and {num_kv_groups}." ) new_state_dict[new_key] = old_state_dict[old_key][ "num_kv_groups" ] else: assert False, ( f"attention_module {old_state_dict['attention_module']} is not supported for " f"llama" ) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return (FormatVersions("hf"), FormatVersions("cs-2.0"))
########################################################### # In CS 2.1, we refactored the embedding layer. # CS 2.0 <> CS 2.1. We don't need a separate HF <> CS 2.1 converters since # HF only supports RoPE which doesn't produce any checkpoint keys. ###########################################################
[docs]class Converter_LlamaForCausalLM_CS20_CS21(Converter_GPT2LMHeadModel_CS20_CS21): @classmethod def converter_note(cls) -> str: return "GPT2LMHeadModel class (configured as Llama)"
[docs]class Converter_LlamaModel_HF_CS21(Converter_LlamaModel_HF_CS20): @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2", "cs-2.3"), ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_LLaMa_HF_CS21
[docs]class Converter_LlamaForCausalLM_HF_CS21(Converter_LlamaForCausalLM_HF_CS20): @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2", "cs-2.3"), ) @staticmethod def get_config_converter_class() -> BaseConfigConverter: return ConfigConverter_LLaMa_HF_CS21 def supports_mup_conversion(self): return True
[docs]class ConfigConverter_LLaMa_HF_CS21(ConfigConverter_LLaMa_HF_CS20): def __init__(self): super().__init__() self.rules = [ ConversionRule( [EquivalentSubkey("rope_scaling", "pos_scaling_factor")], action=self.convert_pi, ), ConversionRule( [EquivalentSubkey("", "pos_scaling_extra_args")], action=None, ), ConversionRule( [EquivalentSubkey("", "pos_scaling_type")], action=None, ), *self.rules, ] self.pre_convert_defaults[0].update( { "rope_scaling": None, } ) self.pre_convert_defaults[1].update( { "pos_scaling_factor": 1.0, }, ) def convert_pi( self, old_key, new_key, old_state_dict, new_state_dict, from_index, action_fn_args, ): if from_index == 0: if old_state_dict[old_key] is None: new_state_dict[new_key] = 1.0 else: if "type" in old_state_dict[old_key]: scaling_type = old_state_dict[old_key]["type"].lower() else: scaling_type = old_state_dict[old_key]["rope_type"].lower() if scaling_type not in ["linear", "yarn", "llama3"]: raise ConfigConversionError( f"Only `rope_scaling` type `linear`, `yarn`, and `llama3` are currently supported, " f"but got type `{scaling_type}`." ) new_state_dict[new_key] = old_state_dict[old_key]["factor"] new_state_dict["pos_scaling_type"] = scaling_type if scaling_type == "yarn": new_state_dict["pos_scaling_extra_args"] = dict( { "original_max_position_embeddings": old_state_dict[ old_key ]["original_max_position_embeddings"], } ) elif scaling_type == "llama3": pos_scaling_extra_args = copy.deepcopy( old_state_dict[old_key] ) pos_scaling_extra_args.pop("rope_type") pos_scaling_extra_args.pop("factor") new_state_dict["pos_scaling_extra_args"] = ( pos_scaling_extra_args ) else: if old_state_dict[old_key] == 1.0: new_state_dict[new_key] = None else: type_key = "type" if old_state_dict["pos_scaling_type"] == "llama3": type_key = "rope_type" new_state_dict[new_key] = { type_key: old_state_dict.get("pos_scaling_type", "linear"), "factor": old_state_dict[old_key], } if "pos_scaling_extra_args" in old_state_dict: new_state_dict[new_key].update( old_state_dict["pos_scaling_extra_args"] ) @staticmethod def formats() -> Tuple[FormatVersions, FormatVersions]: return ( FormatVersions("hf"), FormatVersions("cs-2.1", "cs-2.2", "cs-2.3"), ) def supports_mup_conversion(self): return True