cerebras.pytorch.amp
#
Automatic mixed precision#
The following classes and subclasses are designed to facilitate automatic mixed precision on the Cerebras Wafer Scale Cluster
GradScaler
#
- class cerebras.pytorch.amp.GradScaler(loss_scale=None, init_scale=None, steps_per_increase=None, min_loss_scale=None, max_loss_scale=None, overflow_tolerance=0.0, max_gradient_norm=None)[source]#
Faciliates mixed precision training and DLS, DLS + GCC
For more details please see docs for amp.initialize.
- Parameters
loss_scale (Union[str, float]) – If loss_scale == “dynamic”, then configure dynamic loss scaling. Otherwise, it is the loss scale value used in static loss scaling.
init_scale (float) – The initial loss scale value if loss_scale == “dynamic”
steps_per_increase (int) – The number of steps after which to increase the loss scaling condition
min_loss_scale (float) – The minimum loss scale value that can be chosen by dynamic loss scaling
max_loss_scale (float) – The maximum loss scale value that can be chosen by dynamic loss scaling
overflow_tolerance (float) – The maximum fraction of steps involving infinite or undefined values in the gradient we allow. We reduce the loss scale if the tolerance is exceeded
max_gradient_norm (float) – The maximum gradient norm to use for global gradient clipping Only applies in the DLS + GCC case. If GCC is not enabled, then this parameter has no effect
Example usage:
grad_scaler = cstorch.amp.GradScaler(loss_scale="dynamic") loss: torch.Tensor = ... optimizer.zero_grad() # Scale the loss before calling the backward pass grad_scaler.scale(loss).backward() # Unscales the gradients of optimizer's assigned params in-place # to facilitate things like gradient clipping grad_scaler.unscale_(optimizer) # Global gradient clipping torch.nn.utils.clip_grad_norm_( model.parameters(), 1.0, # max gradient norm ) # Step the optimizer using the grad scaler grad_scaler.step(optimizer) # update the grad scaler once all optimizers have been stepped grad_scaler.update()
- state_dict(destination=None)[source]#
Returns a dictionary containing the state to be saved to a checkpoint
- step_if_finite(optimizer, *args, **kwargs)[source]#
Directly conditionalize the call to optimizer.step(*args, **kwargs) but only if this GradScaler detected finite grads.
- Parameters
optimizer (torch.optim.Optimizer) – Optimizer that applies the gradients.
args – Any arguments passed to the optimizer.step() call.
kwargs – Any keyword arguments passed to the optimizer.step() call.
- Returns
The result of optimizer.step()
- clip_gradients_and_return_isfinite(optimizers)[source]#
Clip the optimizer’s params’s gradients and return whether or not the norm is finite
- step(optimizer, *args, **kwargs)[source]#
Step carries out the following two operations: 1. Internally invokes
unscale_(optimizer)
(unless unscale_ wasexplicitly called for
optimizer
earlier in the iteration). As part of the unscale_, gradients are checked for infs/NaNs.Invokes
optimizer.step()
using the unscaled gradients. Ensure that previous optimizer state or params carry over if we encounter NaNs in the gradients.
*args
and**kwargs
are forwarded tooptimizer.step()
. Returns the return value ofoptimizer.step(*args, **kwargs)
. :param optimizer: Optimizer that applies the gradients. :type optimizer: cerebras.pytorch.optim.Optimizer :param args: Any arguments. :param kwargs: Any keyword arguments.
set_half_dtype
#
- cerebras.pytorch.amp.set_half_dtype(value)[source]#
Sets the underlying 16-bit floating point dtype to use.
- Parameters
value (Union[Literal['float16', 'bfloat16', 'cbfloat16'], torch.dtype]) – Either a 16-bit floating point torch dtype or one of “float16”, “bfloat16”, or “cbfloat16” string.
- Returns
The proxy torch dtype to use for the model. For dtypes that have a torch representation, this returns the same as value passed in. Otherwise, it returns a proxy dtype to use in the model. On CSX, these proxy dtypes are automatically and transparently converted to the real dtype during compilation.
- Return type
By default, automatic mixed precision uses
float16
. If you want to usecbfloat16
orbfloat16
instead offloat16
, call this function.Example usage:
cstorch.amp.set_half_dtype("cbfloat16")
optimizer_step
#
- cerebras.pytorch.amp.optimizer_step(loss, optimizer, grad_scaler, max_gradient_norm=None, max_gradient_value=None)[source]#
Performs loss scaling, gradient scaling and optimizer step
- Parameters
loss (torch.Tensor) – The loss value to scale. loss.backward should be called before this function
optimizer (cerebras.pytorch.optim.optimizer.Optimizer) – The optimizer to step
grad_scaler (cerebras.pytorch.amp.grad_scaler.GradScaler) – The gradient scaler to use to scale the parameter gradients
max_gradient_norm (Optional[float]) – the max gradient norm to use for gradient clipping
max_gradient_value (Optional[float]) – the max gradient value to use for gradient clipping
Example usage:
cstorch.amp.optimizer_step( loss, optimizer, grad_scaler, max_gradient_norm=1.0, )