# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import Tuple
import torch
from cerebras.modelzoo.tools.checkpoint_converters.base_converter import (
BaseCheckpointConverter_HF_CS,
BaseConfigConverter,
ConversionRule,
EquivalentSubkey,
FormatVersions,
)
from cerebras.modelzoo.tools.checkpoint_converters.helper import (
Build_HF_CS_Converter_WithOptionalModel,
)
from cerebras.modelzoo.tools.checkpoint_converters.llama import (
ConfigConverter_LLaMa_HF_CS21,
Converter_LlamaAttention_HF_CS,
)
MAGIC_STR = "__"
[docs]class Converter_MixtralModel_HF_CS(BaseCheckpointConverter_HF_CS):
def __init__(self):
super().__init__()
self.rules = [
# word embeddings
ConversionRule(
[
EquivalentSubkey(
"embed_tokens", "embedding_layer.word_embeddings"
),
r"\.(?:weight|bias)",
],
action=self.replaceKey,
),
# final layer norm
ConversionRule(
[
EquivalentSubkey("norm", "transformer_decoder.norm"),
r"\.(?:weight|bias)",
],
action=self.replace_final_norm,
),
# attention
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.self_attn\.",
Converter_LlamaAttention_HF_CS(),
],
action=None,
),
# Rotary embedding
ConversionRule(
[r"layers\.\d+\.self_attn\.rotary_emb\.inv_freq"],
exists="left",
action=None,
),
# attention norm
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey("input_layernorm", "norm1"),
r"\.(?:weight|bias)",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey("post_attention_layernorm", "norm3"),
r"\.(?:weight|bias)",
],
action=self.replaceKey,
),
# moe ffn
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey("block_sparse_moe.gate", "ffn.gate"),
r"\.(?:weight|bias)",
],
action=self.replaceKey,
),
*self.moe_rules(),
ConversionRule([r"lm_head\.(?:weight|bias)"], exists="right"),
ConversionRule([r"ln_f\.(?:weight|bias)"], exists="right"),
]
def replace_final_norm(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
new_state_dict[new_key] = old_state_dict[old_key]
# CS 1.7 has both "ln_f" and "transformer_decoder.norm"
# we need to copy the original ("ln_f") too:
if from_index == 0:
ln_f_key = re.sub(r"transformer_decoder\.norm\.", "ln_f.", new_key)
new_state_dict[ln_f_key] = old_state_dict[old_key]
def moe_rules(self):
return self.moe_optimized_impl_rules()
def moe_functional_impl_rules(self):
return [
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey(
"block_sparse_moe.experts", "ffn.experts.experts"
),
r"\.\d+\.",
EquivalentSubkey("w1", "ffn.0.linear_layer_for_glu"),
r"\.weight",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey(
"block_sparse_moe.experts", "ffn.experts.experts"
),
r"\.\d+\.",
EquivalentSubkey("w3", "ffn.0.linear_layer"),
r"\.weight",
],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey(
"block_sparse_moe.experts", "ffn.experts.experts"
),
r"\.\d+\.",
EquivalentSubkey("w2", "ffn.1.linear_layer"),
r"\.weight",
],
action=self.replaceKey,
),
]
def moe_optimized_impl_rules(self):
return [
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey(
"block_sparse_moe.experts",
"ffn.experts",
),
EquivalentSubkey(
".0.w1.weight",
".fused_ffns.0.linear_layer_for_glu.expert_weights",
),
],
action=self.convert_expert_weights,
),
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey(
"block_sparse_moe.experts",
f"ffn.experts{MAGIC_STR}",
),
r"\.\d+",
EquivalentSubkey(
".w1.weight",
".fused_ffns.0.linear_layer_for_glu.expert_weights",
),
],
action=self.assert_already_converted,
),
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey(
"block_sparse_moe.experts",
"ffn.experts",
),
EquivalentSubkey(
".0.w3.weight",
".fused_ffns.0.linear_layer.expert_weights",
),
],
action=self.convert_expert_weights,
),
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey(
"block_sparse_moe.experts",
f"ffn.experts{MAGIC_STR}",
),
r"\.\d+",
EquivalentSubkey(
".w3.weight",
".fused_ffns.0.linear_layer.expert_weights",
),
],
action=self.assert_already_converted,
),
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey(
"block_sparse_moe.experts",
"ffn.experts",
),
EquivalentSubkey(
".0.w2.weight",
".fused_ffns.1.linear_layer.expert_weights",
),
],
action=self.convert_expert_weights,
),
ConversionRule(
[
EquivalentSubkey("layers", "transformer_decoder.layers"),
r"\.\d+\.",
EquivalentSubkey(
"block_sparse_moe.experts",
f"ffn.experts{MAGIC_STR}",
),
r"\.\d+",
EquivalentSubkey(
".w2.weight",
".fused_ffns.0.linear_layer.expert_weights",
),
],
action=self.assert_already_converted,
),
]
def convert_expert_weights(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
num_experts = action_fn_args['configs'][1]['model']['moe'][
'num_experts'
]
if from_index == 0:
# Fuse weights across experts.
expert_weights = []
for expert in range(num_experts):
curr_old_key = re.sub(
r"experts\.0", f"experts.{expert}", old_key
)
expert_weights.append(old_state_dict[curr_old_key])
expert_weights = [v.unsqueeze(1) for v in expert_weights]
new_state_dict[new_key] = torch.concat(expert_weights, dim=1)
else:
# Unfuse weights.
for expert in range(num_experts):
curr_new_key = re.sub(
r"experts\.0", f"experts.{expert}", new_key
)
new_state_dict[curr_new_key] = old_state_dict[old_key][
:, expert
]
def assert_already_converted(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 0:
new_key = re.sub(f"{MAGIC_STR}.\d+", "", new_key)
assert (
new_key in new_state_dict
), f"Expected {new_key} to be in new_state_dict"
else:
assert False, "Unreachable"
def post_model_convert(
self,
old_state_dict,
new_state_dict,
configs,
converter_indices,
drop_unmatched_keys,
key_prefix="",
):
if converter_indices.direction == 0:
# We are converting from HF LlamaModel (which is headless) ->
# CS GPT2LMHeadModel configured as llama (which has a head)
# We need to create 'lm_head' and init to default values
logging.warning(
f"{self.formats()[1]} has a language model head (lm_head) "
f"while {self.formats()[0]} does not. Initializing lm_head to default."
)
hf_config = configs[0]
cs_config = configs[1]
use_bias_in_output = cs_config["model"].get(
"use_bias_in_output", False
)
vocab_size = cs_config["model"]["vocab_size"]
embed_dim = cs_config["model"]["hidden_size"]
if hf_config["tie_word_embeddings"]:
lm_head_weight = old_state_dict['embed_tokens.weight']
else:
lm_head_weight = torch.zeros((vocab_size, embed_dim))
lm_head_weight.normal_(mean=0.0, std=0.02)
new_state_dict[key_prefix + "lm_head.weight"] = lm_head_weight
if use_bias_in_output:
lm_head_bias = torch.zeros(vocab_size)
new_state_dict[key_prefix + "lm_head.bias"] = lm_head_bias
super().post_model_convert(
old_state_dict,
new_state_dict,
configs,
converter_indices,
drop_unmatched_keys,
key_prefix=key_prefix,
)
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("hf"), FormatVersions("cs"))
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return None
[docs]class Converter_MixtralForCausalLM_HF_CS(BaseCheckpointConverter_HF_CS):
def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[r"lm_head\.(?:weight|bias)"],
action=self.replaceKey,
),
ConversionRule(
[
EquivalentSubkey("model.", ""),
Converter_MixtralModel_HF_CS(),
],
action=None,
),
]
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("hf"), FormatVersions("cs"))
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return None
[docs]class Converter_MixtralModel_WithoutOptionalModel_HF_CS23(
Converter_MixtralModel_HF_CS
):
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("hf"), FormatVersions("cs-2.3"))
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Mixtral_HF_CS23
@classmethod
def converter_note(cls) -> str:
return (
f"{cls.formats()[0]} MixtralModel <-> {cls.formats()[1]} GPT2LMHeadModel (configured as "
f"Mixtral)\nThe HF model doesn't contain a language model head while the CS one does. "
f"When converting to CS, the exported checkpoint will contain a language model head "
f"initialized to default random values. When converting to HF, the language model head "
f"will be dropped."
).format(cls.formats()[0], cls.formats()[1])
[docs]class Converter_MixtralLMHeadModel_WithoutOptionalModel_HF_CS23(
Converter_MixtralForCausalLM_HF_CS
):
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("hf"), FormatVersions("cs-2.3"))
@staticmethod
def get_config_converter_class() -> BaseConfigConverter:
return ConfigConverter_Mixtral_HF_CS23
@classmethod
def converter_note(cls) -> str:
return "{} MixtralForCausalLM <-> {} GPT2LMHeadModel (configured as Mixtral)".format(
cls.formats()[0], cls.formats()[1]
)
[docs]class ConfigConverter_Mixtral_HF_CS23(ConfigConverter_LLaMa_HF_CS21):
def __init__(self):
self.model_type = "mixtral"
super().__init__()
self.rules = [
ConversionRule(
[
EquivalentSubkey(
"sliding_window", "attention_sliding_window_length"
)
],
action=self.replaceKey,
),
ConversionRule(
[EquivalentSubkey("num_local_experts", "moe")],
action=self.convert_moe_params,
),
*self.rules,
]
self.post_convert_defaults[0].update({"model_type": "mixtral"})
def convert_moe_params(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 0:
if "moe" not in new_state_dict:
new_state_dict["moe"] = {}
new_state_dict["moe"]["num_experts"] = old_state_dict[
"num_local_experts"
]
new_state_dict["moe"]["top_k"] = old_state_dict[
"num_experts_per_tok"
]
new_state_dict["moe"]["load_balancing_loss_coef"] = old_state_dict[
"router_aux_loss_coef"
]
else:
new_state_dict["num_local_experts"] = old_state_dict["moe"][
"num_experts"
]
new_state_dict["num_experts_per_tok"] = old_state_dict["moe"][
"top_k"
]
new_state_dict["router_aux_loss_coef"] = old_state_dict["moe"][
"load_balancing_loss_coef"
]
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return (FormatVersions("hf"), FormatVersions("cs-2.3"))
Converter_MixtralModel_HF_CS23 = Build_HF_CS_Converter_WithOptionalModel(
"Converter_MixtralModel_HF_CS23",
Converter_MixtralModel_WithoutOptionalModel_HF_CS23,
derived_class=Converter_MixtralModel_WithoutOptionalModel_HF_CS23,
)
Converter_MixtralForCausalLM_HF_CS23 = Build_HF_CS_Converter_WithOptionalModel(
"Converter_MixtralForCausalLM_HF_CS23",
Converter_MixtralLMHeadModel_WithoutOptionalModel_HF_CS23,
derived_class=Converter_MixtralLMHeadModel_WithoutOptionalModel_HF_CS23,
)