# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The ModelZoo Trainer class is the main entry point for training models in ModelZoo.
It is responsible for setting up the training environment, running the training/validation loop,
and saving the checkpoint.
"""
from __future__ import annotations
import os
import uuid
from collections import Counter, OrderedDict
from contextlib import ExitStack, contextmanager, nullcontext
from copy import copy
from functools import wraps
from logging import Logger as PythonLogger
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Union, final
from warnings import warn
from weakref import finalize
import torch
import cerebras.pytorch as cstorch
from cerebras.modelzoo.trainer.callbacks import (
GLOBAL_CALLBACK_REGISTRY,
ArtifactDirCallback,
BackendCallback,
Callback,
Checkpoint,
DataLoaderCallback,
GradientAccumulationCallback,
Logging,
LoopCallback,
ModelCallback,
OptimizerCallback,
Precision,
Reproducibility,
SchedulersCallback,
SchedulersInput,
SparsityCallback,
TrainingLoop,
ValidationCallback,
ValidationLoop,
)
from cerebras.modelzoo.trainer.loggers import Logger
from cerebras.modelzoo.trainer.utils import convert_output_to_dict
from cerebras.pytorch.backend import Backend
from cerebras.pytorch.optim import Optimizer
from cerebras.pytorch.sparse import SparsityAlgorithm
UNSPECIFIED = object()
[docs]class Trainer:
"""The Trainer class is the main entry point for training models in ModelZoo."""
@final
def __init__(
self,
device: Optional[str] = None,
backend: Optional[Backend] = None,
model_dir: str = UNSPECIFIED,
model: Union[
Callable[[], torch.nn.Module], torch.nn.Module
] = UNSPECIFIED,
optimizer: Union[
Optimizer,
Callable[[torch.nn.Module], Optimizer],
None,
] = None,
schedulers: SchedulersInput = None,
precision: Optional[Precision] = None,
sparsity: Optional[SparsityAlgorithm] = None,
# Training args
loop: Optional[LoopCallback] = None,
checkpoint: Optional[Checkpoint] = None,
logging: Optional[Logging] = None,
# Trainer args
callbacks: Optional[List[Callback]] = None,
loggers: Optional[List[Logger]] = None,
seed: Optional[int] = None,
):
"""
Args:
device: The device to train the model on. It must be one of "CSX",
"CPU", or "GPU".
backend: The backend used to train the model. This argument is mutually
exclusive with `device`.
model_dir: The directory where the model artifacts are saved.
model: The model to train. It must be one of the following:
- If a callable is passed, it is assumed to be a function that
takes in no arguments returns a torch.nn.Module.
- If a torch.nn.Module is passed, it is used as is.
optimizer: The optimizer used to optimize the model. It must be one of the following:
- If a :py:class:`~cerebras.pytorch.optim.Optimizer` is passed, it is used as is.
- If a callable is passed, it is assumed to be a function that
takes in a torch.nn.Module and returns a
:py:class:`~cerebras.pytorch.optim.Optimizer`.
- If not passed, then assume that only validation will be run.
schedulers: The set of optimizer schedulers to be used. Common schedulers include LR
schedulers. It must be a list of these items:
- If a cstorch.optim.scheduler.Scheduler is passed, it is used as is.
- A callable that is assumed to be a function that takes in a
:py:class:`~cerebras.pytorch.optim.Optimizer` and returns a
cstorch.optim.scheduler.Scheduler.
- If None, there is no optimizer param group scheduling.
precision: The Precision callback used during training
sparsity: The sparsity algorithm used to sparsify weights during training/validation
It must be one of the following:
- If a callable is passed, it is assumed to be a function that
takes in no arguments returns a
:py:class:`~cerebras.pytorch.sparse.SparsityAlgorithm`.
- If a :py:class:`~cerebras.pytorch.sparse.SparsityAlgorithm` is
passed, it is used as is.
loop: The loop callback to use for training. It must be an instance of LoopCallback.
If not provided, the default loop is TrainingLoop(num_epochs=1).
checkpoint: The checkpoint callback to use for saving/loading checkpoints. It must
be an instance of Checkpoints. If not provided, then no checkpoints are saved.
logging: The logging callback used to set up python logging. This callback also
controls when logs are supposed to be logged. If not provided, the default
logging settings ``Logging(log_steps=1, log_level="INFO")`` are used.
callbacks: A list of callbacks to used by the trainer. The order in
which the callbacks are provided is important as it determines
the order in which the callback's hooks are executed.
loggers: A list of loggers to use for logging.
seed: Initial seed for the torch random number generator.
"""
super().__init__()
if model_dir is UNSPECIFIED:
raise ValueError("model_dir is a required argument")
self.model_dir = Path(model_dir)
if model is UNSPECIFIED:
raise ValueError("model is a required argument")
if not isinstance(model, torch.nn.Module) and isinstance(
optimizer, cstorch.optim.Optimizer
):
raise ValueError(
f"Expected optimizer to be a callable that takes in a torch.nn.Module "
f"and returns a cstorch.optim.Optimizer. Got: {type(optimizer)}"
)
# Attributes set by core callbacks
self.artifact_dir: Path
self.summary_dir: Path
self.backend: "cstorch.backend.Backend"
self.model: torch.nn.Module
self.compiled_model: Callable
self.dataloader: cstorch.utils.data.DataLoader
self.optimizer: Optional[cstorch.optim.Optimizer]
self.schedulers: Optional[List[cstorch.optim.scheduler.Scheduler]]
self.executor: cstorch.utils.data.DataExecutor
self.global_step: int
# Other attributes that callbacks may set
self.activation_steps: Optional[int] = None
if precision is not None and not isinstance(precision, Precision):
raise TypeError(
f"Expected precision to be an instance of Precision. "
f"Got: {type(precision)}"
)
if loop is None:
loop = TrainingLoop(num_epochs=1)
elif not isinstance(loop, LoopCallback):
raise TypeError(
f"Expected loop to be an instance of LoopCallback."
f"Got: {type(loop)}"
)
if checkpoint is None:
checkpoint = Checkpoint()
elif not isinstance(checkpoint, Checkpoint):
raise TypeError(
f"Expected checkpoint to be an instance of Checkpoint. "
f"Got: {type(checkpoint)}"
)
if logging is None:
logging = Logging(log_steps=1, log_level="INFO")
elif not isinstance(logging, Logging):
raise TypeError(
f"Expected logging to be an instance of Logging. "
f"Got: {type(logging)}"
)
# Order of the core callbacks is important and should not be changed
self.callbacks = OrderedDict(
{
"artifact_dir": ArtifactDirCallback(),
"reproducibility": Reproducibility(seed),
"backend": BackendCallback(backend, device),
"model": ModelCallback(model),
"dataloader": DataLoaderCallback(),
"optimizer": OptimizerCallback(optimizer),
"schedulers": SchedulersCallback(schedulers),
"precision": precision,
"sparsity": SparsityCallback(sparsity),
"loop": loop,
"logging": logging,
"grad_accum": GradientAccumulationCallback(),
"checkpoint": checkpoint,
}
)
core_callback_types = set(
type(callback)
for callback in self.callbacks.values()
if callback is not None
) | {Precision, LoopCallback}
core_callback_types = tuple(core_callback_types)
user_callbacks = callbacks or []
for callback in user_callbacks:
if isinstance(callback, Logger):
warn(
f"Passed logger {type(callback)} as a `callback`. "
f"It will not be used by trainer.log_metrics(). "
f"To use for logging metrics, pass it as a `logger` "
f"instead."
)
self.loggers = loggers or []
for logger in self.loggers:
if not isinstance(logger, Logger):
raise TypeError(
f"logger must be an instance of Logger. Got: {type(logger)}"
)
counter = Counter(self.callbacks.keys())
def get_name(callback):
name = type(callback).__name__
counter[name] += 1
if counter[name] > 1:
return f"{name}_{counter[name]}"
return name
for callback in user_callbacks + self.loggers:
if isinstance(callback, core_callback_types):
raise ValueError(
f"Callback {type(callback).__name__} is a core callback "
f"and cannot be overridden"
)
self.callbacks[get_name(callback)] = callback
# Checkpoint should be the last callback as saving a checkpoint is slow
self.callbacks.move_to_end("checkpoint")
# Whitelist of non-standard hooks that callbacks can implement
self.non_standard_hooks_whitelist = set()
# ID map for validation dataloaders
self._name_scope_stack = []
self._val_dataloader_id_map = {}
self._val_dataloader_id = 0
self.call("pre_setup")
self.call("setup")
def callback_finalize(callbacks):
for callback in callbacks:
if callback is not None:
callback.finalize()
self._finalizer = finalize(
self, callback_finalize, self.callbacks.values()
)
@property
def all_callbacks(self) -> Generator[Callback, None, None]:
"""Get all callback objects available to the trainer."""
yield from self.callbacks.values()
yield from GLOBAL_CALLBACK_REGISTRY.values()
[docs] def get_callbacks(
self, callback_type: type
) -> Generator[Callback, None, None]:
"""Get all callbacks of the given type."""
for callback in self.all_callbacks:
if isinstance(callback, callback_type):
yield callback
[docs] def get_callback(self, callback_type: type) -> Optional[Callback]:
"""Get the first callback of the given type."""
return next(self.get_callbacks(callback_type), None)
@property
def validation_callbacks(self) -> List[ValidationCallback]:
"""Returns all validation callbacks in the Trainer's callback list."""
return list(self.get_callbacks(ValidationCallback))
[docs] def call(self, hook_name: str, *args, **kwargs):
"""
Call the hook with name hook_name for all callbacks
in the Trainer's callback list as well as the callbacks
in the global registry.
The callback's method is passed in the trainer object itself
as well as any args and kwargs that are passed into this method. e.g.
.. code: python
getattr(callback, hook_name)(self, *args, **kwargs)
Args:
hook_name: The name of the hook to call. It must be the name of
a method in the Callback class.
args: Other positional arguments to forward
along to the called hook.
kwargs: Other keyword arguments to forward
along to the called hook.
"""
seen = set()
for callback in self.all_callbacks:
if callback is None:
continue
if callback in seen:
warn(f"Duplicate callback found in the list: {callback}")
else:
seen.add(callback)
hook = getattr(callback, hook_name, None)
if hook:
try:
hook(self, *args, **kwargs)
except Exception as e:
# TODO(SW-128935): Add this to the exception notes once
# we support python 3.11
raise RuntimeError(
f"Encountered error when calling "
f"{type(callback).__name__}.{hook_name}"
) from e
elif hook_name not in self.non_standard_hooks_whitelist:
raise AttributeError(
f"Callback {type(callback)} does not implement {hook_name}"
)
@property
def precision(self) -> Optional[Precision]:
"""Returns the precision callback instance if it exists."""
return self.callbacks["precision"]
@property
def grad_accum(self) -> GradientAccumulationCallback:
"""Returns the gradient accumulation callback instance."""
return self.callbacks["grad_accum"]
@property
def should_run_optimizer_step(self) -> bool:
"""Returns True if we should run the optimizer step.
The gradient accumulation callback may set this to False if we are
accumulating gradients and have not reached the accumulation steps.
Note, this only applies to CPU/GPU runs.
"""
return self.grad_accum.should_run_optimizer_step
@property
def loop(self) -> LoopCallback:
"""Returns the default loop settings."""
return self.callbacks["loop"]
@property
def checkpoint(self) -> Checkpoint:
"""Returns the checkpoint callback."""
return self.callbacks["checkpoint"]
@property
def logging(self) -> Checkpoint:
"""Returns the logging callback."""
return self.callbacks["logging"]
@property
def logger(self) -> PythonLogger:
"""Returns the Trainer's Python logger object."""
return self.logging.logger
@property
def is_log_step(self) -> bool:
"""Returns True if the current step is a log step."""
return self.logging.is_log_step(self)
@property
def is_first_iteration(self) -> bool:
"""Returns True if the executor is on its first iteration."""
return self.executor and self.executor.iteration == 0
@property
def is_final_iteration(self) -> bool:
"""Returns True if the executor is on its final iteration."""
return self.executor and self.executor.on_final_iteration
@final
@property
def is_tracing(self) -> bool:
"""Returns True if we are currently tracing the model."""
return self.backend.is_tracing
[docs] @final
@cstorch.step_closure
def log_metrics_in_step_closure(self, **kwargs):
"""Log the given kwargs inside a step closure."""
# This is a sanity check and should never assert
assert not self.is_tracing
self.log_metrics(**kwargs)
[docs] @final
def log_metrics(self, **kwargs):
"""
Log the given kwargs to all loggers.
Example usage:
.. code:: python
trainer.log_metrics(loss=loss.item())
Args:
kwargs: The key-value pairs to log.
"""
if self.is_tracing:
# If we are tracing, log the kwargs inside a step closure
self.log_metrics_in_step_closure(**kwargs)
else:
# If we're not inside a logging step and an executor is not
# active, don't send to loggers. The executor check is to
# handle the case where a user wants to log something after
# execution.
if self.executor is not None and not self.is_log_step:
return
if not self.loggers:
warn(
"No loggers are attached to the trainer. "
"Call to trainer.log_metrics() will be a no-op."
)
from cerebras.modelzoo.trainer.loggers import TensorBoardLogger
# Don't add prefixes to metric names if the tensorboard logger
# is configured to output logs to legacy event directories
if any(
logger.legacy_event_dirs
for logger in self.get_callbacks(TensorBoardLogger)
):
prefix = ""
else:
prefix = self.name_scope_path
# Otherwise, log the kwargs directly
for logger in self.loggers:
logger.log_metrics(
{f"{prefix}{k}": v for k, v in kwargs.items()},
step=self.global_step,
)
[docs] @final
@contextmanager
def name_scope(self, name: str):
"""Append name to the trainer's name scope stack whilst inside the
context.
Args:
name: The name to append to the name scope stack.
"""
try:
self._name_scope_stack.append(name)
yield name
finally:
self._name_scope_stack.pop()
@final
@property
def name_scope_path(self) -> str:
"""Returns the current name scope path.
This is the the name scope stack joined by '/'.
"""
return os.path.join(*self._name_scope_stack, "")
[docs] @final
def get_val_dataloader_scope(self, val_dataloader):
"""Get the name scope for the given val dataloader"""
if val_dataloader.id not in self._val_dataloader_id_map:
self._val_dataloader_id_map[val_dataloader.id] = (
f"validate_{self._val_dataloader_id}"
)
self._val_dataloader_id += 1
return self._val_dataloader_id_map[val_dataloader.id]
[docs] @final
@cstorch.trace
def training_step(self, batch) -> Dict[str, Any]:
"""Run a single training step on the given batch.
Note that if retrace is off, content of this method will only run on
the first iteration. So any inputs to this method must either be
non-changing or torch tensors.
Args:
batch: The batch of data to train on.
batch_idx: The index of the batch in the dataloader.
Returns:
A dictionary containing the loss and any other outputs.
"""
outputs = self.forward(batch)
self.backward(outputs)
# Only run the optimizer step if an optimizer was defined
# and we should run the optimizer step
if self.optimizer and self.should_run_optimizer_step:
self.optimizer_step()
self.optimizer_zero_grad()
self.schedulers_step()
return outputs
[docs] @final
def forward(self, batch) -> Dict[str, Any]:
"""
Run the forward pass on the given batch.
Args:
batch: The batch of data to run the forward pass on.
Returns:
A dictionary containing the loss and any other outputs.
"""
if self.precision:
ctx = self.precision.autocast_context_manager()
else:
ctx = nullcontext()
with ctx:
args = [batch]
kwargs = {}
self.call("on_before_forward", self.model, batch, args, kwargs)
output = self.compiled_model(*args, **kwargs)
outputs = convert_output_to_dict(output)
self.call("on_after_forward", self.model, outputs, batch)
return outputs
[docs] @final
def backward(self, outputs: dict):
"""
Run the backward pass on the given loss.
Args:
outputs: The outputs of the model. Expect key 'loss' to be present.
"""
self.call("on_before_backward", self.model, outputs)
loss = outputs["loss"]
if self.precision:
self.precision.backward(loss)
else:
loss.backward()
self.call("on_after_backward", self.model, outputs)
[docs] @final
def optimizer_step(self):
"""Run the optimizer step."""
self.call("on_before_optimizer_step", self.model, self.optimizer)
if self.precision:
self.precision.clip_gradients(self.optimizer)
self.precision.optimizer_step(self.optimizer)
else:
self.optimizer.step()
self.call("on_after_optimizer_step", self.model, self.optimizer)
[docs] @final
def optimizer_zero_grad(self):
"""Zero the gradients of the optimizer."""
self.call("on_before_optimizer_zero_grad", self.model, self.optimizer)
self.optimizer.zero_grad()
self.call("on_after_optimizer_zero_grad", self.model, self.optimizer)
[docs] @final
def schedulers_step(self):
"""Step all the schedulers."""
if not self.schedulers:
return
for scheduler in self.schedulers:
self.call(
"on_before_scheduler_step",
self.model,
self.optimizer,
scheduler,
)
scheduler.step()
self.call(
"on_after_scheduler_step",
self.model,
self.optimizer,
scheduler,
)
[docs] @contextmanager
def on_exception(self, hook):
"""Context manager to handle exceptions in the given hook.
Args:
hook: The hook to handle exceptions for.
"""
try:
yield
except Exception as e:
try:
self.call(f"on_{hook}_exception", e)
except Exception as e2:
raise e2 from e
raise
[docs] @final
def fit(
self,
train_dataloader: cstorch.utils.data.DataLoader,
val_dataloader: Union[
cstorch.utils.data.DataLoader,
List[cstorch.utils.data.DataLoader],
None,
] = None,
ckpt_path: Optional[str] = UNSPECIFIED,
):
"""Complete a full training run on the given train and validation dataloaders.
Args:
train_dataloader: The training dataloader.
val_dataloader: The validation dataloader.
If provided, validation is run every `eval_frequency` steps as defined
in the loop callback.
If not provided, only training is run.
If a list of dataloaders is provided, then each dataloader is
validated in sequence.
ckpt_path: The path to the checkpoint to load before starting training.
If not provided and `autoload_last_checkpoint` is True,
then the latest checkpoint is loaded
"""
loop = self.loop
if not isinstance(loop, TrainingLoop):
raise TypeError(
f"Expected loop to be an instance of TrainingLoop. "
f"Got: {type(loop)}"
)
with ExitStack() as fit_stack:
self.call(
"on_enter_fit",
fit_stack,
train_dataloader,
val_dataloader,
loop,
)
self.load_checkpoint(ckpt_path)
self.call("on_fit_start", train_dataloader, val_dataloader, loop)
for loop_idx in range(loop.num_trains):
self._run_train(train_dataloader, loop, loop_idx)
if loop.eval_frequency is not None and loop.eval_frequency != 0:
# Run upstream and downstream validation after each training iteration
self._validate_all(
val_dataloader,
ckpt_paths=None,
loop=None,
# pylint: disable=cell-var-from-loop
run_validation=lambda: self.call(
"run_validation",
loop_idx=loop_idx,
is_last=loop_idx == loop.num_trains - 1,
),
)
self.call("on_fit_end", loop)
@final
def _run_train(self, train_dataloader, loop, loop_idx=0):
if not isinstance(loop, TrainingLoop):
raise TypeError(
f"Expected loop to be an instance of TrainingLoop. "
f"Got: {type(loop)}"
)
with ExitStack() as stack:
self.call("on_enter_train", stack, train_dataloader, loop, loop_idx)
self.call(
"on_train_start", self.model, train_dataloader, loop, loop_idx
)
self.executor = cstorch.utils.data.DataExecutor(
train_dataloader,
num_steps=loop.train_steps,
checkpoint_steps=loop.checkpoint_steps,
activation_steps=self.activation_steps,
profiler_activities=[], # Don't use data executor's profiler
)
for batch_idx, batch in enumerate(self.executor):
self.call("on_train_batch_start", self.model, batch, batch_idx)
outputs = self.training_step(batch)
self.call(
"on_train_batch_end", self.model, outputs, batch, batch_idx
)
self.call("on_train_end", self.model, loop, loop_idx)
self.executor = None
@final
@cstorch.trace
@torch.no_grad()
def validation_step(self, batch: Any) -> Dict[str, Any]:
"""Run a single validation step on the given batch and batch index.
Note that if retrace is off, content of this method will only run on
the first iteration. So any inputs to this method must either be
constant or torch tensors.
Args:
batch: The batch of data to validate on.
Returns:
A dictionary containing the loss and any other outputs.
"""
return self.forward(batch)
[docs] @final
def validate(
self,
val_dataloader: Optional[cstorch.utils.data.DataLoader] = None,
ckpt_path: Optional[str] = UNSPECIFIED,
loop: Optional[ValidationLoop] = None,
):
"""Complete a full validation run on the validation dataloader.
Args:
val_dataloader: The validation dataloader.
If a list of dataloaders is provided, then each dataloader is
validated in sequence.
ckpt_path: The path to the checkpoint to load before starting validation.
If not provided and `autoload_last_checkpoint` is True,
then the latest checkpoint is loaded.
loop: The loop callback to use for validation. If not provided, the default
loop is used. If provided, it must be an instance of ValidationLoop.
Note, this should only be provided if the loop callback provided in
the constructor is not sufficient.
"""
if not isinstance(val_dataloader, cstorch.utils.data.DataLoader):
raise TypeError(
f"val_dataloader must be a cstorch.utils.data.DataLoader. "
f"Got {type(val_dataloader)}"
)
if not loop:
loop = self.loop
if isinstance(loop, TrainingLoop):
loop = loop.val_loop
if not isinstance(loop, ValidationLoop):
raise TypeError(
f"Expected loop to be an instance of ValidationLoop. "
f"Got: {type(loop)}"
)
with ExitStack() as stack:
stack.enter_context(
self.name_scope(self.get_val_dataloader_scope(val_dataloader))
)
if loop not in self.all_callbacks:
stack.enter_context(loop)
self.call("on_enter_validate", stack, val_dataloader, loop)
self.load_checkpoint(ckpt_path)
self.call(loop.on_start_hook, self.model, val_dataloader, loop)
# For every validation run we want to iterate the dataloader from
# the scratch, so we make a shallow copy of the validation dataloader,
# so the streamer will treat it as new dataloader.
self.executor = cstorch.utils.data.DataExecutor(
copy(val_dataloader),
num_steps=loop.eval_steps,
profiler_activities=[], # Don't use data executor's profiler
)
for batch_idx, batch in enumerate(self.executor):
self.call(
loop.on_batch_start_hook, self.model, batch, batch_idx
)
outputs = self.validation_step(batch)
self.call(
loop.on_batch_end_hook,
self.model,
outputs,
batch,
batch_idx,
)
self.call(loop.on_end_hook, self.model, loop)
self.executor = None
[docs] @final
def validate_all(
self,
val_dataloaders: Union[
List[cstorch.utils.data.DataLoader],
cstorch.utils.data.DataLoader,
None,
] = None,
ckpt_paths: Union[List[str], str, None] = UNSPECIFIED,
loop: Optional[ValidationLoop] = None,
):
"""
Runs all upstream and downstream validation permutations.
.. code:: python
for ckpt_path in ckpt_paths:
for val_dataloader in val_dataloaders:
trainer.validate(val_dataloader, ckpt_path)
# run downstream validation
run_validation(...)
Args:
val_dataloaders: A list of validation dataloaders to run validation on.
ckpt_paths: A list of checkpoint paths to run validation on. Each checkpoint
path must be a path to a checkpoint file, or a glob pattern.
loop: The validation loop to use for validation. If not provided, then the
default loop is used.
"""
self._validate_all(
val_dataloaders,
ckpt_paths,
loop,
run_validation=lambda: self.call(
"run_validation", loop_idx=None, is_last=True
),
)
@final
@wraps(validate_all)
def _validate_all(self, val_dataloaders, ckpt_paths, loop, run_validation):
if not callable(run_validation) and run_validation is not None:
raise RuntimeError(
f"Expected run_validation to be a callable or None. "
f"Got: {run_validation}"
)
if val_dataloaders is None:
val_dataloaders = []
elif not isinstance(val_dataloaders, (list, tuple)):
val_dataloaders = [val_dataloaders]
for i, val_dataloader in enumerate(val_dataloaders):
if not isinstance(val_dataloader, cstorch.utils.data.DataLoader):
raise TypeError(
f"Expected val_dataloader to be an instance of DataLoader. "
f"Got {type(val_dataloader)} for element {i}"
)
if ckpt_paths is UNSPECIFIED or ckpt_paths is None:
ckpt_paths = [ckpt_paths]
else:
if not isinstance(ckpt_paths, (list, tuple)):
ckpt_paths = [ckpt_paths]
for ckpt_path in ckpt_paths:
if not isinstance(ckpt_path, (str, Path)):
raise ValueError(
f"Expected ckpt_path to be a path to a checkpoint file, "
f"or a glob pattern. Got: {ckpt_path}"
)
ckpt_path = Path(ckpt_path)
if not any(ckpt_path.parent.glob(ckpt_path.name)):
raise FileNotFoundError(
f"Checkpoint file(s) not found at: {ckpt_path}"
)
# Flatten all ckpt paths into a list
ckpt_paths = (
checkpoint_file
for ckpt_path in map(Path, ckpt_paths)
for checkpoint_file in ckpt_path.parent.glob(ckpt_path.name)
)
with ExitStack() as stack:
self.call("on_enter_validate_all", stack, val_dataloaders, loop)
for ckpt_path in ckpt_paths:
# Load the checkpoint
self.load_checkpoint(ckpt_path)
# Run upstream validation
for val_dataloader in val_dataloaders:
self.validate(val_dataloader, ckpt_path=None, loop=loop)
# Run downstream validation
if run_validation is not None:
run_validation()
[docs] @final
@cstorch.checkpoint_closure
def save_checkpoint(self):
"""Save a checkpoint at the current global step.
The checkpoint state dict is constructed by various callbacks
that implement the `on_save_checkpoint` method.
"""
state_dict = {}
self.call("on_save_checkpoint", state_dict)
self.call("postprocess_checkpoint", state_dict)
ckpt_path = self.checkpoint.get_checkpoint_path(
self.model_dir, self.global_step
)
# atomic checkpoint save by first writing to a temp file, then renaming
tmp_ckpt_path = Path(f"{ckpt_path}.{str(uuid.uuid4())[:8]}.tmp")
cstorch.save(state_dict, tmp_ckpt_path)
tmp_ckpt_path.rename(ckpt_path)
self.call("on_after_save_checkpoint", ckpt_path)
[docs] @final
def load_checkpoint(self, ckpt_path: Optional[str] = None):
"""Load a checkpoint from the given path.
The checkpoint state dict is loaded and processed by various callbacks
that implement the `on_load_checkpoint` method.
Args:
ckpt_path: The path to the checkpoint to load
If not provided and `autoload_last_checkpoint` is True,
then the latest checkpoint is loaded
"""
# Don't load a checkpoint if compile/validate only
if not self.backend.is_e2e_execution:
return
if (
ckpt_path is UNSPECIFIED
and self.checkpoint.autoload_last_checkpoint
):
ckpt_path = self.checkpoint.get_latest_checkpoint(self)
if not ckpt_path or ckpt_path is UNSPECIFIED:
self.call("on_before_load_checkpoint", None)
return
self.call("on_before_load_checkpoint", ckpt_path)
state_dict = cstorch.load(ckpt_path)
self.call("preprocess_checkpoint", state_dict)
self.call("on_load_checkpoint", state_dict)