Customizing the Trainer with Callbacks#

The Trainer class was designed to be easily extendable using Callback classes. The Trainer exposes a number of hooks which can be overriden using a Callback.

On this page, you will learn about the basic Callback mechanism. By the end you should be able to write and use your own custom Callback.

Prerequisites#

Please ensure that you have read through the Cerebras Model Zoo Trainer Overview beforehand. The rest of this page assumes that you already have at least a cursory understanding of what the Cerebras Model Zoo Trainer is and how to use the python API.

Callbacks#

The callback mechanism is the backbone of the Trainer’s implementation. A lot of the heavy lifting in the Trainer is actually done by various Core Callbacks.

In general, the Callback mechanism exposes a number of useful hooks that allow you to inject certain behaviour into the Trainer. These hooks include (but are not limited to)

  • setup

  • on_{fit,train,validate}_{start,end}

  • on_{train,validate}_batch_{start,end}

  • on_{after,before}_{forward,backward}

  • on_{after,before}_optimizer_{step,zero_grad}

  • on_{after,before}_scheduler_step

  • on_{save,load}_checkpoint

  • on_after_save_checkpoint

  • on_before_load_checkpoint

The following pseudocode describes the structure of the fit call and where the various hooks get called.

class Trainer:
    ...

    def fit(self, train_dataloader, val_dataloader, ckpt_path):

        on_before_load_checkpoint()
        load_checkpoint(ckpt_path)

        on_fit_start(...)

        for loop_idx in range(num_trains):
            on_train_start(...)

            for batch in train_dataloader:
                on_train_batch_start(...)

                on_before_forward(...)
                outputs = model(batch)
                on_after_forward(...)

                on_before_backward(...)
                outputs["loss"].backward()
                on_after_backward(...)

                on_before_optimizer_step(...)
                optimizer.step()
                on_after_optimizer_step(...)

                on_before_optimizer_zero_grad(...)
                optimizer.zero_grad()
                on_after_optimizer_zero_grad(...)

                for scheduler in schedulers:
                    on_before_scheduler_step(...)
                    scheduler.step()
                    on_after_scheduler_step(...)

                on_train_batch_end(...)

            on_train_end(...)

            for batch in val_dataloader:
                on_validate_batch_start(...)

                on_before_forward(...)
                outputs = model(batch)
                on_after_forward(...)

                on_validate_batch_end(...)


        on_fit_end(...)

For a comprehensive list of all supported hooks (as well as the arguments they accept), see the API docs for the Callback class.

Pre-packaged Callbacks#

There are many callbacks that come pre-packaged inside of the Model Zoo. See Add-on Callbacks for a complete list of all the callbacks available out-of-the-box in the Model Zoo

You can use any number of them to enhance the Trainer for your run.

For example,

from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import (
     ComputeNorm,
     CheckLoss,
     ...
)

trainer = Trainer(
    ...,
    callbacks=[
        ComputeNorm(),
        CheckLoss(),
        ...
    ],
    ...
)
...

Global Callbacks#

Any callback can be registered globally so that all Trainer instances know about it and will invoke that callback’s hooks.

There are two ways to globally register a callback. The first way is to treat the callback as a context manager. For example,

from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import CheckLoss

with CheckLoss():
    trainer = Trainer(...)
    trainer.fit(...)

In the above example, while within the CheckLoss’s context, all trainer fit calls inside the context will check the loss values that come out of the model.

The other way to register a callback is to call :py:function:`~cerebras.modelzoo.trainer.callbacks.register_global_callback`. For example,

from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import CheckLoss, register_global_callback

callback = CheckLoss()
handle = register_global_callback(callback)

trainer = Trainer(...)
trainer.fit(...)

handle.remove()

In the above example, all trainer fit calls inside the context will check the loss values that come out of the model.

:py:function:`~cerebras.modelzoo.trainer.callbacks.register_global_callback` returns a removeable handle object that can be used to remove the added callback by calling handle.remove()

Callback Ordering#

The Trainer is comprised of many different callbacks that all serve to enhance the functionality of the Trainer.

All of these callbacks share common hooks. These hooks must be called in a specific order. The order in which callbacks are invoked is as follows:

  1. Core Callbacks: The callbacks that implement the most fundamental behaviour of the Trainer get called first.

  2. User-defined callbacks: The callbacks that are passed into the callbacks argument of the Trainer’s constructor are called next.

  3. Global callbacks: Finally, the callbacks that are registered globally are called.

For example,

from cerebras.modelzoo import Trainer
from cerebras.modelzoo.trainer.callbacks import (
    CheckLoss,
    ComputeNorm,
    TrainingLoop,
)

trainer = Trainer(
    ...,
    loop=TrainingLoop(num_steps=1),
    callbacks=[ComputeNorm()],
    ...
)

with CheckLoss():
    trainer.fit(...)

Let’s consider the on_fit_start hook. Between the three callbacks that are highlighed in the above example, the order that the callbacks’s on_fit_start hook is invoked is as follows:

  1. TrainingLoop.on_fit_start: As TrainingLoop is a core callback.

  2. ComputeNorm.on_fit_start: As ComputeNorm was passed into the Trainer’s constructor.

  3. CheckLoss.on_fit_start: As it is a globally registered callback.

Writing a Custom Callback#

To write your own custom callback class, all you need to do is inherit from the base Callback class and override the hooks that you need.

For example, let’s implement a simple callback that scales the loss value by some constant value before we call loss.backward()

from cerebras.modelzoo.trainer.callbacks import Callback

class ScaleLoss(Callback):
    def __init__(self, value):
        self.value = value

    def on_before_backward(self, trainer, model, outputs):
        outputs["loss"] *= value

That is all there is to it. This callback can now be used inside the Trainer as follows:

from cerebras.modelzoo import Trainer

trainer = Trainer(
    ...,
    callbacks=[ScaleLoss(value=0.95)],
    ...
)
...

Conclusion#

By this point, you should have a cursory understanding of how Callbacks can be used to enhance the Trainer. There are many useful callbacks that come pre-packaged inside the ModelZoo. If there is some functionality that you need that is not covered, you should be confortable with writing your own to implement that functionality.