cerebras.modelzoo.data.common.input_utils.ShardedSampler#

class cerebras.modelzoo.data.common.input_utils.ShardedSampler(*args, **kwargs)[source]#

Bases: torch.utils.data.Sampler

Modified from: https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler Sampler that restricts data loading to a subset of the dataset.

Dataset is assumed to be of constant size.

Parameters
  • dataset (torch.utils.data.Dataset) – Dataset used for sampling.

  • shuffle (bool, optional) – If True (default), sampler will shuffle the indices.

  • seed (int, optional) – Random seed used to shuffle the sampler if shuffle=True. This number should be identical across all processes in the distributed group. Default: 0.

  • drop_last (bool, optional) – If True, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If False, the sampler will add extra indices to make the data evenly divisible across the replicas. Default: False.

Methods