Configure μP for BERT Pretrain (Beta)#
Note
Beta support indicates that we have confirmed the model’s μP functionality using coordinate checks, but have yet to perform long convergence runs.
Model Params#
mup_base_hidden_size
(Required to enable μP):The hidden size of the proxy model in μP transfer used to calculate the necessary multipliers.
mup_base_filter_size
(Required to enable μP):The filter size of the proxy model in μP transfer used to calculate the necessary multipliers.
embeddings_scale
:Scales the embedding hidden states (i.e. the tensor after embeddings & embedding layer norm are applied). Recommended to tune for stabilizing gradient flow during μP training.
output_logits_alpha
:Constant applied to the output logits scalar in μP training. The output logits are scaled by
output_logits_alpha * mup_base_hidden_size/hidden_size
. Recommended to tune for stabilizing output logits in μP training.
scale_qk_dot_by_d
:Scales attention QK dot product by d instead of sqrt(d). Must be enabled for muP training.
attention_logits_alpha
:Scales the attention QK dot product by the specified value. Recommended to tune for stabilizing attention logits in muP training.
scale_output_logits_by_d
:Scales the output logits in μP by
mup_base_hidden_size/hidden_size
if True andsqrt(mup_base_hidden_size/hidden_size)
if False. It is traditionally set toTrue
in the μP implementation of this model.
Supported LR Adjustment Groups#
embedding
: Targets the embedding weights.encoder_attention
: Targets the dense layers in the encoder (Q, K, V, Output projections)encoder_input_ffn
: Targets the first of the two FFN blocks in the encoder.encoder_output_ffn
: Targets the final FFN block in the encoder.pooler
: Targets the pooler FFN.
Example Configuration#
model:
...
# muP
scale_qk_dot_by_d: True
scale_output_logits_by_d: True
mup_base_hidden_size: ...
mup_base_filter_size: ...
output_logits_alpha: ...
attention_logits_alpha: ...
embeddings_scale: ...
optimizer:
...
adjust_learning_rate:
embedding: ...
encoder_input_ffn: ...
...
Note how not all the LR adjustment groups were specified since there is only a need to provide these values in the config if you would like to override the traditional μP LR scaling.