Checkpointing#
On this page, you will learn about how to configure the checkpointing behavior of the
Trainer
with a Checkpoint
object.
By the end you should have a cursory understanding on how to use the
Checkpoint
class in
conjunction with the Trainer
class.
Prerequisites#
Configure Trainer Checkpoint Behavior#
Primary checkpointing functionality is done using the
Checkpoint
core callback. You
can control the cadence at which you save checkpoints, the naming convention of
checkpoints saved, and various other useful functionalities. For details on all
options, see Checkpoint
.
An example of a checkpoint configuration is shown here:
trainer:
init:
checkpoint:
steps: 100
save_initial_checkpoint: True
checkpoint_name: "checkpoint_{step}.mdl"
...
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import Checkpoint
trainer = Trainer(
...,
checkpoint=Checkpoint(
steps=100,
save_initial_checkpoint=True,
checkpoint_name="checkpoint_{step}.mdl",
),
)
In this example, you will save a checkpoint every 100 train steps. You will also save an initial checkpoint prior to training. The saved checkpoints will be named:
["checkpoint_0.mdl", "checkpoint_100.mdl", "checkpoint_200.mdl", ...]
Automatically loading from the most recent checkpoint#
The autoload_last_checkpoint
can be used to autoload the most recent checkpoint
from model_dir
. If you have the following checkpoints in model_dir
:
["checkpoint_0.mdl", ..., "checkpoint_19900.mdl", "checkpoint_20000.mdl"]
If you enable autoload_last_checkpoint
like in the example below, the run will
automatically load from the checkpoint with the largest step value, in this case
"checkpoint_20000.mdl"
.
trainer:
init:
checkpoint:
steps: 100
autoload_last_checkpoint: True
...
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import Checkpoint
trainer = Trainer(
...,
checkpoint=Checkpoint(
steps=100,
autoload_last_checkpoint=True
),
)
Checkpoint loading strictness#
The disable_strict_checkpoint_loading
option can be used to loosen the validation
done when loading a checkpoint. If True, the model will not raise an error if the
checkpoint contains keys that are not present in the model.
trainer:
init:
checkpoint:
steps: 100
disable_strict_checkpoint_loading: True
...
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import Checkpoint
trainer = Trainer(
...,
checkpoint=Checkpoint(
steps=100,
disable_strict_checkpoint_loading=True
),
)
Selective Checkpoint State Saving#
You can specify which individual checkpoint states to be saved using the
SaveCheckpointState
callback.
The SaveCheckpointState
callback
allows us to:
- Save an alternative checkpoint with a subset of states to conserve storage space.
- Can be used to bypass checkpoint deletion policies.
In the example below, you will save an alternative checkpoint every 5 checkpoints
saved (500 steps) that only contains the "model"
state.
Note
k
in SaveCheckpointState
refers to taking an alterative checkpoint every k
checkpoint steps, not
every k
steps.
trainer:
init:
checkpoint:
steps: 100
callbacks:
- SaveCheckpointState:
k: 5
checkpoint_states: "model"
...
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import (
Checkpoint,
SaveCheckpointState,
)
trainer = Trainer(
...,
checkpoint=Checkpoint(
steps=100,
),
callbacks=[
SaveCheckpointState(k=5, checkpoint_states="model"),
],
)
Selective Checkpoint State Loading#
You can specify which individual checkpoint states to be loaded using the
LoadCheckpointStates
callback.
The LoadCheckpointStates
callback
allows us to:
Perform fine-tuning, by loading the model state but starting the optimizer state from scratch and the global step from 0.
In the example below, you configure the Trainer to load only the "model"
state from any checkpoint.
trainer:
init:
checkpoint:
steps: 100
callbacks:
- LoadCheckpointStates:
load_checkpoint_states: "model"
...
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import (
Checkpoint,
LoadCheckpointStates,
)
trainer = Trainer(
...,
checkpoint=Checkpoint(
steps=100,
),
callbacks=[
LoadCheckpointStates(load_checkpoint_states="model"),
],
)
Checkpoint Deletion Policy#
For long runs with limited storage space, it is important to have a way to control
how checkpoints are deleted or retained. To control the number of checkpoints
retained, use KeepNCheckpoints
.
The KeepNCheckpoints
callback
allows us to:
- Constrain the amount of storage space checkpoints take up while still allowing for recent restart points in case a run is interrupted.
- If you want to still keep long-term checkpoints over a larger cadence for validation purposes, checkpoints generated by SaveCheckpointState
are ignored by KeepNCheckpoints
(see Selective Checkpoint State Saving for more details).
In the example below, only the 5 most recent checkpoints will be retained.
trainer:
init:
checkpoint:
steps: 100
callbacks:
- KeepNCheckpoints:
n: 5
...
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import (
Checkpoint,
LoadCheckpointStates,
)
trainer = Trainer(
...,
checkpoint=Checkpoint(
steps=100,
),
callbacks=[
KeepNCheckpoints(n=5),
],
)
What’s next#
To learn how to use advanced checkpointing to do a fine-tuning run, see see Fine-Tuning with Validation.
Further Reading#
To learn about how you can configure a Trainer
instance using a YAML configuration file, you can check out:
Trainer YAML Overview
To learn more about how you can use the Trainer
in some core workflows, you can check out:
To learn more about how you can extend the capabilities of the
Trainer
class, you can check out: