# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains the CheckLoss callback."""
from math import prod
import torch
import cerebras.pytorch as cstorch
from cerebras.modelzoo.trainer.callbacks import Callback
[docs]class CheckLoss(Callback):
"""Callback class that checks for NaN or inf loss values.
It also checks whether the model output contains a scalar loss value.
"""
[docs] def on_after_forward(self, trainer, model, outputs, batch):
if "loss" in outputs:
loss = outputs["loss"]
if not isinstance(loss, torch.Tensor):
raise TypeError(
f"Expected loss to be a scalar torch.Tensor, "
f"but got {type(loss)} instead."
)
elif prod(loss.shape) > 1:
raise TypeError(
f"Expected loss to be a scalar torch.Tensor, "
f"but got tensor with shape {loss.shape} instead."
)
[docs] @cstorch.step_closure
def check_loss(self, loss: torch.Tensor): # pylint: disable=no-self-use
"""Checks for NaN or inf loss values.
Args:
loss: Scalar loss tensor.
"""
msg_postfix = (
"This could potentially be due to selected hyperparameters "
"such as the learning rate, batch size, etc. or it could due "
"an internal error. Please try with different set of "
"hyperparameters and contact Cerebras Support if the issue "
"persists."
)
from cerebras.appliance.errors import ApplianceNanError
if torch.isnan(loss).any().item():
raise ApplianceNanError(f"NaN loss detected. {msg_postfix}")
if torch.isinf(loss).any().item():
raise ApplianceNanError(f"inf loss detected. {msg_postfix}")
[docs] def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx):
if "loss" in outputs:
self.check_loss(outputs["loss"])
[docs] def on_validate_batch_end(self, trainer, model, outputs, batch, batch_idx):
if "loss" in outputs:
self.check_loss(outputs["loss"])