# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from glob import glob
from lm_dataformat import Reader
from tqdm import tqdm
# isort: off
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../"))
# isort: on
from cerebras.modelzoo.data_preparation.nlp.slimpajama.utils import (
rm_if_exists,
sha256str,
write_lmd_dataset,
)
[docs]def deduplicate_train_holdout_sets(
train_path, holdout_path, deduped_train_path, chunk_id
):
# Calculate hashes on holdout set.
seen = set()
if os.path.exists("hashes.txt"):
with open("hashes.txt") as fh:
for line in tqdm(fh):
seen.add(line.strip())
else:
hashf = open("hashes.txt", "w")
for f in tqdm(glob(f"{holdout_path}/*/*.zst")):
reader = Reader(f)
for doc_id, text in enumerate(
reader._stream_data(jsonl_key="text")
):
hash = sha256str(text)
hashf.write(hash + "\n")
seen.add(hash)
hashf.close()
print("Finished collecting hashes for eval", len(seen))
rm_if_exists(f"{deduped_train_path}/chunk{chunk_id}")
os.makedirs(f"{deduped_train_path}/chunk{chunk_id}")
total_written = 0
# Remove elements from train set with hashes seen in eval set.
for f in tqdm(glob(f"{train_path}/chunk{chunk_id}/*.zst")):
def filtered_docs():
reader = Reader(f)
for doc_id, doc in enumerate(reader._stream_data(get_meta=True)):
text, meta = doc
hash = sha256str(text)
if hash not in seen:
yield text, meta
else:
print("Found an intersection!!!")
with open(
f"{deduped_train_path}/chunk{chunk_id}/" + f.split("/")[-1], "wb"
) as fout_dedup_train:
total_written += write_lmd_dataset(
fout_dedup_train,
filtered_docs(),
indices=None,
return_total_written=True,
)
print(f"Total written: {total_written}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("chunk_id", type=int)
parser.add_argument("--src_dir", type=str)
parser.add_argument("--tgt_dir", type=str)
parser.add_argument("--out_dir", type=str)
args = parser.parse_args()
deduplicate_train_holdout_sets(
args.src_dir,
args.tgt_dir,
args.out_dir,
args.chunk_id,
)