# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpointing callback that aids in saving and loading model states."""
import re
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from string import Formatter
from typing import List, Optional, Union
import cerebras.pytorch as cstorch
from cerebras.modelzoo.common.pytorch_utils import (
check_checkpoint_compatibility,
)
from cerebras.modelzoo.trainer.callbacks import Callback
[docs]class Checkpoint(Callback):
"""A callback that handles standard checkpointing logic."""
def __init__(
self,
steps: Optional[int] = None,
autoload_last_checkpoint: bool = True,
disable_strict_checkpoint_loading: bool = False,
save_initial_checkpoint: bool = False,
checkpoint_name: str = "checkpoint_{step}.mdl",
):
"""
Args:
steps: The frequency at which to save a checkpoint. If None, no
checkpoints will be saved. Defaults to None.
autoload_last_checkpoint: Whether to autoload the last
checkpoint in the model directory. Defaults to True.
disable_strict_checkpoint_loading: Whether to disable
strict checkpoint loading. If True, the model will not raise an
error if the checkpoint contains keys that are not present in the
model. Defaults to False.
save_initial_checkpoint: Whether to save the initial
checkpoint at the start of training. Defaults to False.
checkpoint_name: The unformatted name of the checkpoint file. The string will be
formatted with the following keys: `step`
"""
super().__init__()
if not (steps is None or (isinstance(steps, int) and steps >= 0)):
raise ValueError(
f"Checkpoint steps must be None or a non-negative integer, "
f"but got {steps}"
)
self.steps = steps
self.autoload_last_checkpoint = autoload_last_checkpoint
self.disable_strict_checkpoint_loading = (
disable_strict_checkpoint_loading
)
self.save_initial_checkpoint = save_initial_checkpoint
self.checkpoint_name = checkpoint_name
self.model_dir = None
self._stack_size = 0
keys = set(
fname
for _, fname, _, _ in Formatter().parse(self.checkpoint_name)
if fname
)
expected_keys = {"step"}
if keys != expected_keys:
raise ValueError(
f"Found invalid keys in checkpoint name. "
f"Expected keys: {expected_keys}. Got: {keys}"
)
@contextmanager
def _on_enter(self):
try:
self._stack_size += 1
yield
finally:
self._stack_size -= 1
def on_enter_fit(
self, trainer, stack, train_dataloader, val_dataloader, loop
):
stack.enter_context(self._on_enter())
def on_enter_validate_all(self, trainer, stack, val_dataloaders, loop):
stack.enter_context(self._on_enter())
def on_enter_validate(self, trainer, stack, val_dataloader, loop):
stack.enter_context(self._on_enter())
def on_train_start(self, trainer, model, train_dataloader, loop, loop_idx):
if (
loop_idx == 0
and self.save_initial_checkpoint
and trainer.backend.is_e2e_execution
):
trainer.save_checkpoint()
def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx):
if self.steps:
# Call checkpoint closure every iteration and let the
# closure handle calling it at the correct interval
trainer.save_checkpoint()
def on_before_load_checkpoint(self, trainer, ckpt_path):
if ckpt_path is None:
if self._stack_size <= 1:
trainer.logger.info(
f"No checkpoint was provided. "
"Using randomly initialized model parameters."
)
else:
trainer.logger.info(f"Loading weights from checkpoint {ckpt_path}")
def preprocess_checkpoint(self, trainer, state_dict):
check_checkpoint_compatibility(state_dict)
def on_save_checkpoint(self, trainer, state_dict):
trainer.logger.info(f"Saving checkpoint at step {trainer.global_step}")
def on_after_save_checkpoint(self, trainer, ckpt_path):
trainer.logger.info(f"Saved checkpoint {ckpt_path}")
[docs] def get_checkpoint_path(self, ckpt_dir: str, step: int) -> Path:
"""Construct a path to the checkpoint file.
If a checkpoint already exists inside the given checkpoint directory at
the given step, append a timestamp to the filename.
Args:
ckpt_dir: The directory where the checkpoint will be saved.
step: The step at which the checkpoint is saved.
Returns:
A path to which the checkpoint can be saved
"""
# Keep in sync with self.get_all_checkpoints().
ckpt_dir = Path(ckpt_dir)
ckpt_path = ckpt_dir / self.checkpoint_name.format(step=step)
if ckpt_path.exists():
ckpt_path = ckpt_dir / self.checkpoint_name.format(
step=f"{step}_{datetime.now():%Y%m%d_%H%M%S}"
)
return ckpt_path
[docs] def get_latest_checkpoint(self, trainer):
"""Return the path to the latest checkpoint."""
trainer.logger.info(
f"Checkpoint autoloading is enabled. Looking for latest checkpoint "
f"in \"{trainer.model_dir}\" directory with the following naming "
f"convention: `checkpoint_(step)(_timestamp)?.mdl`."
)
ckpts = self.get_all_checkpoints(trainer.model_dir)
ckpt_path = ckpts[-1] if ckpts else None
if ckpt_path:
trainer.logger.info(f"Found latest checkpoint at \"{ckpt_path}\".")
else:
trainer.logger.info(
f"No checkpoints were found in \"{trainer.model_dir}\"."
)
return ckpt_path
[docs] def get_all_checkpoints(self, model_dir: str) -> List[str]:
"""Return the path to all available checkpoints.
Args:
model_dir: The directory where the checkpoints are located.
"""
ckpts = []
# Keep in sync with self.get_checkpoint_path().
pattern = re.compile(
self.checkpoint_name.format(
step=r"(?P<step>\d+)(?:_(?P<timestamp>\d{8}_\d{6}))?"
)
)
for checkpoint in Path(model_dir).glob("*"):
match = pattern.match(checkpoint.name)
if not match:
continue
step = int(match.group("step"))
timestamp = match.group("timestamp")
if timestamp is not None:
try:
date = datetime.strptime(timestamp, "%Y%m%d_%H%M%S")
except ValueError:
continue
else:
date = datetime.min
ckpts.append((checkpoint, step, date))
# sort by step and then by timestamp
return (
[ckpt[0] for ckpt in sorted(ckpts, key=lambda x: (x[1], x[2]))]
if ckpts
else []
)
[docs]class LoadCheckpointStates(Callback):
"""
Callback to load specific states of the model.
"""
def __init__(
self,
load_checkpoint_states: Union[str, List[str]] = "all",
):
"""
Args:
load_checkpoint_states: The list of state names to load from the checkpoint.
"""
states = load_checkpoint_states
if isinstance(states, str) and states.lower() == "all":
self.load_checkpoint_states = states
else:
if isinstance(states, str):
states = states.split(",")
if isinstance(states, (list, tuple, set)) and all(
isinstance(s, str) for s in states
):
self.load_checkpoint_states = set(states)
else:
raise TypeError(
f"Expected `load_checkpoint_states` to be one of the following: "
f"\n\t1. \"all\" to load all checkpoint states."
f"\n\t2. A comma-separated string of checkpoint states to load."
f"\n\t3. List of checkpoint state names to load."
f"\nBut got type \"{type(load_checkpoint_states)}\" with value "
f"{load_checkpoint_states}."
)
[docs] def preprocess_checkpoint(self, trainer, state_dict):
if self.load_checkpoint_states == "all":
# Load all states, nothing to do
return
checkpoint_states = set(state_dict)
# Check that the specified states are valid checkpoint states
if not self.load_checkpoint_states.issubset(checkpoint_states):
raise KeyError(
"Unexpected keys specified via `load_checkpoint_states`: "
f"{', '.join(self.load_checkpoint_states - checkpoint_states)} "
"Only the keys in the following list are accepted: "
f"{', '.join(checkpoint_states)}"
)
if keys := (checkpoint_states - self.load_checkpoint_states):
trainer.logger.info(
f"Opting out of loading the following state(s) as they are "
f"not included in \"load_checkpoint_states\": {', '.join(sorted(keys))}"
)
for key in keys:
state_dict.pop(key, None)
[docs]class SaveCheckpointState(Callback):
"""
Callback to save an alternative checkpoint file that contains a subset of states and is not
affected by deletion policies.
"""
def __init__(
self,
k: int,
checkpoint_states: Union[str, List[str]] = "model",
checkpoint_name: str = "{checkpoint_states}_{ckpt_name}",
):
"""
Args:
k: Cadence at which alternative checkpoint is saved.
Specifes after how many checkpoints saved an alternative checkpoint is saved.
For example, if a full checkpoint is taken every 100 steps and k=5, then an
alternative checkpoint is saved every 500 steps.
checkpoint_states: List of valid checkpoint states to save. Can be a single state or
list of states or 'all' (all states).
checkpoint_name: Prefix to add to the alternative checkpoint file name.
The name will be formatted with the following keys:
* ``checkpoint_states``: ``_`` separated list of checkpoint states
* ``ckpt_name``: original checkpoint file name
"""
if not isinstance(k, int) or k < 1:
raise ValueError(f"Expected k to be a positive integer. Got: {k}")
self.k = k
self.count = 0
if isinstance(checkpoint_states, str):
if checkpoint_states.lower() == "all":
self.checkpoint_states = ["all"]
else:
self.checkpoint_states = list(checkpoint_states.split(","))
elif isinstance(checkpoint_states, (list, tuple, set)):
self.checkpoint_states = list(checkpoint_states)
else:
raise TypeError(
"Expected `checkpoint_states` to be a string, list, tuple, or set. "
f"Got: {type(checkpoint_states)}"
)
self.checkpoint_name = checkpoint_name
keys = set(
fname
for _, fname, _, _ in Formatter().parse(self.checkpoint_name)
if fname
)
expected_keys = {"checkpoint_states", "ckpt_name"}
if keys != expected_keys:
raise ValueError(
f"Found invalid keys in checkpoint name. "
f"Expected keys: {expected_keys}. Got: {keys}"
)
self._is_last_loop = False
self._is_last_step = False
[docs] def on_train_start(self, trainer, model, train_dataloader, loop, loop_idx):
self._is_last_loop = loop_idx == loop.num_trains - 1
[docs] def on_train_batch_start(self, trainer, model, batch, batch_idx):
self._is_last_step = (
self._is_last_loop and trainer.executor.on_final_iteration
)
[docs] def on_after_save_checkpoint(self, trainer, ckpt_path):
self.count += 1
if self.count >= self.k or self._is_last_step:
state_dict = cstorch.load(ckpt_path)
ckpt_keys = set(state_dict)
if self.checkpoint_states == ["all"]:
checkpoint_states = ckpt_keys
else:
checkpoint_states = set(self.checkpoint_states)
# Check that the specified states are valid checkpoint states
if not checkpoint_states.issubset(ckpt_keys):
raise KeyError(
"Unexpected keys specified via `load_checkpoint_states`: "
f"{', '.join(checkpoint_states - ckpt_keys)} "
"Only the keys in the following list are accepted: "
f"{', '.join(ckpt_keys)}"
)
# Store all keys specified and all "metadata"-like keys that begin with "__"
subset_dict = {
key: state_dict[key]
for key in ckpt_keys
if key in checkpoint_states or key.startswith("__")
}
ckpt_name = self.checkpoint_name.format(
checkpoint_states="_".join(self.checkpoint_states),
ckpt_name=ckpt_path.name,
)
cstorch.save(subset_dict, ckpt_path.parent / ckpt_name)
self.count = 0
[docs]class KeepNCheckpoints(Callback):
"""Callback to regulate the maximum number of checkpoints retained."""
def __init__(self, n: Optional[int] = None):
"""
Args:
n: Number of checkpoint files to keep. If the number of checkpoint files saved exceeds
this number, checkpoint files are deleted starting with the oldest one. Does not
affect checkpoints taken from previous runs.
If n is None, no checkpoints are deleted.
"""
if n is None:
n = float("inf")
elif not isinstance(n, int) or n < 1:
raise ValueError(f"Expected n to be a positive integer. Got: {n}")
self.n = n
self.ckpt_paths = []
[docs] def on_after_save_checkpoint(self, trainer, ckpt_path):
self.ckpt_paths.append(ckpt_path)
if len(self.ckpt_paths) > self.n:
ckpt_path = self.ckpt_paths.pop(0)
ckpt_path.unlink(missing_ok=True)