# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
import shutil
import ujson as json
import zstandard
[docs]def utf8len(s):
return len(s.encode('utf-8'))
[docs]def cycle_documents(dataset, process_id, n_process, dup_sh, short_sh):
while True:
# https://github.com/EleutherAI/the-pile/blob/df97f8651ae3da658b19659b3ceaa6a34b0fc014/the_pile/utils.py#L104
yield from filter(
lambda x: x,
dataset.documents(process_id, n_process, dup_sh, short_sh),
)
[docs]def sha256str(s):
h = hashlib.sha256()
try:
h.update(s.encode("utf-8"))
except UnicodeEncodeError:
# to avoid things like \ud809\udc50\ud808\udefc\ud808\udedb
h.update(s.encode("utf-8", "replace"))
return h.hexdigest()
[docs]def rm_if_exists(path):
try:
if os.path.exists(path):
shutil.rmtree(path)
except NotADirectoryError:
os.remove(path)
[docs]def write_lmd_dataset(fh, lines, indices=None, return_total_written=False):
cctx = zstandard.ZstdCompressor(level=3, threads=10)
compressor = cctx.stream_writer(fh)
# to not store large lists into memory, use index
total_written = 0
if indices is not None:
for index in indices:
text, meta = lines[index]
compressor.write(
json.dumps({"text": text, "meta": meta}).encode("UTF-8") + b"\n"
)
total_written += 1
else:
for line in lines:
text, meta = line
compressor.write(
json.dumps({"text": text, "meta": meta}).encode("UTF-8") + b"\n"
)
total_written += 1
compressor.flush(zstandard.FLUSH_FRAME)
if return_total_written:
return total_written