# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import time
from glob import glob
from multiprocessing import Process, Queue
from lm_dataformat import Archive, Reader
from more_itertools import divide
[docs]def generate_samples(files, queues, process_id):
for fp in files:
reader = Reader(fp)
for doc, meta in reader._stream_data(get_meta=True, jsonl_key="text"):
queues[meta["redpajama_set_name"]].put({"text": doc, "meta": meta})
print(f"process {process_id} is done!")
[docs]def write_samples(q, dataset, out_dir):
output_dir = os.path.join(out_dir, dataset.replace("RedPajama", ""))
os.makedirs(output_dir, exist_ok=True)
ar = Archive(output_dir, threads=10)
i = 0
start_time = time.time()
while True:
try:
doc = q.get(timeout=30)
ar.add_data(doc["text"], doc["meta"])
except TypeError:
assert doc == "Done!"
ar.commit(archive_name="slimpajama" + str(ar.i))
break
i += 1
if i % 100000 == 0:
ar.commit(archive_name="slimpajama" + str(ar.i))
print(
f"Total number of processed documents: {i} ",
f"Total time: {time.time() - start_time}",
)
print(f"Finished writing documents for {dataset}.")
[docs]def main(args):
files = glob(os.path.join(args.input_dir, "**/*.jsonl.zst"), recursive=True)
files = sorted(files)
n_process = args.processes
files = divide(n_process, files)
chunks = [list(f) for f in files]
datasets = [
"RedPajamaCommonCrawl",
"RedPajamaC4",
"RedPajamaGithub",
"RedPajamaBook",
"RedPajamaArXiv",
"RedPajamaWikipedia",
"RedPajamaStackExchange",
]
producers = []
queues = {dataset: Queue(64 * 10000) for dataset in datasets}
for process_id in range(n_process):
p = Process(
target=generate_samples,
args=(
chunks[process_id],
queues,
process_id,
),
)
producers.append(p)
consumers = []
for dataset in datasets:
p = Process(
target=write_samples,
args=(
queues[dataset],
dataset,
args.output_dir,
),
)
consumers.append(p)
for p in producers:
p.start()
for p in consumers:
p.start()
for p in producers:
p.join()
for q in queues.values():
q.put("Done!")
for p in consumers:
p.join()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir")
parser.add_argument("--output_dir")
parser.add_argument("--processes", type=int)
args = parser.parse_args()
main(args)