# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This is Dataset process for processing Raw data set on the fly
This contains methods for loading the dataset, tokenizing the dataset
and all data transformations are handled as part of the collator function
"""
import random
from typing import Any, Dict, Iterator, List
import numpy as np
import torch
from torch.utils.data import DataLoader, default_collate
from cerebras.modelzoo.common.registry import registry
from cerebras.modelzoo.data.common.input_utils import (
num_tasks,
shard_list_contiguous,
task_id,
)
from cerebras.modelzoo.data_preparation.data_preprocessing.data_preprocessor import (
DataPreprocessor,
)
from cerebras.modelzoo.data_preparation.raw_dataset_processor.utils import (
Reader,
)
[docs]@registry.register_datasetprocessor("RawDatasetProcessor")
class RawDatasetProcessor(torch.utils.data.IterableDataset):
def __init__(self, params: Dict[str, Any]):
super(RawDatasetProcessor, self).__init__()
self.params = params
self.preprocessing_params = self.params.get("preprocessing", None)
self.dataset_processor = DataPreprocessor(self.preprocessing_params)
self.features_list = self.preprocessing_params["processing"].get(
"features_list", ["input_ids", "attention_mask", "labels"]
)
self.num_workers = params.get("num_workers", 0)
self.drop_last = params.get("drop_last", True)
if self.num_workers == 0:
self.prefetch_factor = None
else:
self.prefetch_factor = params.get("prefetch_factor", 10)
self.persistent_workers = params.get("persistent_workers", True)
self.reader = None
self.batch_size = params.get("batch_size", None)
self.seed = self.params.pop("seed", None)
self.rng = random.Random(self.seed)
self.reader = Reader(
self.dataset_processor.input_files,
keys=self.dataset_processor.data_keys,
format_hook_fn=self.dataset_processor.format_hook_fn,
)
self.num_tasks = num_tasks()
self.task_id = task_id()
self.input_files_in_this_task = shard_list_contiguous(
self.dataset_processor.input_files, self.task_id, self.num_tasks
)
def _worker_init_fn(self, worker_id: int):
"""
Initialization function for each worker in a DataLoader.
Args:
worker_id (int): The ID of the current worker.
"""
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
num_workers = worker_info.num_workers
else:
# Single-process
worker_id = 0
num_workers = 1
if self.seed is not None:
# Use a unique seed for each worker.
random.seed(self.seed + worker_id)
# Shard the data files between workers
self.input_files_in_this_worker = shard_list_contiguous(
self.input_files_in_this_task, worker_id, num_workers
)
def __iter__(self) -> Iterator[Dict[str, np.ndarray]]:
"""
Returns an iterator over the items of the class.
Returns:
Iterator[Dict[str, np.ndarray]]: An iterator yielding dictionaries with string keys
and NumPy array values.
"""
return self.get_next_item()
[docs] def get_next_item(self) -> Iterator[Dict[str, np.ndarray]]:
"""
Returns the next item in the iteration.
This function iterates over the data stream from the reader, tokenizes the data,
and yields dictionaries containing features as keys and NumPy arrays as values.
Returns:
Iterator[Dict[str, np.ndarray]]: An iterator yielding dictionaries with string keys
and NumPy array values.
"""
for data in self.reader.stream_data():
data_array = self.dataset_processor.format_hook_fn(data)
# Tokenize the data and get stats
tokenized_data, stats = (
self.dataset_processor.token_generator.encode(data_array)
)
# Continue to next iteration if "data" key is not present
if "data" not in tokenized_data.keys():
continue
# Iterate through the tokenized data and yield feature dictionary
for d in tokenized_data["data"]:
yield {
feature: np.array(d[i], np.int32)
for i, feature in enumerate(self.features_list)
}
[docs] def collate_fn(self, batch: List[Dict[str, np.ndarray]]) -> Any:
"""
Collates a list of dictionaries into a batch
Args:
batch (List[Dict[str, np.ndarray]]): A list of dictionaries, where each dictionary
contains string keys and NumPy array values.
Returns:
Any: The collated batch.
"""
if self.dataset_processor.shuffle:
random.shuffle(batch)
return default_collate(batch)
[docs] def create_dataloader(self) -> DataLoader:
"""
Classmethod to create the dataloader object.
Returns:
DataLoader: A DataLoader object for the dataset.
"""
# Create the DataLoader object with the specified parameters
dataloader = DataLoader(
self,
batch_size=self.batch_size, # Number of samples per batch
drop_last=self.drop_last, # Drop the last incomplete batch if the dataset size is not divisible by the batch size
collate_fn=self.collate_fn, # Function to merge a list of samples to form a mini-batch
num_workers=self.num_workers, # Number of subprocesses to use for data loading
prefetch_factor=(
self.prefetch_factor if self.num_workers > 0 else None
), # Number of samples loaded in advance by each worker
persistent_workers=(
self.persistent_workers if self.num_workers > 0 else False
), # Keep worker processes alive after they finish their tasks
worker_init_fn=(
self._worker_init_fn
if self.num_workers > 0 and self.seed is not None
else None
), # Function to initialize the worker process
)
# set self.data_partitions in case self.num_workers == 0
if self.num_workers == 0:
self._worker_init_fn(0)
return dataloader