# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""Utilities for saving and loading checkpoints."""
import os
from typing import IO, Any, Callable, Union
import torch
import cerebras.pytorch as cstorch
from cerebras.appliance import logger
from cerebras.appliance.utils.file import StrPath, get_path_size, is_pathlike
from cerebras.appliance.utils.memory import (
get_available_memory,
with_memory_info_logged,
)
from cerebras.pytorch.backend import current_backend_impl
from cerebras.pytorch.saver.checkpoint_reader import CheckpointReader
from .pt_h5_saver import PyTorchH5Saver
# A file-like object, which has to implement `read`, `readline`, `tell`, and
# `seek` methods.
_CkptFileT = Union[StrPath, IO]
_MapLocT = Union[str, torch.device, Callable, dict, None]
_StateDictT = Any
[docs]def save(obj: dict, checkpoint_file: str) -> None:
"""Save a PyTorch state dict to the given file.
Args:
obj: The object to save.
checkpoint_file: The path to save the object to.
"""
backend = current_backend_impl(raise_exception=False)
if backend is None:
logger.debug(
f"No Cerebras backend found. Defaulting to using CPU for "
f"saving."
)
saver = PyTorchH5Saver()
saver.save(checkpoint_file, obj)
else:
backend.save(obj, checkpoint_file)
logger.verbose(f"Successfully saved checkpoint to {checkpoint_file}")
@with_memory_info_logged(
"loading checkpoint",
info=["available", "used"],
logger=logger,
)
def load(
checkpoint_file: _CkptFileT,
map_location: _MapLocT = None,
**kwargs,
) -> _StateDictT:
"""Load a PyTorch checkpoint from a file.
Args:
checkpoint_file: The path to the checkpoint to load.
map_location: A mapping of where to load the checkpoint content to.
If the map_location is `None`, then the tensors will be lazily loaded
from the checkpoint file every single time the tensor is accessed.
If the map_location is "cache", then the tensors will be cached
once they are lazily loaded from the checkpoint file.
If the map location is "cpu", then the tensors will be eagerly loaded
into memory from the checkpoint file.
**kwargs: Additional keyword arguments to pass to the vanilla torch
checkpoint loader. These are ignored if the checkpoint is a
Cerebras HDF5 checkpoint.
Returns:
The loaded checkpoint file.
Raises:
RuntimeError: If the checkpoint file does not exist or checkpoint is not
a valid HDF5 or vanilla torch checkpoint.
"""
if not is_pathlike(
checkpoint_file
) or not PyTorchH5Saver.is_valid_checkpoint(checkpoint_file):
logger.debug(
f"Checkpoint is not a valid HDF5 checkpoint. Falling back to "
f"normal PyTorch checkpoint loading."
)
return _torch_load(checkpoint_file, map_location, **kwargs)
logger.debug(
f"Checkpoint is a valid HDF5 checkpoint. Using the HDF5 checkpoint "
f"loader."
)
res = _cstorch_load(checkpoint_file, map_location, **kwargs)
logger.debug(f"Loaded HDF5 checkpoint {checkpoint_file}.")
return res
def _cstorch_load(
checkpoint_file: _CkptFileT,
map_location: _MapLocT = None,
**kwargs,
) -> _StateDictT:
cache_tensors = False
if map_location == "cache":
cache_tensors = True
map_location = None
if map_location is not None:
if isinstance(map_location, (str, torch.device)):
map_location = torch.device(map_location)
else:
raise TypeError(
f"Unsupported `map_location` provided for loading HDF5 "
f"checkpoint. Expected `None` or a torch device, but got "
f"`{map_location}`"
)
CheckpointReader.saver_cls = PyTorchH5Saver
reader = CheckpointReader(checkpoint_file)
tensor_names = reader.tensor_names
spec = reader.spec
if not spec:
raise RuntimeError(
f"Checkpoint `{checkpoint_file}` is an HDF5 file but does not "
f"conform to the Cerebras HDF5 checkpoint specification. Please "
f"ensure that the checkpoint was saved using `cstorch.save()`."
)
from cerebras.pytorch.utils.nest import recurse_spec
spec_keys = list(map(".".join, recurse_spec(spec)))
unique_spec_keys = set(spec_keys)
missing = unique_spec_keys - set(tensor_names)
present = unique_spec_keys - missing
if missing:
logger.warning(
f"The checkpoint is missing the following keys that are "
f"found in the spec: {sorted(missing)}"
)
backend = current_backend_impl(raise_exception=False)
if backend is None:
logger.debug(
"No backend has been initialized. Loading tensors onto CPU."
)
map_location = torch.device("cpu")
saver = PyTorchH5Saver()
ckpt_version = saver.extract_version(checkpoint_file)
with cstorch.saver.storage.cache_deferred_tensors(cache_tensors):
values = []
# Load all (present) values in one file read, as flocking the h5 file
# can be expensive on some filesystems.
# Because we're using PyTorchH5Saver(), they aren't actually in memory,
# each torch.Tensor is backed by a DeferredH5Tensor.
vals = saver.load(checkpoint_file, present)
for key in spec_keys:
if key in vals:
val = vals[key]
if map_location is not None and isinstance(val, torch.Tensor):
val = val.to(map_location)
else:
val = None
values.append(val)
# pylint: disable=protected-access
state_dict = torch.utils._pytree.tree_unflatten(values, spec)
logger.debug(f"Loaded HDF5 checkpoint {checkpoint_file}.")
return state_dict
def _torch_load(
checkpoint_file: _CkptFileT,
map_location: _MapLocT = None,
**kwargs,
) -> _StateDictT:
"""Load a PyTorch checkpoint using vanilla torch.load.
Args:
checkpoint_file: The path to the checkpoint to load.
map_location: A mapping of where to load the checkpoint content to.
**kwargs: Additional keyword arguments to pass to torch.load.
"""
if is_pathlike(checkpoint_file) and os.path.exists(checkpoint_file):
unit = "GB"
file_size = get_path_size(checkpoint_file, unit=unit)
free_mem = get_available_memory(unit=unit)
if file_size > 10:
backend = current_backend_impl(raise_exception=False)
if backend is not None and backend.backend_type.is_csx:
extra_msg = ", could significantly slow down weight transfer,"
else:
extra_msg = ""
logger.warning(
f"Checkpoint file is a vanilla torch checkpoint and has "
f"size {file_size} {unit}. This may take a while to load"
f"{extra_msg} and may occupy a large amount of memory."
)
if file_size > free_mem:
logger.warning(
f"Checkpoint file is a vanilla torch checkpoint and has "
f"size {file_size} {unit}, which is larger than the "
f"currently available memory {free_mem} {unit}. Since "
f"torch checkpoints are loaded in their entirety into "
f"memory, this may cause out-of-memory errors."
)
try:
state_dict = torch.load(
checkpoint_file, map_location=map_location, **kwargs
)
except FileNotFoundError as e:
# Error message is already descriptive enough
raise e
except Exception as e:
raise RuntimeError(
f"Failed to load checkpoint file `{checkpoint_file}`."
) from e
logger.debug(f"Loaded checkpoint {checkpoint_file} into memory.")
return state_dict
__all__ = ["save", "load"]