Configure μP for BERT Pretrain (Beta)#

Note

Beta support indicates that we have confirmed the model’s μP functionality using coordinate checks, but have yet to perform long convergence runs.

Model Params#

  • mup_base_hidden_size (Required to enable μP):

    The hidden size of the proxy model in μP transfer used to calculate the necessary multipliers.

  • mup_base_filter_size (Required to enable μP):

    The filter size of the proxy model in μP transfer used to calculate the necessary multipliers.

  • embeddings_scale:

    Scales the embedding hidden states (i.e. the tensor after embeddings & embedding layer norm are applied). Recommended to tune for stabilizing gradient flow during μP training.

  • output_logits_alpha:

    Constant applied to the output logits scalar in μP training. The output logits are scaled by output_logits_alpha * mup_base_hidden_size/hidden_size. Recommended to tune for stabilizing output logits in μP training.

  • scale_qk_dot_by_d:

    Scales attention QK dot product by d instead of sqrt(d). Must be enabled for muP training.

  • attention_logits_alpha:

    Scales the attention QK dot product by the specified value. Recommended to tune for stabilizing attention logits in muP training.

  • scale_output_logits_by_d:

    Scales the output logits in μP by mup_base_hidden_size/hidden_size if True and sqrt(mup_base_hidden_size/hidden_size) if False. It is traditionally set to True in the μP implementation of this model.

Supported LR Adjustment Groups#

  • embedding: Targets the embedding weights.

  • encoder_attention: Targets the dense layers in the encoder (Q, K, V, Output projections)

  • encoder_input_ffn: Targets the first of the two FFN blocks in the encoder.

  • encoder_output_ffn: Targets the final FFN block in the encoder.

  • pooler: Targets the pooler FFN.

Example Configuration#

model:
  ...
  # muP
  scale_qk_dot_by_d: True
  scale_output_logits_by_d: True
  mup_base_hidden_size: ...
  mup_base_filter_size: ...
  output_logits_alpha: ...
  attention_logits_alpha: ...
  embeddings_scale: ...

optimizer:
  ...
  adjust_learning_rate:
    embedding: ...
    encoder_input_ffn: ...
    ...

Note how not all the LR adjustment groups were specified since there is only a need to provide these values in the config if you would like to override the traditional μP LR scaling.