# Copyright 2016-2023 Cerebras Systems
# SPDX-License-Identifier: BSD-3-Clause
"""Defines the Cerebras DataLoader class and RestartableDataLoader protocol class."""
import inspect
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum, auto
from functools import cached_property
from typing import Any, Callable, Dict, Iterable, List, Union
import torch
from typing_extensions import Protocol, runtime_checkable
from cerebras.appliance.log import ClassLogger, named_class_logger
from cerebras.pytorch.backend import Backend, current_backend_impl
from cerebras.pytorch.utils.data.utils import infer_batch_size
class _DataLoaderState(Enum):
UNKNOWN = auto()
UNAVAILABLE = auto()
@named_class_logger
class DataLoader(ClassLogger):
"""
Wrapper around torch.utils.data.DataLoader that facilitates
moving data generated by the dataloader to a Cerebras system
Args:
input_fn: A callable that returns a torch.utils.data.DataLoader
instance or an iterable that returns a structure containing torch
tensors.
*args, **kwargs: Any other positional or keyword arguments
are passed into the input_fn when each worker instantiates
their respective dataloaders
"""
_id_counter = 0
STATE_UNKNOWN = _DataLoaderState.UNKNOWN
STATE_UNAVAILABLE = _DataLoaderState.UNAVAILABLE
def __init__(
self,
input_fn: Callable[..., Union[torch.utils.data.DataLoader, Iterable]],
*args,
**kwargs,
):
if not callable(input_fn):
raise TypeError(
"Expected a callable that constructs and returns a "
"`torch.utils.data.DataLoader` or an iterable that "
"returns a structure containing torch tensors."
)
# Properties accessed by the backend
DataLoader._id_counter += 1
self.id = DataLoader._id_counter
self.input_fn = input_fn
self.input_fn_params = deepcopy((args, kwargs))
self.cached_state = self.STATE_UNKNOWN
self.dataloader = input_fn(*args, **kwargs)
if isinstance(self.dataloader, torch.utils.data.DataLoader):
original_persistent_workers = self.dataloader.persistent_workers
original_num_workers = self.dataloader.num_workers
try:
self.dataloader.__initialized = False
self.dataloader.persistent_workers = False
self.dataloader.num_workers = 0
# If the original num workers is greater than zero and a
# worker_init_fn was provided, we need to call it with
# worker_id=0 to ensure that the dataloader is initialized
# correctly.
if (
self.dataloader.worker_init_fn is not None
and original_num_workers > 0
):
self.dataloader.worker_init_fn(0)
except:
# If worker_init_fn fails, we still want to restore the
# original values of persistent_workers and num_workers
self.dataloader.persistent_workers = original_persistent_workers
self.dataloader.num_workers = original_num_workers
finally:
self.dataloader.__initialized = True
# Call this once here to error out early if the dataloader is malformed
_ = self._has_strict_kwarg
self._load_state_kwargs = dict()
self.batch_size = None
def __copy__(self):
dl = DataLoader(
self.input_fn, *self.input_fn_params[0], **self.input_fn_params[1]
)
dl._load_state_kwargs == deepcopy(self._load_state_kwargs)
dl.cached_state = deepcopy(self.cached_state)
return dl
@property
def is_restartable(self) -> bool:
"""Returns True if dataloader is restartable."""
return isinstance(self.dataloader, RestartableDataLoader)
@cached_property
def _has_strict_kwarg(self) -> bool:
"""Check if the dataloader accepts a `stict` named argument.
Originally, we did not accept `strict` named argument in the interface. In later
versions, this argument was added, but we still maintain backwards compatibility.
"""
if not self.is_restartable:
return False
has_kwarg = [
"strict" in inspect.signature(method).parameters
for method in [
self.dataloader.load_state_dict,
self.dataloader.deaggregate_state_dict,
]
]
if any(has_kwarg):
if not all(has_kwarg):
raise TypeError(
"Either both or none of `load_state_dict` and `deaggregate_state_dict` "
"should have the `strict` kwargs in their signature."
)
return True
return False
@property
def _backend(self) -> Backend:
"""Returns the current backend implementation."""
return current_backend_impl()
def state_dict(self) -> Dict[str, Any]:
"""Returns dataloader state to save in a checkpoint
by invoking the saving mechanism of the
:py:class:`~cerebras.pytorch.utils.data.RestartableDataLoader` API.
Returns:
`dict` capturing dataloader state as specified in the
implementation of the dataloader's `aggregate_state_dict`
method
"""
if not self.is_restartable:
raise RuntimeError(
f"DataLoader is not configured for getting state. "
f"Please implement {RestartableDataLoader.__name__} interface "
f"to enable `state_dict()` and `load_state_dict()` methods."
)
if not (
self._backend.run_context.is_checkpoint_step
or self._backend.run_context.is_pre_initial_step
or self._backend.run_context.is_final_step
):
raise RuntimeError(
"DataLoader state can only be requested at a checkpoint step. Please "
"ensure that `state_dict` is called on the `cstorch.utils.DataLoader` "
"at a checkpoint step. If you're calling it inside of a method, please "
"decorate it with the `cstorch.checkpoint_closure` method decorator."
)
# If the state is not known, we need to query it from somewhere, cached it, and return it.
if self.cached_state in [self.STATE_UNKNOWN, self.STATE_UNAVAILABLE]:
self._configure_worker_state(
self._backend.run_context.user_iteration
)
if self._backend.backend_type.is_csx:
if self._backend.run_context.is_pre_initial_step:
if self.cached_state is not self.STATE_UNKNOWN:
raise RuntimeError(
"Invalid dataloader cached state! At the pre-initial step, the cached "
"state should not be STATE_UNAVAILABLE."
)
# If no state was loaded but a state is requested before the run has started,
# call the dataloader directly to return its current state.
return deepcopy(
self.dataloader.aggregate_state_dict(
[self.dataloader.state_dict()]
)
)
else:
# Fetch state from the appliance workers
worker_states: List[DataLoaderCheckpoint] = (
self._backend.appliance.grpc_client.fetch_dataloader_state(
self._backend.run_context.user_iteration
)
)
# For aggregation, we only pass the per WRK state dict users explicitly
# chose to save in their `state_dict` implementation.
self.cached_state = self.dataloader.aggregate_state_dict(
[
worker_state.user_state_dict
for worker_state in worker_states
]
)
else:
self.cached_state = self.dataloader.aggregate_state_dict(
[self.dataloader.state_dict()]
)
self.cached_state = (self.cached_state, self._load_state_kwargs)
return deepcopy(self.cached_state[0])
def load_state_dict(
self, state_dict: Dict[str, Any], strict: bool = True
) -> None:
"""Loads dataloader state from the provided `state_dict`
by invoking the loading mechanism of the
:py:class:`~cerebras.pytorch.utils.data.RestartableDataLoader` API.
Args:
state_dict: dict capturing dataloader state loaded from a
checkpoint
strict: Whether to enforce strict matching of the incoming state_dict.
"""
if not self.is_restartable:
raise RuntimeError(
f"DataLoader is not configured for setting state. "
f"Please implement {RestartableDataLoader.__name__} interface "
f"to enable `state_dict()` and `load_state_dict()` methods."
)
if (
self._backend.in_run_context
and not self._backend.run_context.is_pre_initial_step
):
raise RuntimeError(
"DataLoader state can only be loaded onto before execution. "
"Please make sure to call `load_state_dict()` only before "
"iterating the data executor."
)
if self._has_strict_kwarg:
self._load_state_kwargs["strict"] = strict
if self._backend.backend_type.is_csx:
self.cached_state = (state_dict, self._load_state_kwargs)
# Strictly speaking, there's no need to load state onto the dataloader
# running on user node when on CSX. But dataloader could return different
# data types depending on the loaded state, in which case we may get
# mismatches between the input spec return on USR node vs WRK nodes.
# So opportunistically load state in all cases, but don't error out
# when running on CSX. If it's a real failure, let it fail on the workers.
self._configure_worker_state(0)
try:
self.dataloader.load_state_dict(
self.dataloader.deaggregate_state_dict(
state_dict, **self._load_state_kwargs
),
**self._load_state_kwargs,
)
except Exception as e:
if self._backend.is_csx:
self.logger.warning(
f"Could not load state onto the DataLoader running on "
f"user node. This is not strictly an error because the "
f"actual dataloader running on the workers may still be "
f"able to load the state. However, this may potentially "
f"cause a mismatch between the input specification seen "
f"in compile vs. what the dataloader returns to workers, "
f"which may cause issues further during execution. "
f"The reason for not loading state is: {e}"
)
else:
raise
def __len__(self):
return len(self.dataloader)
def __iter__(self):
self._configure_worker_state(0)
for batch in self.dataloader:
self.batch_size = infer_batch_size(batch, self.batch_size)
yield batch
def _configure_worker_state(self, step: int):
from cerebras.pytorch.distributed.worker_state import WorkerState
WorkerState.configure(
DataLoaderCheckpoint(
local_worker_id=0,
num_workers_per_csx=1,
num_csx=1,
wse_id=0,
appliance_step=step,
worker_step=step,
samples_streamed=step * self.batch_size if step > 0 else 0,
user_state_dict=None,
)
)
[docs]@runtime_checkable
class RestartableDataLoader(Protocol):
"""Defines interface for the restartable dataloader protocol."""
[docs] def state_dict(self) -> Dict[str, Any]:
"""Use this method to specify what state information should be saved
by each CSX Worker.
Returns:
dict holding state information for the CSX Worker
In order to access Cerebras internal data checkpoint info per
CSX Worker at some checkpoint step, follow the steps in the example
below. Cerebras internal data checkpoint format is recorded in the
:py:class:`~cerebras.pytorch.utils.data.DataLoaderCheckpoint` dataclass.
Usage:
::
import cerebras.pytorch as cstorch
...
def state_dict(self) -> Dict[str, Any]:
worker_state = cstorch.distributed.get_worker_state()
state_dict = {}
if worker_state:
state_dict["worker_step"] = worker_state.worker_step
state_dict["worker_id"] = worker_state.global_worker_id
return state_dict
.. note::
The call to :py:func:`~cerebras.pytorch.distributed.get_worker_state`
is well-defined only inside of the `state_dict` method; using this
anywhere else will result in a RuntimeError exception. See linked
docs for more details.
"""
[docs] def load_state_dict(
self, state_dict: Dict[str, Any], strict: bool = True
) -> None:
"""Use this method to load CSX Worker state for the dataloader instance,
as captured from a previous run.
Args:
state_dict: dict holding worker state info, specified in
:py:meth:`~cerebras.pytorch.utils.data.RestartableDataLoader.deaggregate_state_dict`
strict: Whether to enforce strict matching of the incoming state_dict. It is up to the
implementation to decide what "strict matching" is.
Usage:
::
def load_state_dict(self, state_dict, strict=True):
wrk_state_dict = state_dict.get("worker_0", {})
worker_step = wrk_state_dict.get("worker_step", 0)
worker_id = wrk_state_dict.get("worker_id")
print(f"WRK {worker_id} loaded step: {worker_step}")
"""
[docs] def aggregate_state_dict(
self, worker_states: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Use this method to specify how to combine the list of CSX Worker state dicts.
Each CSX Worker state in the `worker_states` list is to be specified in
:py:meth:`~cerebras.pytorch.utils.data.RestartableDataLoader.state_dict`
Returns:
The consolidated state dict that will be saved in a checkpoint.
Usage:
::
def aggregate_state_dict(self, worker_states):
return {
"worker_0": worker_states[0],
"worker_1": worker_states[1]
}
"""
[docs] def deaggregate_state_dict(
self, aggregated_state_dict: Dict[str, Any], strict: bool = True
) -> Dict[str, Any]:
"""Use this method to specify how to load an individual CSX Worker state given
a consolidated list of state dicts, as specified in
:py:meth:`~cerebras.pytorch.utils.data.RestartableDataLoader.aggregate_state_dict`.
Args:
aggregated_state_dict: The aggregated state dict to deaggregate.
strict: Whether to enforce strict matching of the incoming state_dict. It is up to the
implementation to decide what "strict matching" is.
Returns:
The state dict will be passed to the above-defined
:py:meth:`~cerebras.pytorch.utils.data.RestartableDataLoader.load_state_dict` method.
Usage:
::
def deaggregate_state_dict(self, aggregated_state_dict, strict=True):
return {
"worker_0": aggregated_state_dict.get("worker_0", {})
}
"""
@dataclass
class DataLoaderCheckpoint:
"""Dataclass representing the Cerebras internal dataloader checkpoint format.
Each CSX Worker captures its state information via this class at a checkpoint
step.
Attributes:
global_worker_id:
ID of this worker amongst all other workers across all boxes
local_worker_id:
ID of this worker amongst all other workers across the same box
total_num_workers:
The total number of workers for the run across all boxes
num_workers_per_csx:
The total number of workers per box for the run
num_csx:
The total number of CSXs (boxes) for the run
wse_id:
ID of the Wafer-Scale Engine (CSX) to which this worker streams data
appliance_step:
The appliance step at which this checkpoint state info is captured
worker_step:
The worker step at which this state info is captured. Note that this
is simply equal to `appliance_step` if `num_workers_per_csx = 1`;
for the multi-worker scenario, the appliance step is distributed
across workers on a single box in a round-robin fashion based on
the local worker id
samples_streamed:
The total number of samples streamed by this worker at checkpoint
step. This is simply `worker_step` * `per_box_batch_size`
.. note::
`appliance_step`, `worker_step` and `samples_streamed` are the attributes
that vary across different steps; whereas the other attributes provide
constant state information for the current run.
"""
local_worker_id: int
num_workers_per_csx: int
num_csx: int
wse_id: int
appliance_step: int
worker_step: int
samples_streamed: int
# User-defined state dict for the CSX Worker. This object must be picklable.
user_state_dict: Dict[str, Any]
@property
def global_worker_id(self) -> int:
return self.wse_id * self.num_workers_per_csx + self.local_worker_id
@property
def total_num_workers(self) -> int:
return self.num_workers_per_csx * self.num_csx