# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the base Callback class and the global callback registry."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import ExitStack
from typing import TYPE_CHECKING, Any, Dict, List
import torch
import cerebras.pytorch
if TYPE_CHECKING:
from ..trainer import Trainer
from .loop import TrainingLoop, ValidationLoop
[docs]class Callback:
"""
Base class for all callbacks.
"""
[docs] def pre_setup(self, trainer: Trainer):
"""Called before the trainer setup.
Args:
trainer: Trainer instance.
"""
[docs] def setup(self, trainer: Trainer):
"""Setup the callback using the trainer.
Args:
trainer: Trainer instance.
"""
[docs] def finalize(self):
"""Clean up the callback.
This method is called when the trainer is destructed.
"""
[docs] def on_enter_fit(
self,
trainer: Trainer,
stack: ExitStack,
train_dataloader: cerebras.pytorch.utils.data.DataLoader,
val_dataloader: cerebras.pytorch.utils.data.DataLoader,
loop: TrainingLoop,
):
"""Hook that allows arbitrary context managers to be entered
at the beginning of the fit method.
Args:
trainer: Trainer instance.
stack: ExitStack object.
train_dataloader: Train dataloader.
val_dataloader: Validation dataloader.
loop: TrainingLoop object.
"""
[docs] def on_fit_start(
self,
trainer: Trainer,
train_dataloader: cerebras.pytorch.utils.data.DataLoader,
val_dataloader: cerebras.pytorch.utils.data.DataLoader,
loop: TrainingLoop,
):
"""Called at the beginning of the fit method.
Args:
trainer: Trainer instance.
train_dataloader: Train dataloader.
val_dataloader: Validation dataloader.
loop: TrainingLoop object.
"""
[docs] def on_fit_end(self, trainer: Trainer, loop: TrainingLoop):
"""Called at the end of the fit method.
Args:
trainer: Trainer instance.
loop: TrainingLoop object.
"""
[docs] def on_fit_exception(self, trainer: Trainer, exception: Exception):
"""Called if an exception is raised during fit.
Args:
trainer: Trainer instance.
exception: Exception object.
"""
[docs] def on_enter_train(
self,
trainer: Trainer,
stack: ExitStack,
train_dataloader: cerebras.pytorch.utils.data.DataLoader,
loop: TrainingLoop,
loop_idx: int,
):
"""Hook that allows arbitrary context managers to be entered
at the beginning of every training iteration.
Args:
trainer: Trainer instance.
stack: ExitStack object.
train_dataloader: Train dataloader.
loop: TrainingLoop object.
loop_idx: training loop index.
"""
[docs] def on_train_start(
self,
trainer: Trainer,
model: torch.nn.Module,
train_dataloader: cerebras.pytorch.utils.data.DataLoader,
loop: TrainingLoop,
loop_idx: int,
):
"""Called at the beginning of the train loop.
Args:
trainer: Trainer instance.
model: Model instance.
train_dataloader: Train dataloader.
loop: TrainingLoop object.
loop_idx: training loop index.
"""
[docs] def on_train_end(
self,
trainer: Trainer,
model: torch.nn.Module,
loop: TrainingLoop,
loop_idx: int,
):
"""Called at the end of the train loop.
Args:
trainer: Trainer instance.
model: Model instance.
loop: TrainingLoop object.
loop_idx: training loop index.
"""
[docs] def on_train_exception(self, trainer, exception):
"""Called if an exception is raised during a training iteration.
Args:
trainer: Trainer instance.
exception: Exception object.
"""
[docs] def on_train_batch_start(
self,
trainer: Trainer,
model: torch.nn.Module,
batch: Any,
batch_idx: int,
):
"""Called at the beginning of every training iteration.
Args:
trainer: Trainer instance.
model: Model instance.
batch: Batch data.
batch_idx: Batch index.
"""
[docs] def on_train_batch_end(
self,
trainer: Trainer,
model: torch.nn.Module,
outputs: Dict[str, Any],
batch: Any,
batch_idx: int,
):
"""Called at the end of every training iteration.
Args:
trainer: Trainer instance.
model: Model instance.
outputs: Model outputs.
batch: Batch data.
batch_idx: Batch index.
"""
[docs] def run_validation(
self,
trainer: Trainer,
loop_idx: int,
is_last: bool,
):
"""Perform a validation run.
Override this method to perform a custom validation run.
Args:
trainer: Trainer instance.
val_dataloader: Validation dataloader.
loop_idx: Training loop index.
is_last: Whether the last training iteration just happened.
"""
[docs] def on_enter_validate(
self,
trainer: Trainer,
stack: ExitStack,
val_dataloader: cerebras.pytorch.utils.data.DataLoader,
loop: ValidationLoop,
):
"""Hook that allows arbitrary context managers to be entered
at the beginning of every validation run.
Args:
trainer: Trainer instance.
stack: ExitStack object.
val_dataloader: Validation dataloader.
loop: ValidationLoop object.
"""
[docs] def on_validate_start(
self,
trainer: Trainer,
model: torch.nn.Module,
val_dataloader: cerebras.pytorch.utils.data.DataLoader,
loop: ValidationLoop,
):
"""Called at the beginning of the validation loop.
Args:
trainer: Trainer instance.
model: Model instance.
val_dataloader: Validation dataloader.
loop: ValidationLoop object.
"""
[docs] def on_validate_end(
self,
trainer: Trainer,
model: torch.nn.Module,
loop: ValidationLoop,
):
"""Called at the end of the validation loop.
Args:
trainer: Trainer instance.
model: Model instance.
loop: ValidationLoop object.
"""
[docs] def on_validate_exception(self, trainer: Trainer, exception: Exception):
"""Called if an exception is raised during validation.
Args:
trainer: Trainer instance.
exception: Exception object.
"""
[docs] def on_validate_batch_start(
self,
trainer: Trainer,
model: torch.nn.Module,
batch: Any,
batch_idx: int,
):
"""Called at the beginning of every validation iteration.
Args:
trainer: Trainer instance.
model: Model instance.
batch: Batch data.
batch_idx: Batch index.
"""
[docs] def on_validate_batch_end(
self,
trainer: Trainer,
model: torch.nn.Module,
outputs: Dict[str, Any],
batch: Any,
batch_idx: int,
):
"""Called at the end of every validation iteration.
Args:
trainer: Trainer instance.
model: Model instance.
outputs: Model outputs.
batch: Batch data.
batch_idx: Batch index.
"""
[docs] def on_enter_validate_all(
self,
trainer: Trainer,
stack: ExitStack,
val_dataloaders: cerebras.pytorch.utils.data.DataLoader,
loop: ValidationLoop,
):
"""Hook that allows arbitrary context managers to be entered
at the beginning of every validate all run.
Args:
trainer: Trainer instance.
stack: ExitStack object.
val_dataloaders: Validation dataloaders.
loop: ValidationLoop object.
"""
[docs] def on_before_forward(
self,
trainer: Trainer,
model: torch.nn.Module,
batch: Any,
args: List[Any],
kwargs: dict,
):
"""Called before the forward pass.
The args and kwargs may be added to to provide additional
arguments to the forward method.
Args:
trainer: Trainer instance.
model: Model instance.
batch: Batch data.
args: Forward pass arguments.
kwargs: Forward pass keyword arguments.
"""
[docs] def on_after_forward(
self,
trainer: Trainer,
model: torch.nn.Module,
outputs: Dict[str, Any],
batch: Any,
):
"""Called after the forward pass.
Args:
trainer: Trainer instance.
model: Model instance.
outputs: Model outputs.
batch: Batch data.
"""
[docs] def on_before_backward(
self,
trainer: Trainer,
model: torch.nn.Module,
outputs: Dict[str, Any],
):
"""Called before the backward pass.
Args:
trainer: Trainer instance.
model: Model instance.
outputs: Model outputs.
"""
[docs] def on_after_backward(
self,
trainer: Trainer,
model: torch.nn.Module,
outputs: Dict[str, Any],
):
"""Called after the backward pass.
Args:
trainer: Trainer instance.
model: Model instance.
outputs: Model outputs.
batch_idx: Batch index.
"""
[docs] def on_before_optimizer_step(
self,
trainer: Trainer,
model: torch.nn.Module,
optimizer: cerebras.pytorch.optim.Optimizer,
):
"""Called before the optimizer step.
Args:
trainer: Trainer instance.
model: Model instance.
optimizer: Optimizer instance.
"""
[docs] def on_after_optimizer_step(
self,
trainer: Trainer,
model: torch.nn.Module,
optimizer: cerebras.pytorch.optim.Optimizer,
):
"""Called after the optimizer step.
Args:
trainer: Trainer instance.
model: Model instance.
optimizer: Optimizer instance.
"""
[docs] def on_before_optimizer_zero_grad(
self,
trainer: Trainer,
model: torch.nn.Module,
optimizer: cerebras.pytorch.optim.Optimizer,
):
"""Called before the optimizer zero_grad.
Args:
trainer: Trainer instance.
model: Model instance.
optimizer: Optimizer instance.
"""
[docs] def on_after_optimizer_zero_grad(
self,
trainer: Trainer,
model: torch.nn.Module,
optimizer: cerebras.pytorch.optim.Optimizer,
):
"""Called after the optimizer zero_grad.
Args:
trainer: Trainer instance.
model: Model instance.
optimizer: Optimizer instance.
"""
[docs] def on_before_scheduler_step(
self,
trainer: Trainer,
model: torch.nn.Module,
optimizer: cerebras.pytorch.optim.Optimizer,
scheduler: cerebras.pytorch.optim.scheduler.Scheduler,
):
"""Called before the scheduler step.
Args:
trainer: Trainer instance.
model: Model instance.
optimizer: Optimizer instance.
scheduler: A scheduler instance.
"""
[docs] def on_after_scheduler_step(
self,
trainer: Trainer,
model: torch.nn.Module,
optimizer: cerebras.pytorch.optim.Optimizer,
scheduler: cerebras.pytorch.optim.scheduler.Scheduler,
):
"""Called after the scheduler step.
Args:
trainer: Trainer instance.
model: Model instance.
optimizer: Optimizer instance.
scheduler: A scheduler instance.
"""
[docs] def on_save_checkpoint(self, trainer: Trainer, state_dict: dict):
"""Called before saving the checkpoint.
Callbacks should override this method to add states to
the checkpoint.
Args:
trainer: Trainer instance.
state_dict: Trainer state dictionary.
"""
[docs] def postprocess_checkpoint(self, trainer: Trainer, state_dict: dict):
"""Called after constructing the checkpoint.
Callbacks should override this method to modify the checkpoint
before saving.
Args:
trainer: Trainer instance.
state_dict: Trainer state dictionary.
"""
[docs] def on_after_save_checkpoint(self, trainer: Trainer, ckpt_path: str):
"""Called after saving the checkpoint.
Args:
trainer: Trainer instance.
ckpt_path: Checkpoint path.
"""
[docs] def on_before_load_checkpoint(self, trainer: Trainer, ckpt_path: str):
"""Called before loading the checkpoint.
Args:
trainer: Trainer instance.
ckpt_path: Checkpoint path.
"""
[docs] def preprocess_checkpoint(self, trainer: Trainer, state_dict: dict):
"""Called after loading the checkpoint.
Callbacks should override this method to modify the state_dict
after loading.
Args:
trainer: Trainer instance.
state_dict: Trainer state dictionary.
"""
[docs] def on_load_checkpoint(self, trainer: Trainer, state_dict: dict):
"""Called after loading the checkpoint.
Callbacks should override this method to load states from
the checkpoint.
Args:
trainer: Trainer instance.
state_dict: Trainer state dictionary.
"""
def __enter__(self):
"""Register the callback as a global callback."""
if not hasattr(self, "_global_callback_handle"):
# pylint: disable=attribute-defined-outside-init
self._global_callback_handle = register_global_callback(self)
return self
def __exit__(self, *args):
"""Remove the callback from the global callback registry."""
if hasattr(self, "_global_callback_handle"):
# pylint: disable=protected-access
self._global_callback_handle.remove()
del self._global_callback_handle
GLOBAL_CALLBACK_REGISTRY = OrderedDict()
[docs]def register_global_callback(callback):
"""Register a global callback.
Args:
callback: the Callback to register.
If a class is passed, an instance of the class is created.
If an instance is passed, it is registered as is.
Returns:
A torch.utils.hooks.RemoveableHandle object.
"""
from inspect import isclass
from torch.utils import hooks
if isinstance(callback, Callback):
handle = hooks.RemovableHandle(GLOBAL_CALLBACK_REGISTRY)
GLOBAL_CALLBACK_REGISTRY[handle.id] = callback
return handle
elif isclass(callback) and issubclass(callback, Callback):
handle = hooks.RemovableHandle(GLOBAL_CALLBACK_REGISTRY)
GLOBAL_CALLBACK_REGISTRY[handle.id] = callback()
# pylint: disable=protected-access
callback._global_callback_handle = handle
return callback
else:
raise TypeError(f"Expected a Callback. Got: {type(callback)}")
[docs]class ValidationCallback(Callback, ABC):
"""
A special type of callback that indicates to the trainer
that it will perform some custom validation logic.
This is useful for callbacks that need to perform downstream validation
logic that is not covered by the default validation loop.
All ValidationCallbacks must implement the following methods:
- run_validation
Essentially, you are telling the trainer what to run at the end of each
training run.
"""
[docs] @abstractmethod
def run_validation(self, trainer, loop_idx, is_last):
pass