Configure μP for T5 (Beta)#

Note

Beta support indicates that we have confirmed the model’s μP functionality using coordinate checks, but have yet to perform long convergence runs.

T5 support places many more limitations on the model than the others:

  • μP is only supported with use_transformer_initialization set to False.

  • Using Adafactor as usual with T5 in contrast to using the usual Adam optimizers in our other μP models may cause instability.

  • Varying d_kv and num_heads between the proxy model and the target model may cause instability.

Model Params#

  • mup_base_d_model (Required to enable μP):

    The d_model of the base model in μP transfer used to calculate the necessary multipliers.

  • mup_base_d_ff (Required to enable μP):

    The d_ff of the base model in μP transfer used to calculate the necessary multipliers.

  • mup_base_d_kv (Highly Experimental):

    The d_kv of the base model in μP transfer used to calculate the necessary multipliers. Only use when varying d_kv alongside d_model.

  • embeddings_alpha:

    Scales the embedding hidden states (i.e. the tensor after embeddings & embedding layer norm are applied). The embeddings are scaled by embeddings_alpha * sqrt(d_model). Recommended to tune for stabilizing gradient flow during μP training.

  • output_logits_alpha`:

    Constant applied to the output logits scalar in μP training. The output logits are scaled by output_logits_alpha * mup_base_d_model/d_model. Recommended to tune for stabilizing output logits in μP training.

  • scale_encoder_qk_dot_by_d:

    Scales encoder attention QK dot product by d instead of sqrt(d). Must be enabled for μP training.

  • scale_decoder_qk_dot_by_d:

    Scales decoder attention QK dot product by d instead of sqrt(d). Must be enabled for μP training.

  • encoder_attention_logits_alpha:

    Additionally scales the encoder attention QK dot product by the specified value. Recommended to tune for stabilizing attention logits in μP training.

  • decoder_attention_logits_alpha:

    Additionally scales the decoder attention QK dot product by the specified value. Recommended to tune for stabilizing attention logits in μP training.

  • scale_output_logits_by_d:

    Scales the output logits in μP by mup_base_d_model/d_model if True and sqrt(mup_base_d_model/d_model) if False. Unlike other models, the default for the μP implementation in this model is False.

Supported LR Adjustment Groups#

  • embedding: Targets the embedding weights.

  • encoder_qkv_projection: Targets the Q, K, V projection dense layers in the encoder.

  • encoder_output_projection: Targets the output projection dense layer in the encoder.

  • encoder_input_ffn: Targets the first of the two FFN blocks in the encoder.

  • encoder_output_ffn: Targets the final FFN block in the encoder.

  • decoder_qkv_projection: Targets the Q, K, V projection dense layers in the decoder.

  • decoder_output_projection: Targets the output projection dense layer in the decoder.

  • decoder_input_ffn: Targets the first of the two FFN blocks in the decoder.

  • decoder_output_ffn: Targets the final FFN block in the decoder.

  • lm_head: Targets the weights of the lm_head.

Example Configuration#

model:
  ...
  # muP
  scale_encoder_qk_dot_by_d: True
  scale_decoder_qk_dot_by_d: True
  scale_output_logits_by_d: False
  mup_base_d_model: ...
  mup_base_d_ff: ...
  output_logits_alpha: ...
  encoder_attention_logits_alpha: ...
  decoder_attention_logits_alpha: ...
  embeddings_alpha: ...

optimizer:
  ...
  adjust_learning_rate:
    embedding: ...
    decoder_input_ffn: ...
    lm_head: ...
    ...

Note how not all the LR adjustment groups were specified since there is only a need to provide these values in the config if you would like to override the traditional μP LR scaling.