Model#
This page will cover how to pass a model into the Trainer
.
The model
is the main Module
that all training
and validation is run on. It is required by all Trainer
instances.
Prerequisites#
Make sure to have read through Trainer Overview and Trainer Configuration Overview which provide the basic overview of how to run Model Zoo models. In this document, you will be using the tools and configurations outlined in those pages.
Configure the model
#
To set the model to train/validate using the Trainer
use
the model
argument.
All model
subkeys are passed as arguments to the model class. The
model class is decided by the model_fn
in your run script.
trainer:
init:
...
model:
vocab_size: 1024
max_position_embeddings: 1024
...
...
...
The model can be passed as either:
a callable assumed to be a function that takes in no arguments returns a
Module
a
Module
is passed and used as is
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.models.nlp.gpt2.model import Gpt2Model
trainer = Trainer(
...,
model=lambda: Gpt2Model(
vocab_size=1024,
max_position_embeddings=1024,
...,
),
...,
)
...
Note
If passing the model
as a Module
directly,
it is optimal to first initalize the model inside of the Cerebras device
context.
For example:
import cerebras.pytorch as cstorch
from cerebras.modelzoo import Trainer
from cerebras.modelzoo.models.nlp.gpt2.model import Gpt2Model
# Initialize the Cerebras backend for efficient processing.
backend = cstorch.backend("CSX")
# Use the backend's device context manager for initializing the model.
with backend.device:
model = Gpt2Model(
vocab_size=1024,
max_position_embeddings=1024,
...,
)
# Compile the model using the Cerebras backend for optimized execution.
trainer = Trainer(
...,
backend=backend,
model=model,
...,
)
...
This ensures that model parameters are automatically moved to the Cerebras device, optimizing memory usage and enhancing initialization speed. For more information, see Efficient weight initialization.
Conclusion#
That covers specifying the model to train/validate with the Trainer
.
You should now understand the various ways to configure the model and how the
Trainer
accepts a model.
Further Reading#
To learn more about how you can use the Trainer
in some core workflows, you can check out:
To learn more about how you can extend the capabilities of the
Trainer
class, you can check out: