# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
""" Accuracy metric for PyTorch """
import warnings
import torch
from cerebras.pytorch.metrics.metric import Metric
[docs]class AccuracyMetric(Metric):
"""Computes the accuracy of the model's predictions
Args:
name: Name of the metric
"""
[docs] def reset(self):
self.register_state(
"total_correct_predictions", torch.tensor(0, dtype=torch.float32)
)
self.register_state(
"total_num_tokens", torch.tensor(0, dtype=torch.float32)
)
self._dtype = None
[docs] def update(
self, labels, predictions, weights=None, dtype=None
): # pylint: disable=arguments-differ
if labels.shape != predictions.shape:
warnings.warn(
"Shapes mismatch in accuracy metric"
f"\n labels: {labels.shape}"
f"\n predictions {predictions.shape}"
)
predictions = predictions.reshape(labels.shape)
correct_predictions = (labels == predictions).float()
if weights is None:
num_correct_predictions = correct_predictions.sum()
num_tokens = torch.tensor(
correct_predictions.numel(),
dtype=torch.float32,
device=predictions.device,
)
else:
correct_predictions = correct_predictions * weights
num_correct_predictions = correct_predictions.sum()
num_tokens = (weights > 0).float().sum()
self.total_correct_predictions.add_(num_correct_predictions)
self.total_num_tokens.add_(num_tokens)
self._dtype = dtype
[docs] def compute(self) -> torch.Tensor:
result = self.total_correct_predictions / self.total_num_tokens
if self._dtype is not None:
result = result.to(self._dtype)
return result