cerebras.modelzoo.data.common.restartable_dataloader.RestartableDataLoader#
- class cerebras.modelzoo.data.common.restartable_dataloader.RestartableDataLoader(*args, **kwargs)[source]#
Bases:
torch.utils.data.DataLoader
Restartable dataloader for an torch.utils.data.Dataset.
The state we care about for allowing deterministic restart of instances of Dataset is the total number of samples streamed globally, which gets consumed by the sampler. Accordingly each worker saves the number of samples that it has streamed in state_dict(). We aggregate these together via summation to save the global number of samples streamed across all workers, which is the same thing that is used to set the state of the sampler on state dict load.
Constructs a RestartableDataLoader instance.
Methods
Aggregates states across all dataloaders into a single state.
Deaggregates state from all dataloaders.
Loads given state into the dataloader.
Returns the state of the current dataloader.