Multi-Phase Training#
On this page, you will learn how to set up Multi-Phase training using the
Trainer
class. Multi-Phase training allows
you to combine multiple training phases with different batch sizes or
max sequence lengths in a single config file or python script.
Prerequisites#
Please ensure that you have read through the next tutorials beforehand:
The rest of this page assumes that you already have at least a cursory understanding of what the Cerebras Model Zoo Trainer is and how to use the Python API.
Multi-Phase Training#
In Multi-Phase training, you may want to define several distinct training phases.
For example, the training pipeline for the Llama-3 model might involve varying batch
sizes or max sequence lengths across different phases. Each of these phases is
defined by an instance of the Trainer
.
Let’s consider an example. In the Pretraining with Upstream Validation, you’ve learned how to construct the Trainer for the Llama-3 model. Now, let’s add a new training phase with a different batch size and new max sequence length.
To define each phase you need to construct a separate
Trainer
instance. For example:
trainer:
- trainer:
...
- trainer:
...
from cerebras.modelzoo import Trainer
trainer_1 = Trainer(...)
trainer_2 = Trainer(...)
Note
The number of Trainer
instances is not
limited and each Trainer
can have different
parameters, so you can construct arbitrary training/validation pipelines including
different models, dataloders, etc.
For each phase we define different batch size and different max sequence lengths.
trainer:
- trainer:
init: &init
backend:
backend_type: CSX
cluster_config:
num_csx: 16
seed: 2024
model:
# Embedding
vocab_size: 128256
hidden_size: 4096
position_embedding_type: "rotary"
pos_scaling_factor: 1.0
rope_theta: 500000.0
rotary_dim: 128
share_embedding_weights: false
max_position_embeddings: 8192
embedding_dropout_rate: 0.0
embedding_layer_norm: false
# Decoder
num_hidden_layers: 32
dropout_rate: 0.0
layer_norm_epsilon: 1.0e-5
norm_type: "rmsnorm"
# Decoder - Attention
num_heads: 32
attention_type: "scaled_dot_product"
attention_module: "multiquery_attention"
attention_dropout_rate: 0.0
use_projection_bias_in_attention: false
use_ffn_bias_in_attention: false
extra_attention_params:
num_kv_groups: 8
# Decoder - ffn
filter_size: 14336
nonlinearity: "swiglu"
use_ffn_bias: false
# Task-specific
use_bias_in_output: false
loss_scaling: "num_tokens"
loss_weight: 1.0
# Initializer
initializer_range: 0.02
# Cerebras parameters
mixed_precision: True
fp16_type: "cbfloat16"
optimizer:
AdamW:
betas: [0.9, 0.95]
correct_bias: True
weight_decay: 0.1
schedulers:
- CosineDecayLR:
initial_learning_rate: 3.0e-5
end_learning_rate: 3.0e-6
total_iters: 528
precision:
fp16_type: cbfloat16
loss_scaling_factor: dynamic
max_gradient_norm: 1.0
loop:
num_steps: 10000
eval_frequency: 1000
eval_steps: 1000
checkpoint:
steps: 1000
callbacks:
- ComputeNorm: {}
- CheckLoss: {}
- ModelEvalMetrics: {}
loggers:
- ProgressLogger: {}
- TensorBoardLogger: {}
fit:
train_dataloader:
data_processor: GptHDF5MapDataProcessor
data_dir: "/data/llama_v3_dataset_vocab128256_msl8192/train"
batch_size: 80
shuffle: False
shuffle_seed: 1337
num_workers: 8
prefetch_factor: 10
persistent_workers: True # Important to avoid seeding at each epoch
val_dataloader:
- data_processor: GptHDF5MapDataProcessor
data_dir: "/data/llama_v3_dataset_vocab128256_msl8192/val"
batch_size: 80
shuffle: False
shuffle_seed: 1337
num_workers: 8
prefetch_factor: 10
persistent_workers: True # Important to avoid seeding at each epoch
- trainer:
init:
<<: *init
fit:
train_dataloader:
data_processor: GptHDF5MapDataProcessor
data_dir: "/data/llama_v3_dataset_vocab128256_msl512/train"
batch_size: 40
shuffle: False
shuffle_seed: 1337
num_workers: 8
prefetch_factor: 10
persistent_workers: True # Important to avoid seeding at each epoch
val_dataloader:
- data_processor: GptHDF5MapDataProcessor
data_dir: "/data/llama_v3_dataset_vocab128256_msl512/val"
batch_size: 40
shuffle: False
shuffle_seed: 1337
num_workers: 8
prefetch_factor: 10
persistent_workers: True # Important to avoid seeding at each epoch
import cerebras.pytorch as cstorch
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import (
CheckLoss,
Checkpoint,
ComputeNorm,
MixedPrecision,
ModelEvalMetrics,
TrainingLoop,
)
from cerebras.modelzoo.trainer.loggers import (
ProgessLogger,
TensorBoardLogger,
)
from cerebras.modelzoo.models.nlp.gpt2.model import Gpt2Model as LLaMA3
trainer = Trainer(
backend=cstorch.backend(
backend_type="CSX",
cluster_config=cstorch.distributed.ClusterConfig(
num_csx=16,
),
),
seed=2024,
model=lambda: LLaMA3(
# Embedding
vocab_size=128256,
hidden_size=4096,
position_embedding_type="rotary",
pos_scaling_factor=1.0,
rope_theta=500000.0,
rotary_dim=128,
share_embedding_weights=False,
max_position_embeddings=8192,
embedding_dropout_rate=0.0,
embedding_layer_norm=False,
# Decoder
num_hidden_layers=32,
dropout_rate=0.0,
layer_norm_epsilon=1.0e-5,
norm_type="rmsnorm",
# Decoder - Attention
num_heads=32,
attention_type="scaled_dot_product",
attention_module="multiquery_attention",
attention_dropout_rate=0.0,
use_projection_bias_in_attention=False,
use_ffn_bias_in_attention=False,
extra_attention_params=dict(
num_kv_groups=8,
),
# Decoder - ffn
filter_size=14336,
nonlinearity="swiglu",
use_ffn_bias=False,
# Task-specific
use_bias_in_output=False,
loss_scaling="num_tokens",
loss_weight=1.0,
# Initializer
initializer_range=0.02,
# Cerebras parameters
mixed_precision=True,
fp16_type="cbfloat16",
),
optimizer=lambda model: cstorch.optim.AdamW(
model.parameters(),
lr=0.01, # This is a placeholder
betas=[0.9, 0.95],
correct_bias=True,
weight_decay: 0.1,
),
schedulers=[
lambda optimizer: cstorch.optim.lr_scheduler.CosineDecayLR(
initial_learning_rate=3.0e-5,
end_learning_rate=3.0e-6,
total_iters=528,
)
],
precision=MixedPrecision(
fp16_type="cbfloat16",
loss_scaling_factor="dynamic",
max_gradient_norm=1.0,
),
loop=TrainingLoop(
num_steps=10000,
eval_frequency=1000,
eval_steps=1000,
),
checkpoint=Checkpoint(steps=1000),
callbacks=[
ComputeNorm(),
CheckLoss(),
ModelEvalMetrics(),
],
loggers=[
ProgessLogger(),
TensorBoardLogger(),
],
)
# Running training phase with batch_size=80 and MSL=8192.
trainer.fit(
train_dataloader=cstorch.utils.data.DataLoader(
registry.get_data_processor("GptHDF5MapDataProcessor"),
data_dir="/data/llama_v3_dataset_vocab128256_msl8192/train",
batch_size=80,
shuffle=False,
shuffle_seed=1337,
num_workers=8,
prefetch_factor=10,
persistent_workers=True, # Important to avoid seeding at each epoch
),
val_dataloader=[
cstorch.utils.data.DataLoader(
registry.get_data_processor("GptHDF5MapDataProcessor"),
data_dir="/data/llama_v3_dataset_vocab128256_msl8192/val",
batch_size=80,
shuffle=False,
shuffle_seed=1337,
num_workers=8,
prefetch_factor=10,
persistent_workers=True, # Important to avoid seeding at each epoch
),
]
)
# Running training phase with batch_size=40 and MSL=512.
trainer.fit(
train_dataloader=cstorch.utils.data.DataLoader(
registry.get_data_processor("GptHDF5MapDataProcessor"),
data_dir="/data/llama_v3_dataset_vocab128256_msl512/train",
batch_size=40,
shuffle=False,
shuffle_seed=1337,
num_workers=8,
prefetch_factor=10,
persistent_workers=True, # Important to avoid seeding at each epoch
),
val_dataloader=[
cstorch.utils.data.DataLoader(
registry.get_data_processor("GptHDF5MapDataProcessor"),
data_dir="/data/llama_v3_dataset_vocab128256_msl512/val",
batch_size=40,
shuffle=False,
shuffle_seed=1337,
num_workers=8,
prefetch_factor=10,
persistent_workers=True, # Important to avoid seeding at each epoch
),
]
)
Note
It’s important to note that when using YAML, you have to construct a
Trainer
instance for each phase, which adds
some overhead to your run due to time spent on compile and weights transfer.
If you are using Python API, you can construct a single
Trainer
object and call
fit
using
different DataLoader
objects.
Multi-Phase Training (Advanced)#
A more advanced example of Multi-Phase training involves changing model parameters
between training phases. For instance, you might want to switch the learning rate
scheduler from CosineDecayLR
to
ConstantLR
. To accomplish this,
you need to create two instances of the Trainer
and
carefully manage checkpoint loading between phases to account for the changes in
model parameters.
Note
In the example below, please note that the model, optimizer, and other parameters are similar to those in the previous example. These parameters have been omitted to simplify the example.
trainer:
- trainer:
init:
...
schedulers:
- CosineDecayLR:
initial_learning_rate: 3.0e-5
end_learning_rate: 3.0e-6
total_iters: 528
- trainer:
init:
...
schedulers:
- ConstantLR:
learning_rate: 1.0e-6
callbacks:
- LoadCheckpointStates:
load_checkpoint_states: "model,grad_scaler,optimizer,global_step"
import cerebras.pytorch as cstorch
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import (
LoadCheckpointStates,
)
trainer = Trainer(
...
schedulers=[
lambda optimizer: cstorch.optim.lr_scheduler.CosineDecayLR(
initial_learning_rate=3.0e-5,
end_learning_rate=3.0e-6,
total_iters=528,
)
],
)
trainer.fit(...)
trainer = Trainer(
...
schedulers=[
lambda optimizer: cstorch.optim.lr_scheduler.ConstantLR(
learning_rate=1.0e-6,
)
],
callbacks=[
LoadCheckpointStates(
load_checkpoint_states="model,grad_scaler,optimizer,global_step",
)
],
)
trainer.fit(...)
In this example, each Trainer
constructs and compiles
a model where in the second phase we changed the scheduler to
ConstantLR
, so to avoid any issues with
checkpoint loading we specify which parameters needs to be loaded. For further reading please
follow Checkpointing.
Caveats#
When running Multi-Phase training using Python API, you may hit an issue:
RuntimeError: Cannot instantiate multiple backends. A backend with type CSX has already been instantiated.
Please ensure that when you construct a Trainer
, you only
instantiate a single backend. For example:
backend = cstorch.backend(
"CSX",
...
)
trainer1 = Trainer(
backend=backend,
...
)
trainer2 = Trainer(
backend=backend,
...
)
Conclusion#
This tutorial showcases some of the use cases where Multi-Phase training can be applied. However, you are not limited to these examples and can construct as many Trainers as you need, combining different models, schedulers, optimizers, dataloaders, and more.