cerebras.pytorch.amp#

Automatic mixed precision#

The following classes and subclasses are designed to facilitate automatic mixed precision on the Cerebras Wafer Scale Cluster

GradScaler#

class cerebras.pytorch.amp.GradScaler(loss_scale=None, init_scale=None, steps_per_increase=None, min_loss_scale=None, max_loss_scale=None, overflow_tolerance=0.0, max_gradient_norm=None)[source]#

Faciliates mixed precision training and DLS, DLS + GCC

For more details please see docs for amp.initialize.

Parameters
  • loss_scale (Union[str, float]) – If loss_scale == “dynamic”, then configure dynamic loss scaling. Otherwise, it is the loss scale value used in static loss scaling.

  • init_scale (float) – The initial loss scale value if loss_scale == “dynamic”

  • steps_per_increase (int) – The number of steps after which to increase the loss scaling condition

  • min_loss_scale (float) – The minimum loss scale value that can be chosen by dynamic loss scaling

  • max_loss_scale (float) – The maximum loss scale value that can be chosen by dynamic loss scaling

  • overflow_tolerance (float) – The maximum fraction of steps involving infinite or undefined values in the gradient we allow. We reduce the loss scale if the tolerance is exceeded

  • max_gradient_norm (float) – The maximum gradient norm to use for global gradient clipping Only applies in the DLS + GCC case. If GCC is not enabled, then this parameter has no effect

Example usage:

grad_scaler = cstorch.amp.GradScaler(loss_scale="dynamic")

loss: torch.Tensor = ...

optimizer.zero_grad()
# Scale the loss before calling the backward pass
grad_scaler.scale(loss).backward()

# Unscales the gradients of optimizer's assigned params in-place
# to facilitate things like gradient clipping
grad_scaler.unscale_(optimizer)

# Global gradient clipping
torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    1.0,  # max gradient norm
)

# Step the optimizer using the grad scaler
grad_scaler.step(optimizer)

# update the grad scaler once all optimizers have been stepped
grad_scaler.update()
state_dict(destination=None)[source]#

Returns a dictionary containing the state to be saved to a checkpoint

load_state_dict(state_dict)[source]#

Loads the state dictionary into the current params

scale(loss)[source]#

Scales the loss in preparation of the backwards pass

get_scale()[source]#

Return the loss scale

unscale_(optimizer)[source]#

Unscales the optimizer’s params gradients inplace

step_if_finite(optimizer, *args, **kwargs)[source]#

Directly conditionalize the call to optimizer.step(*args, **kwargs) but only if this GradScaler detected finite grads.

Parameters
  • optimizer (torch.optim.Optimizer) – Optimizer that applies the gradients.

  • args – Any arguments passed to the optimizer.step() call.

  • kwargs – Any keyword arguments passed to the optimizer.step() call.

Returns

The result of optimizer.step()

clip_gradients_and_return_isfinite(optimizers)[source]#

Clip the optimizer’s params’s gradients and return whether or not the norm is finite

step(optimizer, *args, **kwargs)[source]#

Step carries out the following two operations: 1. Internally invokes unscale_(optimizer) (unless unscale_ was

explicitly called for optimizer earlier in the iteration). As part of the unscale_, gradients are checked for infs/NaNs.

  1. Invokes optimizer.step() using the unscaled gradients. Ensure that previous optimizer state or params carry over if we encounter NaNs in the gradients.

*args and **kwargs are forwarded to optimizer.step(). Returns the return value of optimizer.step(*args, **kwargs). :param optimizer: Optimizer that applies the gradients. :type optimizer: cerebras.pytorch.optim.Optimizer :param args: Any arguments. :param kwargs: Any keyword arguments.

update_scale(optimizers)[source]#

Update the scales of the optimizers

update(new_scale=None)[source]#

Update the gradient scalar after all optimizers have been stepped

set_half_dtype#

cerebras.pytorch.amp.set_half_dtype(value)[source]#

Sets the underlying 16-bit floating point dtype to use.

Parameters

value (Union[Literal['float16', 'bfloat16', 'cbfloat16'], torch.dtype]) – Either a 16-bit floating point torch dtype or one of “float16”, “bfloat16”, or “cbfloat16” string.

Returns

The proxy torch dtype to use for the model. For dtypes that have a torch representation, this returns the same as value passed in. Otherwise, it returns a proxy dtype to use in the model. On CSX, these proxy dtypes are automatically and transparently converted to the real dtype during compilation.

Return type

torch.dtype

By default, automatic mixed precision uses float16. If you want to use cbfloat16 or bfloat16 instead of float16, call this function.

Example usage:

cstorch.amp.set_half_dtype("cbfloat16")

optimizer_step#

cerebras.pytorch.amp.optimizer_step(loss, optimizer, grad_scaler, max_gradient_norm=None, max_gradient_value=None)[source]#

Performs loss scaling, gradient scaling and optimizer step

Parameters
  • loss (torch.Tensor) – The loss value to scale. loss.backward should be called before this function

  • optimizer (cerebras.pytorch.optim.optimizer.Optimizer) – The optimizer to step

  • grad_scaler (cerebras.pytorch.amp.grad_scaler.GradScaler) – The gradient scaler to use to scale the parameter gradients

  • max_gradient_norm (Optional[float]) – the max gradient norm to use for gradient clipping

  • max_gradient_value (Optional[float]) – the max gradient value to use for gradient clipping

Example usage:

cstorch.amp.optimizer_step(
    loss,
    optimizer,
    grad_scaler,
    max_gradient_norm=1.0,
)