Using the CerebrasEstimator#
The CerebrasEstimator
is a critical part of your main Python program when running on the CS system. It is the CerebrasEstimator
that launches the Cerebras Graph Compiler (CGC) when its methods such as compile
, or train
are called while providing the IP address of the CS system with cs_ip
. See also The CerebrasEstimator Interface.
In this section, an example run.py
template is used to show how the CerebrasEstimator
interacts with the key code segments of your Python program.
Shown below is a highly simplified run.py
example code that is used for neural network training:
1 # Example run.py script for neural network training
2 from cerebras.models.common.estimator.tf.cs_estimator import CerebrasEstimator
3 from cerebras.models.common.estimator.tf.run_config import CSRunConfig
4 from cerebras.tf.cs_slurm_cluster_resolver import CSSlurmClusterResolver
5
6 def model_fn(features, labels, mode, params):
7
8 ...
9
10 return spec
11
12 def input_fn(params):
13
14 ...
15
16 return dataset
17
18 config = CSRunConfig(
19 cs_ip=ip,
20 save_checkpoints_steps=1000,
21 log_step_count_steps=10000,
22 "use_cbfloat16": True )
23 params ={
24 "batch_size":32,
25 "lr":0.1,
26 "use_cbfloat16": True
27 }
28
29 est = CerebrasEstimator(
30 model_fn,
31 config=config,
32 params=params,
33 model_dir='./out',
34 use_cs=True
35 )
36
37 est.train(input_fn, steps=100000)
Calling the CerebrasEstimator#
In the est=CerebrasEstimator(...)
call (line 29), the model_fn
argument is a callback function. When the CerebrasEstimator
receives this argument, the CerebrasEstimator
API waits until one of its methods, train
, is invoked.
Note
The
model_fn
argument to theCerebrasEstimator
interface is passed without the()
.
Callback input function#
The
est.train (input_fn, steps=100000)
(line 37) is atrain
method call to theCerebrasEstimator
withinput_fn
argument as a callback function. TheCerebrasEstimator
then calls theinput_fn
with theparams
argument.Note
The
input_fn
argument to thetrain
method is passed without the()
.Both the
CerebrasEstimator
and TensorFlow Estimator API expect the input function to:Accept a standard group of input parameters with the argument
params
andReturns a
tf.data.Dataset
that yields tensor pairs in the predefined format: tensor with features and tensor with labeles.
Any
params
passed to theCerebrasEstimator
are passed on to theinput_fn
and to themodel_fn
. when theCerebrasEstimator
calls theinput_fn
.The
input_fn
should return atf.data.Dataset
(see Dataset API for documentation).The input function builds the input pipeline and yields the batched data in the form of
(features, labels)
pairs, where:features
can be a tensor or dictionary of tensors, andlabels
can be a tensor, a dictionary of tensors or None.
Example#
def input_fn(params):
...
ds = ds.shuffle(buffer_size)
ds = ds.repeat()
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.prefetch(buffer_size)
return ds
Callback model function#
The model function model_fn
is used to generate the graph for your neural network model.
The
features
andlabels
, the two arguments returned from theinput_fn
, are the handles to the batched data that your model will use. When these two arguments,features
andlabels
, are returned from theinput_fn
, theCerebrasEstimator
will then call themodel_fn
by passing the following arguments to themodel_fn
:The
mode
argument that indicates whether the caller is requesting training.The
params
object that was passed in theest=CerebrasEstimator(...)
call.
Important
The functions
input_fn
and themodel_fn
are called by theCerebrasEstimator
as these two are passed to theCerebrasEstimator
as callback functions. You should not directly call either of these two functions in your TensorFlow code.
Both the CerebrasEstimator
and TensorFlow Estimator API expect the model function to accept a standard group of input parameters and return a standard group of output values.
Currently, the CerebrasEstimator
supports usage of the Tensorflow Keras Layers API in
the model function. However, the Tensorflow Metrics API is not supported.
Syntax#
def model_fn(
features, # This is batch_features from input_fn
labels, # This is batch_labels from input_fn
mode, # An instance of tf.estimator.ModeKeys
params # Additional configuration
):
Example#
See below an example of model_fn
definition.
def model_fn(features, labels, mode=tf.estimator.ModeKeys.TRAIN, params=None):
""" Model definition """
logits = build_model(features, params)
learning_rate = tf.constant(params["lr"])
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
loss_op = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits) )
train_op = tf.train.GradientDescentOptimizer(learning_rate=learning_rate ).minimize(loss_op, global_step=tf.train.get_global_step())
spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss_op, train_op=train_op)
return spec
Setting the runtime configuration#
Runtime and environment options can be set. Usually this is the information that is not captured in the model_fn
and input_fn
. Use the CSRunConfig
object to set these Cerebras-specific options. These options are an extension of TensorFlow RunConfig.
Important
Make sure to add the following import
statement to your Slurm-orchestrated TensorFlow code so that Slurm cluster resolving is done automatically.
from cerebras.tf.cs_slurm_cluster_resolver import CSSlurmClusterResolver
CSRunConfig#
The Cerebras CSRunConfig
class inherits from the standard TensorFlow RunConfig
class. You can pass to the CSRunConfig
the same parameters as those of the Tensorflow RunConfig
, and also pass additional parameters that specify the configurations for a CerebrasEstimator
run, including the IP address of the CS system. Such additional parameters include:
cs_ip
: IP address of the CS system, provided by Cerebras.system_name
: Name of the CS system.
The full list of options for TensorFlow RunConfig
can be found
here.
Example#
from cerebras.models.common.estimator.tf.run_config import CSRunConfig
from cerebras.tf.cs_slurm_cluster_resolver import CSSlurmClusterResolver
config = CSRunConfig(
cs_ip=ip,
save_checkpoints_steps=1000,
log_step_count_steps=10000,
save_summary_steps=1000
)