# Copyright 2022 Cerebras Systems.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import pickle
import queue
import time
from collections import defaultdict
from glob import glob
from multiprocessing import Process, Queue

from datasketch.lean_minhash import LeanMinHash
from more_itertools import divide

def _H(hs):
    return bytes(hs.byteswap().data)

[docs]def split_files(input_dir, n_proc): files = [] for dataset in [ "arxiv", "stackexchange", "book", "wikipedia", "github", "c4", "common_crawl", ]: if dataset == "common_crawl": files.extend(glob(f"{input_dir}/{dataset}/*/minhash_nfc/*")) else: files.extend(glob(f"{input_dir}/{dataset}/minhash_nfc/*")) files = sorted(files) parts = divide(n_proc, files) return [list(p) for p in parts]
[docs]def get_hashes(files, doc_queues, r): for fp in files: with open(fp, "rb") as fin: for item in pickle.load(fin): key = f"{item['file_name']}@{item['doc_id']}" minhash = LeanMinHash(item["hash"]) for i, doc_queue in enumerate(doc_queues): H = _H(minhash.hashvalues[i * r : (i + 1) * r]) doc_queue.put((key, H))
[docs]def lsh(out_file, doc_queue, idx): lsh_dict = defaultdict(str) i = 0 start_time = time.time() f = open(out_file.replace(".txt", f"-{idx}.txt"), "w") while True: try: key, H = doc_queue.get(timeout=30) cand = lsh_dict.get(H, "None") if cand != "None": f.write(f'{key} :: {cand}\n') else: lsh_dict[H] = key if i % 100000 == 0: print( f"{idx}: Processed {i / 931361530 * 100}%.", time.time() - start_time, ) i += 1 except queue.Empty: break print(f"Total number of documents: {i}") f.close()
[docs]def generate_pairs(args): # size of the queue was tuned for optimal perf and memory constraints. doc_queues = [Queue(1000000) for _ in range(args.bands)] files = split_files(args.input_dir, args.processes) processes = [] for process_id in range(args.processes): p = Process( target=get_hashes, args=( files[process_id], doc_queues, args.range, ), ) processes.append(p) p.start() for process_id in range(args.bands): p = Process( target=lsh, args=( args.out_file, doc_queues[process_id], process_id, ), ) processes.append(p) p.start() for p in processes: p.join()
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input_dir") parser.add_argument("--out_file") parser.add_argument( "--range", type=int, ) parser.add_argument( "--bands", type=int, ) parser.add_argument( "--processes", type=int, ) args = parser.parse_args() generate_pairs(args)