Source code for cerebras.pytorch.utils.data.utils

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""Utility and helper functions used by the Cerebras dataloader."""
import math
from dataclasses import dataclass
from typing import List, Optional
from warnings import warn

import numpy as np
import torch

import cerebras.pytorch as cstorch
from cerebras.appliance.data.dtypes import bf16, is_bf16
from cerebras.pytorch.utils._num import ceildiv
from cerebras.pytorch.utils.nest import visit_torch_tensors


@dataclass
class Schedule:
    """Generic schedule object that represents a collection of step intervals.

    Args:
        intervals: List of ranges.
    """

    @dataclass
    class Range:
        """A range of steps.

        Args:
            start: The starting index (inclusive).
            end: The end index (exclusive).
            step: the jump size.
            include_last: Whether the `end - 1` is included in the interval or not,
                regardless of whether it overlaps (start + step * N).
        """

        start: int
        end: int
        step: int
        include_last: bool

        def __post_init__(self):
            if not isinstance(self.start, int) or self.start < 0:
                raise ValueError(f"start ({self.start}) must an integer >= 0.")
            if not isinstance(self.end, int) or self.end <= self.start:
                raise ValueError(
                    f"end ({self.end}) must be an integer greater than "
                    f"start {self.start}."
                )
            if not isinstance(self.step, int) or self.step < 1:
                raise ValueError(
                    f"step ({self.step}) must be an integer greater than zero."
                )
            if not isinstance(self.include_last, bool):
                raise ValueError(
                    f"include_last must be a bool, got {type(self.include_last)}"
                )

        def match(self, index: int) -> bool:
            """Returns whether the given index belongs to this interval."""
            return (self.start <= index < self.end) and (
                ((index - self.start) % self.step == 0)
                or (index == self.end - 1 and self.include_last)
            )

        def __iter__(self):
            yield from self.range()
            if self.include_last:
                yield self.end

        def range(self):
            return range(self.start, self.end, self.step)

    intervals: List[Range]

    def __post_init__(self):
        if any(
            not isinstance(interval, Schedule.Range)
            for interval in self.intervals
        ):
            raise ValueError(f"interval must of type {Schedule.Range}")

        for interval1, interval2 in zip(self.intervals, self.intervals[1:]):
            if interval2.start < interval1.end:
                raise ValueError(
                    f"Intervals must be non-overlapping ranges, but got {self.intervals}."
                )

    def match(self, index: int) -> bool:
        """Returns whether the given index belongs to this schedule."""
        return any(interval.match(index) for interval in self.intervals)

    def __iter__(self):
        for interval in self.intervals:
            yield from interval

    def range(self):
        return [interval.range() for interval in self.intervals]


def compute_num_steps(
    dataloader: torch.utils.data.DataLoader,
    initial_step: int = 0,
    num_steps: Optional[int] = None,
    max_steps: Optional[int] = None,
    num_epochs: Optional[int] = None,
    steps_per_epoch: Optional[int] = None,
    grad_accum_steps: int = 1,
):
    """
    Computes the number of steps to execute on the system based on the
    provided step information.

    Args:
        dataloader: The dataloader itself which is used to determine the length
            of the dataset if available
        initial_step: The step to begin on. An error is thrown if the initial
            step exceeds the maximal steps calulated below
        num_steps: The number of steps to run
        max_steps: The maximum number of steps to run
        num_epochs: The number of epochs to run
        steps_per_epoch: The number of steps to run each epoch
        grad_accum_steps: The number of steps accumulate gradients before stepping

    Note:
        At least one of num_steps, max_steps, or num_epochs must be specified

    Returns:
        The calculated total number of steps to execute
    """

    def _check_steps(name, value, allow_none=False, allow_zero=False):
        if value is None:
            if not allow_none:
                raise ValueError(f"`{name}` cannot be None.")
        else:
            if not isinstance(value, int):
                raise ValueError(
                    f"`{name}` must be an integer, but got {type(value)}."
                )
            if value == 0 and not allow_zero:
                raise ValueError(f"`{name}` must be greater than zero.")
            if value < 0:
                raise ValueError(
                    f"`{name}` cannot be negative, but got {value}."
                )

    if num_epochs is not None and num_steps is not None:
        raise ValueError(
            "Only one of `num_epochs` or `num_steps` can be specified."
        )

    _check_steps(
        "initial_step", initial_step, allow_none=False, allow_zero=True
    )
    _check_steps("num_steps", num_steps, allow_none=True)
    _check_steps("max_steps", max_steps, allow_none=True)
    _check_steps("num_epochs", num_epochs, allow_none=True)
    _check_steps("steps_per_epoch", steps_per_epoch, allow_none=True)
    _check_steps("grad_accum_steps", grad_accum_steps, allow_none=False)

    try:
        # Dataset length is known
        dataloader_size = len(dataloader)
        assert dataloader_size > 0, "Dataloader does not generate any batches."
        if steps_per_epoch is not None:
            if steps_per_epoch > dataloader_size:
                raise ValueError(
                    f"The requested steps per epoch of {steps_per_epoch} "
                    f"exceeds total steps in an epoch, which is "
                    f"{dataloader_size}."
                )
        else:
            steps_per_epoch = dataloader_size

        # With grad accumulation, the global step is incremented every Nth
        # batch, so our effective steps per epoch needs to be adjusted.
        if grad_accum_steps > steps_per_epoch:
            raise ValueError(
                f"Gradient accumulation steps of {grad_accum_steps} is "
                f"greater than batches per epoch of {steps_per_epoch}."
            )

        steps_per_epoch //= grad_accum_steps
    except TypeError:
        # Dataset length is not known
        if num_epochs is not None:
            raise ValueError(
                "Specifying num_epochs for datasets with unknown length is "
                "not allowed. Please control training behavior through "
                "number of steps instead."
            )
        steps_per_epoch = 1

    # Calculate total steps
    total_steps = math.inf
    if num_epochs is not None:
        total_steps = min(total_steps, num_epochs * steps_per_epoch)
    if num_steps is not None:
        total_steps = min(total_steps, num_steps)
    if max_steps is not None:
        remaining_steps = max_steps - initial_step
        if remaining_steps <= 0:
            raise RuntimeError(
                f"Initial global step {initial_step} already exceeds "
                f"max step {max_steps}."
            )
        total_steps = min(total_steps, remaining_steps)

    # At least one of the above if blocks must have been true.
    # Adding an assert in case someone makes a mistake.
    if math.isinf(total_steps):
        raise ValueError(
            "One of num_epochs, num_steps, or max_steps must be provided"
        )

    if num_epochs is None:
        steps_per_epoch = total_steps

    # Override steps_per_epoch depending on the num_epochs computation
    num_epochs = ceildiv(total_steps, steps_per_epoch)
    steps_per_epoch = ceildiv(total_steps, num_epochs)

    return total_steps


def infer_batch_size(data, batch_size=None) -> Optional[int]:
    """Infers the batch size from a dataloader batch.

    Args:
        data: A nested structure of tensors.
        batch_size: The batch size to compare against.
            If None, the batch size is inferred from the data.
    Returns:
        If all tensors have the same batch size, it is returned.
        If inconsistent batch sizes are seen across tensors in the batch,
        None is returned in the CPU/GPU case and an error is raised in
        the CSX case.
    """
    inferred_batch_sizes = set(
        1 if len(tensor.size()) == 0 else tensor.size()[0]
        for _, tensor in visit_torch_tensors(data)
    )
    if len(inferred_batch_sizes) > 1:
        if cstorch.use_cs():
            raise RuntimeError(
                f"Only uniform batch sizes are supported in CS runs, but "
                f"the dataloader returned a batch with batch sizes "
                f"{inferred_batch_sizes}. "
            )
        warn(
            f"Detected non-uniform batch sizes within the same batch: "
            f"{inferred_batch_sizes}. While this is allowed in non-CSX "
            f"runs, it may throw off metrics such as rate profiling. "
            f"The run will proceed assuming no batch size."
        )
        return None

    if len(inferred_batch_sizes) == 1:
        inferred_batch_size = inferred_batch_sizes.pop()

        if batch_size is not None and inferred_batch_size != batch_size:
            if cstorch.use_cs():
                raise RuntimeError(
                    f"Only uniform batch sizes are supported in CS runs, but "
                    f"the dataloader returned two different batches with "
                    f"batch sizes {batch_size} and {inferred_batch_size}. "
                    f"Make sure to set `drop_last=True` in the dataloader."
                )
            else:
                warn(
                    f"Detected non-uniform batch sizes between batches "
                    f"({batch_size} vs {inferred_batch_size}). "
                    f"While this is allowed in non-CSX runs, it may throw off "
                    f"metrics such as rate profiling. "
                )

        return inferred_batch_size

    raise RuntimeError(
        "We could not detect any torch tensors in the input data "
        "returned by the dataloader. We expect the dataloader to "
        "return a nested dict/list/tuple of tensors. If there are "
        "custom types that internally hold tensors, we are not "
        "currently able to detect them. Please ensure that the "
        "dataloader returns tensors in the expected format."
    )


[docs]def to_numpy(tensor: torch.Tensor) -> np.ndarray: """Converts a torch tensor to a numpy array.""" if tensor.dtype == torch.bfloat16: assert bf16.itemsize == 2 # Sanity check return tensor.view(torch.int16).numpy().view(bf16) return tensor.numpy()
[docs]def from_numpy(array: np.ndarray) -> torch.Tensor: """Converts a numpy array to a torch tensor.""" # Copy non-writeable array to make it writable for torch.from_numpy if not array.flags.writeable: array = array.copy() if is_bf16(array.dtype): return torch.from_numpy(array).view(torch.bfloat16) return torch.from_numpy(array)