Configure μP for T5 (Beta)#
Note
Beta support indicates that we have confirmed the model’s μP functionality using coordinate checks, but have yet to perform long convergence runs.
T5 support places many more limitations on the model than the others:
μP is only supported with
use_transformer_initialization
set toFalse
.Using Adafactor as usual with T5 in contrast to using the usual Adam optimizers in our other μP models may cause instability.
Varying
d_kv
andnum_heads
between the proxy model and the target model may cause instability.
Model Params#
mup_base_d_model
(Required to enable μP):The d_model of the base model in μP transfer used to calculate the necessary multipliers.
mup_base_d_ff
(Required to enable μP):The d_ff of the base model in μP transfer used to calculate the necessary multipliers.
mup_base_d_kv
(Highly Experimental):The d_kv of the base model in μP transfer used to calculate the necessary multipliers. Only use when varying d_kv alongside d_model.
embeddings_alpha
:Scales the embedding hidden states (i.e. the tensor after embeddings & embedding layer norm are applied). The embeddings are scaled by
embeddings_alpha * sqrt(d_model)
. Recommended to tune for stabilizing gradient flow during μP training.
output_logits_alpha`
:Constant applied to the output logits scalar in μP training. The output logits are scaled by
output_logits_alpha * mup_base_d_model/d_model
. Recommended to tune for stabilizing output logits in μP training.
scale_encoder_qk_dot_by_d
:Scales encoder attention QK dot product by d instead of sqrt(d). Must be enabled for μP training.
scale_decoder_qk_dot_by_d
:Scales decoder attention QK dot product by d instead of sqrt(d). Must be enabled for μP training.
encoder_attention_logits_alpha
:Additionally scales the encoder attention QK dot product by the specified value. Recommended to tune for stabilizing attention logits in μP training.
decoder_attention_logits_alpha
:Additionally scales the decoder attention QK dot product by the specified value. Recommended to tune for stabilizing attention logits in μP training.
scale_output_logits_by_d
:Scales the output logits in μP by
mup_base_d_model/d_model
ifTrue
andsqrt(mup_base_d_model/d_model)
ifFalse
. Unlike other models, the default for the μP implementation in this model isFalse
.
Supported LR Adjustment Groups#
embedding
: Targets the embedding weights.encoder_qkv_projection
: Targets the Q, K, V projection dense layers in the encoder.encoder_output_projection
: Targets the output projection dense layer in the encoder.encoder_input_ffn
: Targets the first of the two FFN blocks in the encoder.encoder_output_ffn
: Targets the final FFN block in the encoder.decoder_qkv_projection
: Targets the Q, K, V projection dense layers in the decoder.decoder_output_projection
: Targets the output projection dense layer in the decoder.decoder_input_ffn
: Targets the first of the two FFN blocks in the decoder.decoder_output_ffn
: Targets the final FFN block in the decoder.lm_head
: Targets the weights of the lm_head.
Example Configuration#
model:
...
# muP
scale_encoder_qk_dot_by_d: True
scale_decoder_qk_dot_by_d: True
scale_output_logits_by_d: False
mup_base_d_model: ...
mup_base_d_ff: ...
output_logits_alpha: ...
encoder_attention_logits_alpha: ...
decoder_attention_logits_alpha: ...
embeddings_alpha: ...
optimizer:
...
adjust_learning_rate:
embedding: ...
decoder_input_ffn: ...
lm_head: ...
...
Note how not all the LR adjustment groups were specified since there is only a need to provide these values in the config if you would like to override the traditional μP LR scaling.