Configure μP for GPT-Style Models#
Models affected: GPT-2, GPT-3, GPTJ (Beta), Bloom, Llama, Falcon, Starcoder, and MPT
Note
Beta support for GPTJ indicates that we have confirmed the model’s μP functionality using coordinate checks, but have yet to perform long convergence runs.
Model Params#
mup_base_hidden_size
(Required to enable μP):The hidden size of the proxy model in μP transfer used to calculate the necessary multipliers.
mup_base_filter_size
(Required to enable μP):The filter size of the proxy model in μP transfer used to calculate the necessary multipliers.
embeddings_scale
:Scales the embedding hidden states (i.e. the tensor after embeddings & embedding layer norm are applied). Recommended to tune for stabilizing gradient flow during μP training.
output_logits_alpha
:Constant applied to the output logits scalar in μP training. The output logits are scaled by
output_logits_alpha * mup_base_hidden_size/hidden_size
. Recommended to tune for stabilizing output logits in μP training.
scale_qk_dot_by_d
:Scales attention QK dot product by d instead of sqrt(d). Must be enabled for muP training.
attention_logits_alpha
:Scales the attention QK dot product by the specified value. Recommended to tune for stabilizing attention logits in muP training.
scale_output_logits_by_d
:Scales the output logits in μP by
mup_base_hidden_size/hidden_size
if True andsqrt(mup_base_hidden_size/hidden_size)
if False. It is traditionally set toTrue
in the μP implementation of this model.
Supported LR Adjustment Groups#
embedding
: Targets the embedding weights.decoder_attention
: Targets the dense layers in the decoder (Q, K, V, Output projections)decoder_input_ffn
: Targets the first of the two FFN blocks in the decoder.decoder_output_ffn
: Targets the final FFN block in the decoder.
Example Configuration#
model:
...
# muP
scale_qk_dot_by_d: True
scale_output_logits_by_d: True
mup_base_hidden_size: ...
mup_base_filter_size: ...
output_logits_alpha: ...
attention_logits_alpha: ...
embeddings_scale: ...
optimizer:
...
adjust_learning_rate:
embedding: ...
decoder_input_ffn: ...
...
Note how not all the LR adjustment groups were specified since there is only a need to provide these values in the config if you would like to override the traditional μP LR scaling.