# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
""" Perplexity metric for PyTorch """
import torch
from cerebras.pytorch.metrics.metric import Metric
[docs]class PerplexityMetric(Metric):
"""Computes the perplexity of the model's predictions
Args:
name: Name of the metric
"""
[docs] def reset(self):
self.register_state("total_loss", torch.tensor(0, dtype=torch.float32))
self.register_state(
"total_num_tokens", torch.tensor(0, dtype=torch.float32)
)
self._dtype = None
[docs] def update(self, labels, loss, weights=None, dtype=None):
if weights is None:
num_tokens = torch.tensor(
labels.numel(), dtype=torch.float32, device=labels.device
)
else:
num_tokens = (weights > 0).float().sum()
self.total_loss.add_(loss)
self.total_num_tokens.add_(num_tokens)
self._dtype = dtype
[docs] def compute(self) -> torch.Tensor:
result = torch.exp(self.total_loss / self.total_num_tokens)
if self._dtype is not None:
result = result.to(self._dtype)
return result