Restartable dataloaders#
Overview#
In the Cerebras PyTorch API 2.0, we provide revamped support for deterministically restarting any custom input-generating dataloaders used for a run. This feature enables the saving and loading of the dataloader state and seamlessly integrates with our existing mechanism of capturing checkpoints for a run.
Saving DataLoader State#
Similar to how you call state_dict on components such as the model and optimizer to fetch and
save state information in our Cerebras H5-based checkpoint format. You can save the state of
your dataloader by calling state_dict
on the
Cerebras PyTorch dataloader wrapper that must be initialized at the beginning of a custom training loop.
cerebras_dataloader = cstorch.utils.data.DataLoader(input_fn, *args, **kwargs)
...
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"dataloader": cerebras_dataloader.state_dict(),
...
}
cstorch.save(state_dict, "<path to save checkpoint to>")
Note
Our typical workflow in Model Zoo already includes this call on the Cerebras PyTorch dataloader wrapper to save the state of the dataloader being used for the run.
The dataloader state can only be saved at a checkpoint step – i.e. you should wrap the method invoking the call to save the dataloader state in the
checkpoint_closure
decorator.
Loading DataLoader State#
Upon restarting a run from a Cerebras checkpoint file, you can fetch the saved dataloader state (if it
exists) from the loaded checkpoint and pass it to the load_state_dict
method on the Cerebras PyTorch dataloader wrapper to load your dataloader’s state, e.g.
state_dict = cstorch.load("<path to H5 checkpoint file>")
model.load_state_dict(state_dict["model"])
optimizer.load_state_dictt(state_dict["optimizer"])
if "dataloader" in state_dict:
cerebras_dataloader.load_state_dict(state_dict["dataloader"])
And that is all!
Now to specify what “state” information of your dataloader is to be saved in a checkpoint
when state_dict is called on the Cerebras PyTorch dataloader, and how this state information should
be loaded to rewind your dataloader, your dataloader must conform to the protocol class
RestartableDataLoader
.
Restartable DataLoader API#
By implementing methods state_dict
,
aggregate_state_dict
, deaggregate_state_dict
and load_state_dict
, with the appropriate method signatures, your dataloader is guaranteed to be restartable.
That is, you are able to save the state of your dataloader in a checkpoint and load it by the mechanism described above.
Recall that in a distributed setting, each input worker per CSX creates its own instance of the dataloader for parallelism. Thus, implementing these four methods will determine how your dataloader’s state should be saved and loaded to enable deterministic restarts for such settings.
To illustrate the usage of this protocol with an example, we define our CustomRestartableDataLoader class below. The following subsections describe each method signature more generally and within the context of our custom class.
import cerebras.pytorch as cstorch
import torch
class CustomRestartableDataLoader(torch.utils.data.DataLoader):
def state_dict(self) -> Dict[str, Any]:
worker_state = cstorch.distributed.get_worker_state()
state_dict = {
"worker_step": worker_state.worker_step,
"worker_id": worker_state.global_worker_id
"some_state_info": <any other state info>
...
}
return state_dict
def aggregate_state_dict(self, worker_states):
return {
"step_worker_0": worker_states[0]["worker_step"],
"id_worker_1": worker_states[1]["worker_id"],
"combined_step_sum": worker_states[0]["worker_step"] + worker_states[1]["worker_step"]
}
def deaggregate_state_dict(self, aggregated_state_dict):
if "combined_step_sum" not in aggregated_state_dict:
raise RuntimeError(
"The aggregated state dict must contain key `combined_step_sum`. "
"This means that the dataloader state in the checkpoint you are "
"loading from is not compatible with the dataloader currently "
"in use."
)
return {
"combined_step": aggregated_state_dict["combined_step_sum"]
}
def load_state_dict(self, state_dict):
if "combined_step" not in state_dict:
raise RuntimeError(
"The state dict must contain key `combined_step`, but it does not. "
"This means that the dataloader state in the checkpoint you are "
"loading from is not compatible with the dataloader currently "
"in use."
)
print(f"Loading state using combined steps: {state_dict["combined_step"]}")
...
state_dict#
Use this method to specify what state information each input-generating worker should capture at an appliance
checkpoint step. By default, each worker captures some internal state info using our new Cerebras dataloader
checkpoint format defined by the DataLoaderCheckpoint
dataclass. Please
refer to the linked docs on this class for detailed information on each attribute. Essentially, in your
definition of state_dict you may choose to save any of the aforementioned internal state info per worker.
We expose an API method get_worker_state
that you may utilize
in your implementation of state_dict to fetch the worker’s internal state info, e.g.
def state_dict(self) -> Dict[str, Any]:
worker_state = cstorch.distributed.get_worker_state()
state_dict = {
"worker_step": worker_state.worker_step,
"worker_id": worker_state.global_worker_id
"some_state_info": <any other state info>
...
}
return state_dict
Note
The call to
get_worker_state
is well-defined only inside of your implementation of state_dict; calling this method anywhere else will result in a RunimeError exception.Ensure that any other state info you choose to save must be picklable using the dill package.
aggregate_state_dict#
This method accepts the list of individual worker states dicts as an argument. Each state dict inside this list holds per-worker state information as defined in your implementation of the state_dict signature method.
Use this method to specify how to combine the state information of all workers in a single, consolidated state dict, e.g.
def aggregate_state_dict(self, worker_states):
return {
"step_worker_0": worker_states[0]["worker_step"],
"id_worker_1": worker_states[1]["worker_id"],
"combined_step_sum": worker_states[0]["worker_step"] + worker_states[1]["worker_step"]
}
Note
The aggregated state dict represents the state of your dataloader and will eventually be saved in our Cerebras H5 checkpoint file when state_dict is invoked on the Cerebras PyTorch dataloader wrapper to save your dataloader’s state.
In the example above, we’re assuming two total workers used for the run. In the aggregated state dict, we are choosing to save worker 0’s step, worker 1’s global worker id, and the summed step count of both workers as the state of our dataloader.
You can expect the worker_states list to be ordered by the global worker id of each worker.
deaggregate_state_dict#
This method accepts an aggregated state dict as an argument. The aggregated state dict represents the state of your dataloader, as specified in the aggregate_state_dict method signature of your dataloader.
To load your data loader’s state, use this method to specify how the consolidated dataloader state loaded from a checkpoint should be disaggregated into a single state dict defining how each worker should load its state, e.g.
def deaggregate_state_dict(self, aggregated_state_dict):
if "combined_step_sum" not in aggregated_state_dict:
raise RuntimeError(
"The aggregated state dict must contain key `combined_step_sum`. "
"This means that the dataloader state in the checkpoint you are "
"loading from is not compatible with the dataloader currently "
"in use."
)
return {
"combined_step": aggregated_state_dict["combined_step_sum"]
}
In the example above, our implementation has an explicit check to ensure that we’re loading state captured by this dataloader. Upon restart, we assume that each worker cares about the combined step count of all workers in the previous run at the checkpoint we’re loading from; thus, the deaggregation method constructs and returns a single state holding the combined step info.
Note
This method will be particularly useful when the number of workers per box changes between subsequent runs; use this to specify which state dict should be loaded by each worker upon restart.
load_state_dict#
This method accepts a disaggregated state dict as an argument, as defined in your implementation of deaggregate_state_dict.
Use this method to specify how the worker should load its state from the provided, disaggregated state dict, e.g.
def load_state_dict(self, state_dict):
if "combined_step" not in state_dict:
raise RuntimeError(
"The state dict must contain key `combined_step`, but it does not. "
"This means that the dataloader state in the checkpoint you are "
"loading from is not compatible with the dataloader currently "
"in use."
)
print(f"Loading state using combined steps: {state_dict["combined_step"]}")
...
Again, we have an explicit check to ensure that the disaggregated state dict being used by each worker to load its state upon restart is the same as that specified by our data loader’s implementation of deaggregate_state_dict. For this example, each worker simply prints the combined step count from the previous run, but you can imagine using this step count to set other properties on your data loader that enable it to restart deterministically.
Putting it All Together#
Combining all of the above steps, we have the following steps to set up our custom restartable dataloader whose state can be captured via checkpointing:
def restartable_torch_dataloader(batch_size):
from torchvision import datasets
train_dataset = datasets.MNIST(
"/path/to/data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.float16)
),
]
),
target_transform=transforms.Lambda(
lambda x: torch.as_tensor(x, dtype=torch.int32)
)
)
return CustomRestartableDataLoader(train_dataset, batch_size=batch_size, shuffle=True)
cerebras_dataloader = cstorch.utils.data.DataLoader(restartable_torch_dataloader, batch_size=64)
...
@cstorch.checkpoint_closure
def save_checkpoint():
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"dataloader": cerebras_dataloader.state_dict(),
...
}
cstorch.save(state_dict, "<path to save checkpoint to>")
...
def load_checkpoint(checkpoint_file_path):
state_dict = cstorch.load(checkpoint_file_path)
model.load_state_dict(state_dict["model"])
optimizer.load_state_dictt(state_dict["optimizer"])
if "dataloader" in state_dict:
cerebras_dataloader.load_state_dict(state_dict["dataloader"])
Note
It is not necessary for your dataloader to be of type torch.utils.data.DataLoader to enable
the saving and loading of its state; in fact, any iterable that returns a structure comprising
torch tensors can be programmed to be restartable as long as it implements the four signature
methods conforming to the Cerebras RestartableDataLoader
protocol class.