Source code for cerebras.modelzoo.data_preparation.nlp.slimpajama.preprocessing.datasets

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import os
import pickle
import random
from glob import glob
from multiprocessing import Manager, Process

from lm_dataformat import Reader
from tqdm import tqdm

from cerebras.modelzoo.data_preparation.nlp.slimpajama.utils import (
    cycle_documents,
    utf8len,
)


[docs]class Dataset(abc.ABC):
[docs] def dir_path(self): """Path to the directory"""
[docs] def short_documents_path(self): """Path to the file with short documents"""
[docs] def name(self): """Human-readable name of tfhe dataset"""
[docs] def documents(self, process_id, n_process, dup_sh, short_sh): """A generator producing all documents in the dataset.""" filtered = 0 total_count = 0 files = glob(self.dir_path()) random.shuffle(files) for file_path in files: reader = Reader(file_path) file_name = file_path.replace(self.stem_dir_path(), "") duplicates_set = dup_sh.get(file_name, set()) short_set = short_sh.get(file_name, set()) for doc_id, doc in enumerate(reader._stream_data(jsonl_key="text")): if doc_id % n_process == process_id: if doc_id not in short_set and doc_id not in duplicates_set: total_count += 1 yield {"doc": doc, "meta": {}} else: filtered += 1 print( f"Total number of documents: {total_count}", f"Filtered documents: {filtered}", )
[docs] def size(self): """Return an estimate of the dataset size. Implementations may use a faster, less accurate estimate.""" size = sum( map( lambda x: utf8len(x["doc"]), tqdm(self.documents(), total=self.num_docs()), ) ) return size
def num_docs(self): num_docs = sum( map( lambda x: 1, tqdm(self.documents(), total=self.num_docs()), ) ) return num_docs
[docs] def already_shuffled(self): """Datasets where the source is already shuffled should override this to return True so that it isn't shuffled again.""" return False
[docs]class RedPajamaBooksDataset(Dataset): def __init__(self, input_dir): self.stem_dir_path_ = input_dir self.dir_path_ = os.path.join(input_dir, "book/*.jsonl") def dir_path(self): return self.dir_path_ def stem_dir_path(self): return self.stem_dir_path_ def name(self): return "RedPajamaBook" def size(self): return 102851843814 def size_duplicate_docs(self): return 2106014751 def size_short_docs(self): return 0 def num_docs(self): return 200242 def num_duplicate_docs(self): return 5502 def num_short_docs(self): return 0
[docs]class RedPajamaArXivDataset(Dataset): def __init__(self, input_dir): self.stem_dir_path_ = input_dir self.dir_path_ = os.path.join(input_dir, "arxiv/*.jsonl") def dir_path(self): return self.dir_path_ def stem_dir_path(self): return self.stem_dir_path_ def name(self): return "RedPajamaArXiv" def size(self): return 89018875739 def size_duplicate_docs(self): return 54749418 def size_short_docs(self): return 574293 def num_docs(self): return 1546641 def num_duplicate_docs(self): return 1979 def num_short_docs(self): return 9686
[docs]class RedPajamaCommonCrawlDataset(Dataset): def __init__(self, input_dir): self.stem_dir_path_ = input_dir self.dir_path_ = os.path.join(input_dir, "common_crawl/*/*.jsonl.zst") def dir_path(self): return self.dir_path_ def stem_dir_path(self): return self.stem_dir_path_ def name(self): return "RedPajamaCommonCrawl" def size(self): return 1384835073956 def size_duplicate_docs(self): return 2436638659265 def size_short_docs(self): return 6867259 def num_docs(self): return 187084822 def num_duplicate_docs(self): return 289100390 def num_short_docs(self): return 90807
[docs]class RedPajamaC4Dataset(Dataset): def __init__(self, input_dir): self.stem_dir_path_ = input_dir self.dir_path_ = os.path.join(input_dir, "c4/*.jsonl") def dir_path(self): return self.dir_path_ def stem_dir_path(self): return self.stem_dir_path_ def name(self): return "RedPajamaC4" def size(self): return 734903985384 def size_duplicate_docs(self): return 53403692569 def size_short_docs(self): return 664163266 def num_docs(self): return 324686115 def num_duplicate_docs(self): return 23015691 def num_short_docs(self): return 17167086
[docs]class RedPajamaWikipediaDataset(Dataset): def __init__(self, input_dir): self.stem_dir_path_ = input_dir self.dir_path_ = os.path.join(input_dir, "wikipedia/*.jsonl") def dir_path(self): return self.dir_path_ def stem_dir_path(self): return self.stem_dir_path_ def name(self): return "RedPajamaWikipedia" def size(self): return 78649866316 def size_duplicate_docs(self): return 1798885899 def size_short_docs(self): return 0 def num_docs(self): return 26967854 def num_duplicate_docs(self): return 2866317 def num_short_docs(self): return 0
[docs]class RedPajamaGithubDataset(Dataset): def __init__(self, input_dir): self.stem_dir_path_ = input_dir self.dir_path_ = os.path.join(input_dir, "github/*.jsonl") def dir_path(self): return self.dir_path_ def stem_dir_path(self): return self.stem_dir_path_ def name(self): return "RedPajamaGithub" def size(self): return 105581774510 def size_duplicate_docs(self): return 90515346113 def size_short_docs(self): return 0 def num_docs(self): return 21232084 def num_duplicate_docs(self): return 7561228 def num_short_docs(self): return 0
[docs]class RedPajamaStackExchangeDataset(Dataset): def __init__(self, input_dir): self.stem_dir_path_ = input_dir self.dir_path_ = os.path.join(input_dir, "stackexchange/*.jsonl") def dir_path(self): return self.dir_path_ def stem_dir_path(self): return self.stem_dir_path_ def name(self): return "RedPajamaStackExchange" def size(self): return 71278349386 def size_duplicate_docs(self): return 139373830 def size_short_docs(self): return 3987870 def num_docs(self): return 29702946 def num_duplicate_docs(self): return 25975 def num_short_docs(self): return 96165
[docs]class RedPajamaReplication(Dataset): def __init__(self, datasets, duplicates, short_docs): self.datasets = datasets self.duplicates = duplicates self.short_docs = short_docs self.rnd_docs = random.Random(42) self.rnd_queues = random.Random(420) def name(self): return "RedPajama" def size(self): return int(sum([weight * ds.size() for ds, weight in self.datasets]))
[docs] def num_docs(self): """Return an estimate of the dataset number of documents. Implementations may use a faster, less accurate estimate.""" return int( sum([ds.num_docs() * weight for ds, weight in self.datasets]) )
def sample_documents( self, weights, k, queues, process_id, n_process, dup_sh, short_sh ): # each process is going to sample documents with batch size k # sampling is happening globally across all available documents; datasets = [] for dataset, _ in self.datasets: datasets.append( ( dataset.name(), cycle_documents( dataset, process_id, n_process, dup_sh, short_sh ), ) ) for j in range(self.num_docs() // k // n_process): if j % 1000 == 0: print(f"Sampling chunk of documents {j}") chunk = self.rnd_docs.choices( population=datasets, weights=weights, k=k, ) for name, documents in chunk: document = next(documents) text, meta = document["doc"], document["meta"] meta["redpajama_set_name"] = name q = self.rnd_queues.choice(queues) q.put({"doc": text, "meta": meta}) print("Finished sampling documents.") def documents(self, queues): weights = [] # calculate relative_weight for each total_weight = sum([x[1] * x[0].num_docs() for x in self.datasets]) for dataset, weight in self.datasets: relative_weight = weight * dataset.num_docs() / total_weight weights.append(relative_weight) with open(self.duplicates, "rb") as fin: dup = pickle.load(fin) with open(self.short_docs, "rb") as fin: short = pickle.load(fin) manager = Manager() dup_sh = manager.dict(dup) short_sh = manager.dict(short) # create processes here to speed up read and write in shuffle_holdout.py # queues are given by shuffle_holdout to populate with documents n_process = 2 * len(queues) k = 1000 procs = [] for process_id in range(n_process): p = Process( target=self.sample_documents, args=( weights, k, queues, process_id, n_process, dup_sh, short_sh, ), ) procs.append(p) return procs, manager
[docs]def redpj_datasets(input_dir): return [ (RedPajamaWikipediaDataset(input_dir), 1.0), (RedPajamaC4Dataset(input_dir), 1.0), (RedPajamaCommonCrawlDataset(input_dir), 1.0), (RedPajamaStackExchangeDataset(input_dir), 1.0), (RedPajamaBooksDataset(input_dir), 1.0), (RedPajamaGithubDataset(input_dir), 1.0), (RedPajamaArXivDataset(input_dir), 1.0), ]