# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Tuple
from cerebras.modelzoo.tools.checkpoint_converters.base_converter import (
BaseConfigConverter,
BaseDictionaryConverter,
ConversionRule,
EquivalentSubkey,
FormatVersions,
)
[docs]def scale_initializers_by_dimension(
initializers,
width_scale=None,
depth_scale=None,
):
if not width_scale:
width_scale = 1.0
if not depth_scale:
depth_scale = 1.0
mup_scalar = width_scale * depth_scale
if not isinstance(initializers, list):
initializers = [initializers]
for initializer in initializers:
if type(initializer) == str:
initializer = {"name": initializer}
if "name" not in initializer:
raise ValueError("Initializer name must be provided")
initializer_name = initializer["name"].lower()
if initializer_name == "normal":
initializer["std"] = initializer.get("std", 1.0) * mup_scalar
elif initializer_name == "truncated_normal":
std = initializer.get("std", 1.0)
initializer["std"] = std * mup_scalar
initializer["a"] = initializer.get("a", -2 * std) * mup_scalar
initializer["b"] = initializer.get("b", 2 * std) * mup_scalar
std = None
[docs]class ConfigConverter_sP_muP_pre_CS23(BaseConfigConverter):
"""Transforms a CS 2.2 and before muP config to a CS sP config."""
def __init__(self):
super().__init__()
self.rules = [
ConversionRule(["output_logits_scale"]),
ConversionRule(["embeddings_scale"]),
ConversionRule(["scale_qk_dot_by_d"]),
ConversionRule(
["share_embedding_weights"],
action=self.set_share_embedding_weights,
),
ConversionRule(
[r".*"], action=self.replaceKey
), # Catch-all for everything else
]
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return ("sP", "muP")
@staticmethod
def file_formats() -> Tuple[str, str]:
return ()
@staticmethod
def is_mup(config):
return _is_mup_CS22(config)
def set_share_embedding_weights(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 1 and (
"output_logits_scale" in old_state_dict
or "embeddings_scale" in old_state_dict
):
new_state_dict[new_key] = False
else:
new_state_dict[new_key] = old_state_dict[old_key]
[docs]class ConfigConverter_sP_muP_post_CS23(ConfigConverter_sP_muP_pre_CS23):
"""Transforms a CS 2.3 and onwards muP config to a CS sP config."""
def __init__(self):
super().__init__()
self.rules = [
ConversionRule(["mup_base_hidden_size"]),
ConversionRule(["mup_base_filter_size"]),
ConversionRule(["attention_logits_alpha"]),
ConversionRule(["output_logits_alpha"]),
ConversionRule(["scale_output_logits_by_d"]),
ConversionRule(["embeddings_scale"]),
ConversionRule(["scale_qk_dot_by_d"]),
ConversionRule(
["share_embedding_weights"],
action=self.set_share_embedding_weights,
),
ConversionRule(
[r".*"], action=self.replaceKey
), # Catch-all for everything else
]
def set_share_embedding_weights(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 1 and (
_is_mup(old_state_dict) or 'embeddings_scale' in old_state_dict
):
new_state_dict[new_key] = False
else:
new_state_dict[new_key] = old_state_dict[old_key]
@staticmethod
def is_mup(config):
return _is_mup(config)
[docs]class ConfigConverter_T5_sP_muP(ConfigConverter_sP_muP_post_CS23):
def __init__(self):
super().__init__()
self.rules = [
ConversionRule(["mup_base_d_model"]),
ConversionRule(["mup_base_d_ff"]),
ConversionRule(["mup_base_d_kv"]),
ConversionRule(["encoder_attention_logits_alpha"]),
ConversionRule(["decoder_attention_logits_alpha"]),
ConversionRule(["output_logits_alpha"]),
ConversionRule(["scale_output_logits_by_d"]),
ConversionRule(["embeddings_alpha"]),
ConversionRule(["scale_encoder_qk_dot_by_d"]),
ConversionRule(["scale_decoder_qk_dot_by_d"]),
ConversionRule(
["share_embedding_weights"],
action=self.set_share_embedding_weights,
),
ConversionRule(
[r".*"], action=self.replaceKey
), # Catch-all for everything else
]
def set_share_embedding_weights(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
if from_index == 1 and (
_is_mup(old_state_dict) or 'embeddings_alpha' in old_state_dict
):
new_state_dict[new_key] = False
else:
new_state_dict[new_key] = old_state_dict[old_key]
@staticmethod
def is_mup(config):
return _is_mup(config)
[docs]class ConfigConverter_muP_CS22_CS23(BaseConfigConverter):
"""Transforms a CS 2.2 muP config to a CS 2.3 muP config with the reworked
params.
"""
def __init__(self):
super().__init__()
self.hidden_size_width_mult = None
self.rules = [
ConversionRule(
[
EquivalentSubkey("initializer", "initializer"),
],
action=self.unscale_input_initializer,
),
ConversionRule(
[
EquivalentSubkey(
"output_layer_initializer", "output_layer_initializer"
),
],
action=self.unscale_output_initializer,
),
ConversionRule(
[
EquivalentSubkey(
"output_logits_scale", "output_logits_alpha"
),
],
action=self.replace_output_logits_scale,
),
ConversionRule(
[r".*"], action=self.replaceKey
), # Catch-all for everything else
]
@classmethod
def convert(
cls,
config,
converter_indices,
drop_unmatched_keys=False,
no_progress_bar=True,
debug=False,
):
instance = cls()
lr_scale = config["optimizer"].pop("adjust_learning_rate")
instance.hidden_size_width_mult = 1 / lr_scale["decoder_kernel"]
return instance.convert_helper(
config["model"],
converter_indices,
drop_unmatched_keys=drop_unmatched_keys,
no_progress_bar=no_progress_bar,
debug=debug,
)
def unscale_input_initializer(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
# self.hidden_size_width_mult = hidden_size / mup_base_hidden_size
# input_layer = input_layer' / sqrt(self.hidden_size_width_mult)
initializer_config = old_state_dict[old_key]
scale_initializers_by_dimension(
initializer_config,
width_scale=self.hidden_size_width_mult**0.5,
)
new_state_dict[new_key] = initializer_config
def unscale_output_initializer(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
# self.hidden_size_width_mult = hidden_size / mup_base_hidden_size
num_hidden_layers = old_state_dict["num_hidden_layers"]
# output_layer = output_layer' / sqrt(2 * num_hidden_layers * self.hidden_size_width_mult)
initializer_config = old_state_dict[old_key]
scale_initializers_by_dimension(
initializer_config,
width_scale=self.hidden_size_width_mult**0.5,
depth_scale=(2 * num_hidden_layers) ** 0.5,
)
new_state_dict[new_key] = initializer_config
def replace_output_logits_scale(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
hidden_size = old_state_dict["hidden_size"]
filter_size = old_state_dict["filter_size"]
# self.hidden_size_width_mult = hidden_size / mup_base_hidden_size
mup_base_hidden_size = hidden_size / self.hidden_size_width_mult
mup_base_filter_size = filter_size / self.hidden_size_width_mult
# output_logits_scale = output_logits_alpha / self.hidden_size_width_mult
output_logits_alpha = (
old_state_dict[old_key] * self.hidden_size_width_mult
)
new_state_dict["mup_base_hidden_size"] = mup_base_hidden_size
new_state_dict["mup_base_filter_size"] = mup_base_filter_size
new_state_dict[new_key] = output_logits_alpha
@staticmethod
def formats() -> Tuple[FormatVersions, FormatVersions]:
return ("cs-2.2", "cs-2.3")
@staticmethod
def file_formats() -> Tuple[str, str]:
return ()
@staticmethod
def is_mup(config):
return _is_mup_CS22(config)
[docs]class Converter_sP_muP_pre_CS23(BaseDictionaryConverter):
"""Transforms a CS 2.2 and older muP checkpoints into a CS sP checkpoint.
muP: Maximal Update Parametrization.
sP: Standard Parametrization.
"""
def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[r".+\.proj_k_dense_layer.*"],
action=self.scale_k_projection,
),
ConversionRule(
[r"(?:model\.|)lm_head\.weight"],
action=self.scale_lm_head,
),
ConversionRule(
[r"(?:model\.|)embedding_layer\.word_embeddings\.weight"],
action=self.scale_embeddings,
),
ConversionRule(
[
r"(?:model\.|)embedding_layer\.position_embeddings(?:\.embed)?\.weight"
],
action=self.scale_embeddings,
),
ConversionRule(
[r"(?:model\.|)embedding_ln_f\.(?:weight|bias)"],
action=self.scale_embedding_layernorm,
),
ConversionRule(
[r".*"], action=self.replaceKey
), # Catch-all for everything else
]
def scale_k_projection(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
if config["model"].get('scale_qk_dot_by_d', False):
d_model = config["model"]["hidden_size"]
n_heads = config["model"]["num_heads"]
d_sqrt = math.sqrt(d_model // n_heads)
new_state_dict[new_key] = old_state_dict[old_key] / d_sqrt
else:
new_state_dict[new_key] = old_state_dict[old_key]
def scale_lm_head(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
if "output_logits_scale" in config["model"]:
output_scale = config["model"]["output_logits_scale"]
new_state_dict[new_key] = old_state_dict[old_key] * output_scale
else:
new_state_dict[new_key] = old_state_dict[old_key]
def scale_embeddings(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
# Fold embeddings_scale into word/position embeddings if embedding
# layer norm *is not* enabled
if "embeddings_scale" in config["model"] and not config["model"].get(
"embedding_layer_norm", False
):
emb_scale = config["model"]["embeddings_scale"]
new_state_dict[new_key] = old_state_dict[old_key] * emb_scale
else:
new_state_dict[new_key] = old_state_dict[old_key]
def scale_embedding_layernorm(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
# Fold embeddings_scale into embedding layer norm if embedding
# layer norm *is* enabled
if "embeddings_scale" in config["model"] and config["model"].get(
"embedding_layer_norm", False
):
emb_scale = config["model"]["embeddings_scale"]
new_state_dict[new_key] = old_state_dict[old_key] * emb_scale
else:
new_state_dict[new_key] = old_state_dict[old_key]
@staticmethod
def is_mup(config):
return _is_mup_CS22(config.get('model', {}))
@staticmethod
def formats():
return ("sP", "muP")
[docs]class Converter_sP_muP_post_CS23(Converter_sP_muP_pre_CS23):
"""Transforms a CS 2.3 and onwards muP checkpoint into a CS sP checkpoint.
muP: Maximal Update Parametrization.
sP: Standard Parametrization.
"""
def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[r".+\.proj_k_dense_layer.*"],
action=self.scale_k_projection,
),
ConversionRule(
[r"(?:model\.|)lm_head\.weight"],
action=self.scale_lm_head,
),
ConversionRule(
[r"(?:model\.|)mlm_head\.weight"],
action=self.scale_lm_head,
),
ConversionRule(
[r"(?:model\.|)cls_head\.weight"],
action=self.scale_lm_head,
),
ConversionRule(
[r"(?:model\.|)embedding_layer\.word_embeddings\.weight"],
action=self.scale_embeddings,
),
ConversionRule(
[
r"(?:model\.|)embedding_layer\.position_embeddings(?:\.embed)?\.weight"
],
action=self.scale_embeddings,
),
ConversionRule(
[r"(?:model\.|)embedding_ln_f\.(?:weight|bias)"],
action=self.scale_embedding_layernorm,
),
ConversionRule(
[r".*"], action=self.replaceKey
), # Catch-all for everything else
]
def scale_k_projection(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
scale_qk_dot_by_d = config["model"].get("scale_qk_dot_by_d", None)
if scale_qk_dot_by_d is None:
scale_qk_dot_by_d = _is_mup(config["model"])
attention_logits_alpha = config["model"].get(
"attention_logits_alpha", None
)
if attention_logits_alpha is None:
attention_logits_alpha = 1.0
if scale_qk_dot_by_d:
d_model = config["model"]["hidden_size"]
n_heads = config["model"]["num_heads"]
d_sqrt = math.sqrt(d_model // n_heads)
new_state_dict[new_key] = (
attention_logits_alpha * old_state_dict[old_key] / d_sqrt
)
else:
new_state_dict[new_key] = (
attention_logits_alpha * old_state_dict[old_key]
)
def scale_lm_head(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
mup_base_hidden_size = config["model"].get("mup_base_hidden_size", None)
if mup_base_hidden_size:
output_logits_alpha = config["model"].get(
"output_logits_alpha", None
)
if output_logits_alpha is None:
output_logits_alpha = 1.0
scale_output_logits_by_d = config["model"].get(
"scale_output_logits_by_d", None
)
if scale_output_logits_by_d is None:
scale_output_logits_by_d = True
hidden_size = config["model"]["hidden_size"]
hidden_size_width_mult = hidden_size / mup_base_hidden_size
if scale_output_logits_by_d:
output_logits_scale = (
output_logits_alpha / hidden_size_width_mult
)
else:
output_logits_scale = (
output_logits_alpha / hidden_size_width_mult**0.5
)
new_state_dict[new_key] = (
old_state_dict[old_key] * output_logits_scale
)
else:
new_state_dict[new_key] = old_state_dict[old_key]
@staticmethod
def is_mup(config):
return _is_mup(config["model"])
[docs]class Converter_T5_sP_muP(BaseDictionaryConverter):
"""Transforms a T5 CS muP checkpoint into a T5 CS sP checkpoint.
muP: Maximal Update Parametrization.
sP: Standard Parametrization.
"""
def __init__(self):
super().__init__()
self.rules = [
ConversionRule(
[r".+\.proj_q_dense_layer.*"],
action=self.scale_q_projection,
),
ConversionRule(
[r".*encoder.*proj_k_dense_layer.*"],
action=self.scale_encoder_k_projection,
),
ConversionRule(
[r".*decoder.*proj_k_dense_layer.*"],
action=self.scale_decoder_k_projection,
),
ConversionRule(
[r".+\.proj_v_dense_layer.*"],
action=self.scale_v_projection,
),
ConversionRule(
[r".+\.proj_output_dense_layer.*"],
action=self.scale_output_projection,
),
ConversionRule(
[r"(?:model\.|)lm_head\.weight"],
action=self.scale_lm_head,
),
ConversionRule(
[r"(?:model\.|).*embeddings\.word_embeddings\.weight"],
action=self.scale_embeddings,
),
ConversionRule(
[r".*"], action=self.replaceKey
), # Catch-all for everything else
]
def scale_q_projection(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
d_model = config["model"]["d_model"]
mup_base_d_model = config["model"].get("mup_base_d_model", None)
if mup_base_d_model is None:
mup_base_d_model = d_model
d_model_width_mult = d_model / mup_base_d_model
d_sqrt = math.sqrt(config["model"]["d_kv"])
projection_scale = d_model_width_mult**-0.5
if config["model"].get("mup_base_d_kv", None):
projection_scale = 1.0
total_scale = projection_scale / d_sqrt
new_state_dict[new_key] = old_state_dict[old_key] * total_scale
def scale_encoder_k_projection(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
scale_qk_dot_by_d = config["model"].get(
"scale_encoder_qk_dot_by_d", None
)
if scale_qk_dot_by_d is None:
scale_qk_dot_by_d = _is_mup(config["model"])
d_model = config["model"]["d_model"]
mup_base_d_model = config["model"].get("mup_base_d_model", None)
if mup_base_d_model is None:
mup_base_d_model = d_model
d_model_width_mult = d_model / mup_base_d_model
attention_logits_alpha = config["model"].get(
"encoder_attention_logits_alpha"
)
if attention_logits_alpha is None:
attention_logits_alpha = 1.0
projection_scale = d_model_width_mult**-0.5
if config["model"].get("mup_base_d_kv", None):
projection_scale = 1.0
if scale_qk_dot_by_d:
d_sqrt = math.sqrt(config["model"]["d_kv"])
total_scale = attention_logits_alpha * projection_scale / d_sqrt
else:
total_scale = attention_logits_alpha * projection_scale
new_state_dict[new_key] = old_state_dict[old_key] * total_scale
def scale_decoder_k_projection(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
scale_qk_dot_by_d = config["model"].get(
"scale_decoder_qk_dot_by_d", None
)
if scale_qk_dot_by_d is None:
scale_qk_dot_by_d = _is_mup(config["model"])
d_model = config["model"]["d_model"]
mup_base_d_model = config["model"].get("mup_base_d_model", None)
if mup_base_d_model is None:
mup_base_d_model = d_model
d_model_width_mult = d_model / mup_base_d_model
attention_logits_alpha = config["model"].get(
"decoder_attention_logits_alpha"
)
if attention_logits_alpha is None:
attention_logits_alpha = 1.0
projection_scale = d_model_width_mult**-0.5
if config["model"].get("mup_base_d_kv", None):
projection_scale = 1.0
if scale_qk_dot_by_d:
d_sqrt = math.sqrt(config["model"]["d_kv"])
total_scale = attention_logits_alpha * projection_scale / d_sqrt
else:
total_scale = attention_logits_alpha * projection_scale
new_state_dict[new_key] = old_state_dict[old_key] * total_scale
def scale_v_projection(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
d_model = config["model"]["d_model"]
mup_base_d_model = config["model"].get("mup_base_d_model", None)
if mup_base_d_model is None:
mup_base_d_model = d_model
d_model_width_mult = d_model / mup_base_d_model
projection_scale = d_model_width_mult**-0.5
if config["model"].get("mup_base_d_kv", None):
projection_scale = 1.0
new_state_dict[new_key] = projection_scale * old_state_dict[old_key]
def scale_output_projection(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
d_model = config["model"]["d_model"]
mup_base_d_model = config["model"].get("mup_base_d_model", None)
if mup_base_d_model is None:
mup_base_d_model = d_model
d_model_width_mult = d_model / mup_base_d_model
projection_scale = d_model_width_mult**0.5
if config["model"].get("mup_base_d_kv", None):
projection_scale = 1.0
new_state_dict[new_key] = projection_scale * old_state_dict[old_key]
def scale_embeddings(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
# Fold embeddings_scale into word/position embeddings if embedding
# layer norm *is not* enabled
if _is_mup(config["model"]):
emb_alpha = config["model"].get("embeddings_alpha", None)
d_model = config["model"]["d_model"]
if not emb_alpha:
emb_alpha = 1.0
emb_scale = emb_alpha * d_model**0.5
new_state_dict[new_key] = old_state_dict[old_key] * emb_scale
else:
new_state_dict[new_key] = old_state_dict[old_key]
def scale_lm_head(
self,
old_key,
new_key,
old_state_dict,
new_state_dict,
from_index,
action_fn_args,
):
config = action_fn_args["configs"][1]
mup_base_d_model = config["model"].get("mup_base_d_model", None)
if mup_base_d_model:
output_logits_alpha = config["model"].get(
"output_logits_alpha", None
)
if output_logits_alpha is None:
output_logits_alpha = 1.0
scale_output_logits_by_d = config["model"].get(
"scale_output_logits_by_d", None
)
if scale_output_logits_by_d is None:
scale_output_logits_by_d = False
d_model = config["model"]["d_model"]
d_model_width_mult = d_model / mup_base_d_model
if scale_output_logits_by_d:
output_logits_scale = output_logits_alpha / d_model_width_mult
else:
output_logits_scale = (
output_logits_alpha / d_model_width_mult**0.5
)
new_state_dict[new_key] = (
old_state_dict[old_key] * output_logits_scale
)
else:
new_state_dict[new_key] = old_state_dict[old_key]
@staticmethod
def is_mup(config):
return _is_mup(config["model"])
@staticmethod
def formats():
return ("sP", "muP")
def _is_mup(model_config):
return any(
name.startswith('mup_base_') and model_config[name] is not None
for name in model_config
)
def _is_mup_CS22(model_config):
scale_qk_dot_by_d = model_config.get('scale_qk_dot_by_d', False)
embeddings_scale = model_config.get('embeddings_scale', None)
output_logits_scale = model_config.get('output_logits_scale', None)
all_set = scale_qk_dot_by_d and embeddings_scale and output_logits_scale
# embeddings_scale defaults to 1.0 in many models so check for the other two
any_set = scale_qk_dot_by_d or output_logits_scale
if any_set and not all_set:
raise ValueError(
"This looks like an incomplete muP config. Either all of or none of "
"\"scale_qk_dot_by_d\", \"embeddings_scale\", \"output_logits_scale\" can be "
"specified, but this config only has some that are specified."
)
return all_set