Configure μP for GPT-Style Models#

Models affected: GPT-2, GPT-3, GPTJ (Beta), Bloom, Llama, Falcon, Starcoder, and MPT

Note

Beta support for GPTJ indicates that we have confirmed the model’s μP functionality using coordinate checks, but have yet to perform long convergence runs.

Model Params#

  • mup_base_hidden_size (Required to enable μP):

    The hidden size of the proxy model in μP transfer used to calculate the necessary multipliers.

  • mup_base_filter_size (Required to enable μP):

    The filter size of the proxy model in μP transfer used to calculate the necessary multipliers.

  • embeddings_scale:

    Scales the embedding hidden states (i.e. the tensor after embeddings & embedding layer norm are applied). Recommended to tune for stabilizing gradient flow during μP training.

  • output_logits_alpha:

    Constant applied to the output logits scalar in μP training. The output logits are scaled by output_logits_alpha * mup_base_hidden_size/hidden_size. Recommended to tune for stabilizing output logits in μP training.

  • scale_qk_dot_by_d:

    Scales attention QK dot product by d instead of sqrt(d). Must be enabled for muP training.

  • attention_logits_alpha:

    Scales the attention QK dot product by the specified value. Recommended to tune for stabilizing attention logits in muP training.

  • scale_output_logits_by_d:

    Scales the output logits in μP by mup_base_hidden_size/hidden_size if True and sqrt(mup_base_hidden_size/hidden_size) if False. It is traditionally set to True in the μP implementation of this model.

Supported LR Adjustment Groups#

  • embedding: Targets the embedding weights.

  • decoder_attention: Targets the dense layers in the decoder (Q, K, V, Output projections)

  • decoder_input_ffn: Targets the first of the two FFN blocks in the decoder.

  • decoder_output_ffn: Targets the final FFN block in the decoder.

Example Configuration#

model:
  ...
  # muP
  scale_qk_dot_by_d: True
  scale_output_logits_by_d: True
  mup_base_hidden_size: ...
  mup_base_filter_size: ...
  output_logits_alpha: ...
  attention_logits_alpha: ...
  embeddings_scale: ...

optimizer:
  ...
  adjust_learning_rate:
    embedding: ...
    decoder_input_ffn: ...
    ...

Note how not all the LR adjustment groups were specified since there is only a need to provide these values in the config if you would like to override the traditional μP LR scaling.