Defer Weight Initialization#
In this page, you will explore how to defer model weight initialization using the
Trainer
class. By deferring the initialization of
model weights, you can significantly reduce the time-to-first-loss, leading to
faster iteration times and a more efficient training process.
By the end of this guide, you will understand how to implement deferred initialization and the advantages it brings to your model training.
Prerequisites#
Please ensure that you have read through the Trainer Overview beforehand. The rest of this page assumes that you already have at least a cursory understanding of what the Cerebras Model Zoo Trainer is and how to use the python API.
Model Function#
In Basic Usage, it was shown that you could pass any
torch.nn.Module
into the
Trainer
. However, to do this, you need to have a
concrete torch.nn.Module
object to pass into the
Trainer
. Due to PyTorch’s eager nature, initializing
a model can be very time consuming, especially for extremely large models.
To improve your experience, we introduce a mechanism by which you can defer your
model’s weight initialization. The way to do this would be to pass in a function
to the model
argument of the Trainer
that takes
in no arguments and returns a torch.nn.Module
object.
import torch
from cerebras.modelzoo import Trainer
def model_fn() -> torch.nn.Module:
...
trainer = Trainer(
device="CSX",
model=model_fn,
...
)
Any callable that matches this schema is accepted. So, you could construct your model
inline using a lambda
function as follows.
import torch
trainer = Trainer(
device="CSX",
model=lambda: torch.nn.Linear(784, 10),
...
)
Passing in a model function will allow the Trainer
to
employ the use of the Efficient weight initialization mechanism built into
the Cerebras PyTorch API (see here for more
details).
Empirically, deferring model weight initialization can reduce the time-to-first-loss (the amount of time it takes to get the first value back from the Wafer-Scale Cluster) by over 50%. This means, less time waiting around and faster iteration time overall.
Optimizer/Scheduler Functions#
One question that may be on your mind is, “what about the optimizer?”. The optimizer constructor takes in the model parameters. So, if the model initialization is being delayed, then how can the optimizer be constructed?
The answer is that the Trainer
can also accept a
function for the optimizer
argument which is expected to take in a
torch.nn.Module
and return a
Optimizer
object.
import cerebras.pytorch as cstorch
import torch
trainer = Trainer(
device="CSX",
model=lambda: torch.nn.Linear(784, 10),
optimizer=lambda model: cstorch.optim.SGD(
model.parameters(),
lr=0.01,
momentum=0.9,
),
...
)
Similarly, the Trainer
can also accept a
function for the schedulers
argument which is expected to take in a
Optimizer
object and return a
Scheduler
object.
import cerebras.pytorch as cstorch
import torch
trainer = Trainer(
device="CSX",
model=lambda: torch.nn.Linear(784, 10),
optimizer=lambda model: cstorch.optim.SGD(
model.parameters(),
lr=0.01,
momentum=0.9,
),
schedulers=[
lambda optimizer: cstorch.optim.lr_scheduler.LinearLR(
optimizer,
initial_learning_rate=0.01,
end_learning_rate=0.01,
total_iters=100,
),
],
)
Conclusion#
That is all there is to deferring model, optimizer, and scheduler initialization! By simply wrapping these components inside a callable, you can gain large improvements to iteration time and resource utilization, enhancing your overall experience with the Cerebras Model Zoo and ensuring more efficient training outcomes.
Further Reading#
To learn about how you can configure a Trainer
instance using a YAML configuration file, you can check out:
Trainer YAML Overview
To learn more about how you can use the Trainer
in some core workflows, you can check out:
To learn more about how you can extend the capabilities of the
Trainer
class, you can check out: