cerebras.modelzoo.common.pytorch_utils.BufferedShuffleDataset#
- class cerebras.modelzoo.common.pytorch_utils.BufferedShuffleDataset(*args, **kwargs)[source]#
Bases:
torch.utils.data.IterableDataset
Dataset shuffled from the original dataset.
This class is useful to shuffle an existing instance of an IterableDataset. The buffer with buffer_size is filled with the items from the dataset first. Then, each item will be yielded from the buffer by reservoir sampling via iterator. buffer_size is required to be larger than 0. For buffer_size == 1, the dataset is not shuffled. In order to fully shuffle the whole dataset, buffer_size is required to be greater than or equal to the size of dataset. When it is used with
DataLoader
, each item in the dataset will be yielded from theDataLoader
iterator. And, the method to set up a random seed is different based onnum_workers
. For single-process mode (num_workers == 0
), the random seed is required to be set before theDataLoader
in the main process.- Parameters
dataset (IterableDataset) – The original IterableDataset.
buffer_size (int) – The buffer size for shuffling.
Example
For multi-process mode (
num_workers > 0
), the random seed is set by a callable function in each worker.>>> ds = BufferedShuffleDataset(dataset) >>> random.seed(...) >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) >>> ds = BufferedShuffleDataset(dataset) >>> def init_fn(worker_id): ... random.seed(...) >>> print(list(torch.utils.data.DataLoader(ds, ..., num_workers=n, worker_init_fn=init_fn)))
Methods