Customizing the Trainer with Callbacks#
The Trainer
class was designed to be easily
extendable using Callback
classes. The Trainer
exposes a number of hooks
which can be overriden using a Callback
.
On this page, you will learn about the basic
Callback
mechanism. By the end
you should be able to write and use your own custom
Callback
.
Prerequisites#
Please ensure that you have read through the Cerebras Model Zoo Trainer Overview beforehand. The rest of this page assumes that you already have at least a cursory understanding of what the Cerebras Model Zoo Trainer is and how to use the python API.
Callbacks#
The callback mechanism is the backbone of the
Trainer
’s implementation. A lot of the heavy
lifting in the Trainer is actually done by various Core Callbacks.
In general, the Callback
mechanism
exposes a number of useful hooks that allow you to inject certain behaviour into
the Trainer
. These hooks include (but are not limited to)
setup
on_{fit,train,validate}_{start,end}
on_{train,validate}_batch_{start,end}
on_{after,before}_{forward,backward}
on_{after,before}_optimizer_{step,zero_grad}
on_{after,before}_scheduler_step
on_{save,load}_checkpoint
on_after_save_checkpoint
on_before_load_checkpoint
The following pseudocode describes the structure of the
fit
call and where the various hooks get
called.
class Trainer:
...
def fit(self, train_dataloader, val_dataloader, ckpt_path):
on_before_load_checkpoint()
load_checkpoint(ckpt_path)
on_fit_start(...)
for loop_idx in range(num_trains):
on_train_start(...)
for batch in train_dataloader:
on_train_batch_start(...)
on_before_forward(...)
outputs = model(batch)
on_after_forward(...)
on_before_backward(...)
outputs["loss"].backward()
on_after_backward(...)
on_before_optimizer_step(...)
optimizer.step()
on_after_optimizer_step(...)
on_before_optimizer_zero_grad(...)
optimizer.zero_grad()
on_after_optimizer_zero_grad(...)
for scheduler in schedulers:
on_before_scheduler_step(...)
scheduler.step()
on_after_scheduler_step(...)
on_train_batch_end(...)
on_train_end(...)
for batch in val_dataloader:
on_validate_batch_start(...)
on_before_forward(...)
outputs = model(batch)
on_after_forward(...)
on_validate_batch_end(...)
on_fit_end(...)
For a comprehensive list of all supported hooks (as well as the arguments they
accept), see the API docs for the
Callback
class.
Pre-packaged Callbacks#
There are many callbacks that come pre-packaged inside of the Model Zoo. See Add-on Callbacks for a complete list of all the callbacks available out-of-the-box in the Model Zoo
You can use any number of them to enhance the Trainer for your run.
For example,
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import (
ComputeNorm,
CheckLoss,
...
)
trainer = Trainer(
...,
callbacks=[
ComputeNorm(),
CheckLoss(),
...
],
...
)
...
You can configure any pre-packaged callback inside a YAML configuration file as follows:
trainer:
init:
...
callbacks:
- ComputeNorm: {}
- CheckLoss: {}
...
Global Callbacks#
Any callback can be registered globally so that all Trainer
instances know about it and will invoke that callback’s hooks.
There are two ways to globally register a callback. The first way is to treat the callback as a context manager. For example,
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import CheckLoss
with CheckLoss():
trainer = Trainer(...)
trainer.fit(...)
In the above example, while within the CheckLoss
’s context, all trainer
fit
calls inside the context will check
the loss values that come out of the model.
The other way to register a callback is to call :py:function:`~cerebras.modelzoo.trainer.callbacks.register_global_callback`. For example,
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import CheckLoss, register_global_callback
callback = CheckLoss()
handle = register_global_callback(callback)
trainer = Trainer(...)
trainer.fit(...)
handle.remove()
In the above example, all trainer
fit
calls inside the context will check
the loss values that come out of the model.
:py:function:`~cerebras.modelzoo.trainer.callbacks.register_global_callback`
returns a removeable handle object that can be used to remove the added callback by
calling handle.remove()
Callback Ordering#
The Trainer
is comprised of many different callbacks
that all serve to enhance the functionality of the Trainer
.
All of these callbacks share common hooks. These hooks must be called in a specific order. The order in which callbacks are invoked is as follows:
Core Callbacks: The callbacks that implement the most fundamental behaviour of the
Trainer
get called first.User-defined callbacks: The callbacks that are passed into the
callbacks
argument of theTrainer
’s constructor are called next.Global callbacks: Finally, the callbacks that are registered globally are called.
For example,
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import (
CheckLoss,
ComputeNorm,
TrainingLoop,
)
trainer = Trainer(
...,
loop=TrainingLoop(num_steps=1),
callbacks=[ComputeNorm()],
...
)
with CheckLoss():
trainer.fit(...)
Let’s consider the on_fit_start
hook. Between the three callbacks that are
highlighed in the above example, the order that the callbacks’s on_fit_start
hook is invoked is as follows:
TrainingLoop.on_fit_start
: AsTrainingLoop
is a core callback.ComputeNorm.on_fit_start
: AsComputeNorm
was passed into the Trainer’s constructor.CheckLoss.on_fit_start
: As it is a globally registered callback.
Writing a Custom Callback#
To write your own custom callback class, all you need to do is inherit from the
base Callback
class and
override the hooks that you need.
For example, let’s implement a simple callback that scales the loss value by
some constant value before we call loss.backward()
from cerebras.modelzoo.trainer.callbacks import Callback
class ScaleLoss(Callback):
def __init__(self, value):
self.value = value
def on_before_backward(self, trainer, model, outputs):
outputs["loss"] *= value
That is all there is to it. This callback can now be used inside the
Trainer
as follows:
from cerebras.modelzoo import Trainer
trainer = Trainer(
...,
callbacks=[ScaleLoss(value=0.95)],
...
)
...
As long as the callback class exists in the Python global namespace, you can add any custom callback to a YAML configuration file in exactly the same way you would any other pre-packaged callback.
trainer:
init:
...
callbacks:
- ScaleLoss:
value: 0.95
...
In order for the callback class to exist in the Python global namespace,
the Python interpreter must have seen it at some point. Implementing
your custom callback in the run.py
or in the same file as the model
class are two ways to ensure that the callback is seen by the Python
interpreter and loaded into the Python global namespace.
Conclusion#
By this point, you should have a cursory understanding of how Callbacks can be used to enhance the Trainer. There are many useful callbacks that come pre-packaged inside the ModelZoo. If there is some functionality that you need that is not covered, you should be confortable with writing your own to implement that functionality.