Static graphs#
As of the 2.1.0 software release, we do not officially support reprogramming the Cerebras Wafer-Scale cluster after initial programming. This means that multiple compiles are not supported and therefore, the PyTorch compute graph must not change between iterations.
The way to define a training/evaluation is by decorating a function using
cerebras.pytorch.trace
.
For example:
loss_fn = torch.nn.CrossEntropyLoss()
@cstorch.trace
def training_step(inputs, targets):
outputs = compiled_model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss
By default, the training_step
function is only ever traced a single time.
This means that the contents of the traced function must represent a static
computation graph. If there are any conditionals, the branch that is encountered
in the first iteration is what gets encoded into the graph. If there are any
loops, the loops get unrolled according to the number of times the loop ran in
the first iteration.
In addition, any other side effects, such as print statements and changes to python scalars, will only happen once when the function is being traced.
Retracing every Iteration#
There is an option to enable retracing every iteration. The way to do this is to
specify the retrace_every_iteration
flag while constructing the
backend
to enable retracing.
For example:
backend = cstorch.backend("CSX", ..., retrace_every_iteration=True)
Setting this flag to True
means that the function decorated with
cerebras.pytorch.trace
will be traced every single iteration. The
benefit to retracing every iteration is that side effects such as print statements
and changes to python scalars will happen at every iteration now.
It is important to note that dynamic graph logic will still not be captured. Python conditionals will be resolved at trace time and python loops will be unrolled. If the computation graph has changed in any way between iterations, then a compile error will be thrown.
Note
Retracing every iteration can have a performance impact for smaller models where tracing time outweights the time it takes to execute the model.
For larger models where execution time is significantly larger, retracing time should be negligible.