Source code for cerebras.pytorch.utils.data.dataloader

# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause

"""Defines the Cerebras DataLoader class and RestartableDataLoader protocol class."""
import inspect
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum, auto
from functools import cached_property
from typing import Any, Callable, Dict, Iterable, List, Union

import torch
from typing_extensions import Protocol, runtime_checkable

from cerebras.appliance.log import ClassLogger, named_class_logger
from cerebras.pytorch.backend import Backend, current_backend_impl
from cerebras.pytorch.utils.data.utils import infer_batch_size


class _DataLoaderState(Enum):
    UNKNOWN = auto()
    UNAVAILABLE = auto()


@named_class_logger
class DataLoader(ClassLogger):
    """
    Wrapper around torch.utils.data.DataLoader that facilitates
    moving data generated by the dataloader to a Cerebras system

    Args:
        input_fn: A callable that returns a torch.utils.data.DataLoader
            instance or an iterable that returns a structure containing torch
            tensors.
        *args, **kwargs: Any other positional or keyword arguments
            are passed into the input_fn when each worker instantiates
            their respective dataloaders
    """

    _id_counter = 0
    STATE_UNKNOWN = _DataLoaderState.UNKNOWN
    STATE_UNAVAILABLE = _DataLoaderState.UNAVAILABLE

    def __init__(
        self,
        input_fn: Callable[..., Union[torch.utils.data.DataLoader, Iterable]],
        *args,
        **kwargs,
    ):
        if not callable(input_fn):
            raise TypeError(
                "Expected a callable that constructs and returns a "
                "`torch.utils.data.DataLoader` or an iterable that "
                "returns a structure containing torch tensors."
            )

        # Properties accessed by the backend
        DataLoader._id_counter += 1
        self.id = DataLoader._id_counter
        self.input_fn = input_fn
        self.input_fn_params = deepcopy((args, kwargs))
        self.cached_state = self.STATE_UNKNOWN

        self.dataloader = input_fn(*args, **kwargs)
        if isinstance(self.dataloader, torch.utils.data.DataLoader):
            original_persistent_workers = self.dataloader.persistent_workers
            original_num_workers = self.dataloader.num_workers
            try:
                self.dataloader.__initialized = False
                self.dataloader.persistent_workers = False
                self.dataloader.num_workers = 0

                # If the original num workers is greater than zero and a
                # worker_init_fn was provided, we need to call it with
                # worker_id=0 to ensure that the dataloader is initialized
                # correctly.
                if (
                    self.dataloader.worker_init_fn is not None
                    and original_num_workers > 0
                ):
                    self.dataloader.worker_init_fn(0)
            except:
                # If worker_init_fn fails, we still want to restore the
                # original values of persistent_workers and num_workers
                self.dataloader.persistent_workers = original_persistent_workers
                self.dataloader.num_workers = original_num_workers
            finally:
                self.dataloader.__initialized = True

        # Call this once here to error out early if the dataloader is malformed
        _ = self._has_strict_kwarg

        self._load_state_kwargs = dict()
        self.batch_size = None

    def __copy__(self):
        dl = DataLoader(
            self.input_fn, *self.input_fn_params[0], **self.input_fn_params[1]
        )
        dl._load_state_kwargs == deepcopy(self._load_state_kwargs)
        dl.cached_state = deepcopy(self.cached_state)
        return dl

    @property
    def is_restartable(self) -> bool:
        """Returns True if dataloader is restartable."""
        return isinstance(self.dataloader, RestartableDataLoader)

    @cached_property
    def _has_strict_kwarg(self) -> bool:
        """Check if the dataloader accepts a `stict` named argument.

        Originally, we did not accept `strict` named argument in the interface. In later
        versions, this argument was added, but we still maintain backwards compatibility.
        """
        if not self.is_restartable:
            return False

        has_kwarg = [
            "strict" in inspect.signature(method).parameters
            for method in [
                self.dataloader.load_state_dict,
                self.dataloader.deaggregate_state_dict,
            ]
        ]

        if any(has_kwarg):
            if not all(has_kwarg):
                raise TypeError(
                    "Either both or none of `load_state_dict` and `deaggregate_state_dict` "
                    "should have the `strict` kwargs in their signature."
                )
            return True
        return False

    @property
    def _backend(self) -> Backend:
        """Returns the current backend implementation."""
        return current_backend_impl()

    def state_dict(self) -> Dict[str, Any]:
        """Returns dataloader state to save in a checkpoint
        by invoking the saving mechanism of the
        :py:class:`~cerebras.pytorch.utils.data.RestartableDataLoader` API.

        Returns:
            `dict` capturing dataloader state as specified in the
            implementation of the dataloader's `aggregate_state_dict`
            method
        """
        if not self.is_restartable:
            raise RuntimeError(
                f"DataLoader is not configured for getting state. "
                f"Please implement {RestartableDataLoader.__name__} interface "
                f"to enable `state_dict()` and `load_state_dict()` methods."
            )

        if not (
            self._backend.run_context.is_checkpoint_step
            or self._backend.run_context.is_pre_initial_step
            or self._backend.run_context.is_final_step
        ):
            raise RuntimeError(
                "DataLoader state can only be requested at a checkpoint step. Please "
                "ensure that `state_dict` is called on the `cstorch.utils.DataLoader` "
                "at a checkpoint step. If you're calling it inside of a method, please "
                "decorate it with the `cstorch.checkpoint_closure` method decorator."
            )

        # If the state is not known, we need to query it from somewhere, cached it, and return it.
        if self.cached_state in [self.STATE_UNKNOWN, self.STATE_UNAVAILABLE]:
            self._configure_worker_state(
                self._backend.run_context.user_iteration
            )

            if self._backend.backend_type.is_csx:
                if self._backend.run_context.is_pre_initial_step:
                    if self.cached_state is not self.STATE_UNKNOWN:
                        raise RuntimeError(
                            "Invalid dataloader cached state! At the pre-initial step, the cached "
                            "state should not be STATE_UNAVAILABLE."
                        )

                    # If no state was loaded but a state is requested before the run has started,
                    # call the dataloader directly to return its current state.
                    return deepcopy(
                        self.dataloader.aggregate_state_dict(
                            [self.dataloader.state_dict()]
                        )
                    )
                else:
                    # Fetch state from the appliance workers
                    worker_states: List[DataLoaderCheckpoint] = (
                        self._backend.appliance.grpc_client.fetch_dataloader_state(
                            self._backend.run_context.user_iteration
                        )
                    )
                    # For aggregation, we only pass the per WRK state dict users explicitly
                    # chose to save in their `state_dict` implementation.
                    self.cached_state = self.dataloader.aggregate_state_dict(
                        [
                            worker_state.user_state_dict
                            for worker_state in worker_states
                        ]
                    )
            else:
                self.cached_state = self.dataloader.aggregate_state_dict(
                    [self.dataloader.state_dict()]
                )

            self.cached_state = (self.cached_state, self._load_state_kwargs)

        return deepcopy(self.cached_state[0])

    def load_state_dict(
        self, state_dict: Dict[str, Any], strict: bool = True
    ) -> None:
        """Loads dataloader state from the provided `state_dict`
        by invoking the loading mechanism of the
        :py:class:`~cerebras.pytorch.utils.data.RestartableDataLoader` API.

        Args:
            state_dict: dict capturing dataloader state loaded from a
                checkpoint
            strict: Whether to enforce strict matching of the incoming state_dict.
        """
        if not self.is_restartable:
            raise RuntimeError(
                f"DataLoader is not configured for setting state. "
                f"Please implement {RestartableDataLoader.__name__} interface "
                f"to enable `state_dict()` and `load_state_dict()` methods."
            )

        if (
            self._backend.in_run_context
            and not self._backend.run_context.is_pre_initial_step
        ):
            raise RuntimeError(
                "DataLoader state can only be loaded onto before execution. "
                "Please make sure to call `load_state_dict()` only before "
                "iterating the data executor."
            )

        if self._has_strict_kwarg:
            self._load_state_kwargs["strict"] = strict

        if self._backend.backend_type.is_csx:
            self.cached_state = (state_dict, self._load_state_kwargs)

        # Strictly speaking, there's no need to load state onto the dataloader
        # running on user node when on CSX. But dataloader could return different
        # data types depending on the loaded state, in which case we may get
        # mismatches between the input spec return on USR node vs WRK nodes.
        # So opportunistically load state in all cases, but don't error out
        # when running on CSX. If it's a real failure, let it fail on the workers.
        self._configure_worker_state(0)
        try:
            self.dataloader.load_state_dict(
                self.dataloader.deaggregate_state_dict(
                    state_dict, **self._load_state_kwargs
                ),
                **self._load_state_kwargs,
            )
        except Exception as e:
            if self._backend.is_csx:
                self.logger.warning(
                    f"Could not load state onto the DataLoader running on "
                    f"user node. This is not strictly an error because the "
                    f"actual dataloader running on the workers may still be "
                    f"able to load the state. However, this may potentially "
                    f"cause a mismatch between the input specification seen "
                    f"in compile vs. what the dataloader returns to workers, "
                    f"which may cause issues further during execution. "
                    f"The reason for not loading state is: {e}"
                )
            else:
                raise

    def __len__(self):
        return len(self.dataloader)

    def __iter__(self):
        self._configure_worker_state(0)

        for batch in self.dataloader:
            self.batch_size = infer_batch_size(batch, self.batch_size)
            yield batch

    def _configure_worker_state(self, step: int):
        from cerebras.pytorch.distributed.worker_state import WorkerState

        WorkerState.configure(
            DataLoaderCheckpoint(
                local_worker_id=0,
                num_workers_per_csx=1,
                num_csx=1,
                wse_id=0,
                appliance_step=step,
                worker_step=step,
                samples_streamed=step * self.batch_size if step > 0 else 0,
                user_state_dict=None,
            )
        )


[docs]@runtime_checkable class RestartableDataLoader(Protocol): """Defines interface for the restartable dataloader protocol."""
[docs] def state_dict(self) -> Dict[str, Any]: """Use this method to specify what state information should be saved by each CSX Worker. Returns: dict holding state information for the CSX Worker In order to access Cerebras internal data checkpoint info per CSX Worker at some checkpoint step, follow the steps in the example below. Cerebras internal data checkpoint format is recorded in the :py:class:`~cerebras.pytorch.utils.data.DataLoaderCheckpoint` dataclass. Usage: :: import cerebras.pytorch as cstorch ... def state_dict(self) -> Dict[str, Any]: worker_state = cstorch.distributed.get_worker_state() state_dict = {} if worker_state: state_dict["worker_step"] = worker_state.worker_step state_dict["worker_id"] = worker_state.global_worker_id return state_dict .. note:: The call to :py:func:`~cerebras.pytorch.distributed.get_worker_state` is well-defined only inside of the `state_dict` method; using this anywhere else will result in a RuntimeError exception. See linked docs for more details. """
[docs] def load_state_dict( self, state_dict: Dict[str, Any], strict: bool = True ) -> None: """Use this method to load CSX Worker state for the dataloader instance, as captured from a previous run. Args: state_dict: dict holding worker state info, specified in :py:meth:`~cerebras.pytorch.utils.data.RestartableDataLoader.deaggregate_state_dict` strict: Whether to enforce strict matching of the incoming state_dict. It is up to the implementation to decide what "strict matching" is. Usage: :: def load_state_dict(self, state_dict, strict=True): wrk_state_dict = state_dict.get("worker_0", {}) worker_step = wrk_state_dict.get("worker_step", 0) worker_id = wrk_state_dict.get("worker_id") print(f"WRK {worker_id} loaded step: {worker_step}") """
[docs] def aggregate_state_dict( self, worker_states: List[Dict[str, Any]] ) -> Dict[str, Any]: """Use this method to specify how to combine the list of CSX Worker state dicts. Each CSX Worker state in the `worker_states` list is to be specified in :py:meth:`~cerebras.pytorch.utils.data.RestartableDataLoader.state_dict` Returns: The consolidated state dict that will be saved in a checkpoint. Usage: :: def aggregate_state_dict(self, worker_states): return { "worker_0": worker_states[0], "worker_1": worker_states[1] } """
[docs] def deaggregate_state_dict( self, aggregated_state_dict: Dict[str, Any], strict: bool = True ) -> Dict[str, Any]: """Use this method to specify how to load an individual CSX Worker state given a consolidated list of state dicts, as specified in :py:meth:`~cerebras.pytorch.utils.data.RestartableDataLoader.aggregate_state_dict`. Args: aggregated_state_dict: The aggregated state dict to deaggregate. strict: Whether to enforce strict matching of the incoming state_dict. It is up to the implementation to decide what "strict matching" is. Returns: The state dict will be passed to the above-defined :py:meth:`~cerebras.pytorch.utils.data.RestartableDataLoader.load_state_dict` method. Usage: :: def deaggregate_state_dict(self, aggregated_state_dict, strict=True): return { "worker_0": aggregated_state_dict.get("worker_0", {}) } """
@dataclass class DataLoaderCheckpoint: """Dataclass representing the Cerebras internal dataloader checkpoint format. Each CSX Worker captures its state information via this class at a checkpoint step. Attributes: global_worker_id: ID of this worker amongst all other workers across all boxes local_worker_id: ID of this worker amongst all other workers across the same box total_num_workers: The total number of workers for the run across all boxes num_workers_per_csx: The total number of workers per box for the run num_csx: The total number of CSXs (boxes) for the run wse_id: ID of the Wafer-Scale Engine (CSX) to which this worker streams data appliance_step: The appliance step at which this checkpoint state info is captured worker_step: The worker step at which this state info is captured. Note that this is simply equal to `appliance_step` if `num_workers_per_csx = 1`; for the multi-worker scenario, the appliance step is distributed across workers on a single box in a round-robin fashion based on the local worker id samples_streamed: The total number of samples streamed by this worker at checkpoint step. This is simply `worker_step` * `per_box_batch_size` .. note:: `appliance_step`, `worker_step` and `samples_streamed` are the attributes that vary across different steps; whereas the other attributes provide constant state information for the current run. """ local_worker_id: int num_workers_per_csx: int num_csx: int wse_id: int appliance_step: int worker_step: int samples_streamed: int # User-defined state dict for the CSX Worker. This object must be picklable. user_state_dict: Dict[str, Any] @property def global_worker_id(self) -> int: return self.wse_id * self.num_workers_per_csx + self.local_worker_id @property def total_num_workers(self) -> int: return self.num_workers_per_csx * self.num_csx