cerebras.modelzoo.common.input_utils.bucketed_batch#
- cerebras.modelzoo.common.input_utils.bucketed_batch(data_iterator, batch_size, buckets=None, element_length_fn=None, collate_fn=None, drop_last=False, seed=None)[source]#
Batch the data from an iterator such that sampels of similar length end up in the same batch. If buckets is not supplied, then this just batches the dataset normally.
- Parameters
data_iterator – An iterater that yields data one sample at a time.
batch_size (int) – The number of samples in a batch.
buckets (list) – A list of bucket boundaries. If set to None, then no bucketing will happen, and data will be batched normally. If set to a list, then data will be grouped into len(buckets) + 1 buckets. A sample s will go into bucket i if buckets[i-1] <= element_length_fn(s) < buckets[i] where 0 and inf are the implied lowest and highest boundaries respectively. buckets must be sorted and all elements must be non-zero.
element_length_fn (callable) – A function that takes a single sample and returns an int representing the length of that sample.
collate_fn (callable) – The function to use to collate samples into a batch. Defaults to PyTorch’s default collate function.
drop_last (bool) – Whether or not to drop incomplete batches at the end of the dataset. If using bucketing, buckets that are not completely full will also be dropped, even if combined there are more than batch_size samples remaining spread across multiple buckets.
seed (int) – If using drop_last = False, we don’t want to feed out leftover samples with order correlated to their lengths. The solution is to shuffle the leftover samples before batching and yielding them. This seed gives the option to make this shuffle deterministic. It is only used when buckets is not None and drop_last = True.
- Yields
Batches of samples of type returned by collate_fn, or batches of PyTorch tensors if using the default collate function.