# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
# isort: off
import sys
# isort: on
import subprocess
import warnings
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../"))
from cerebras.modelzoo.common.utils.utils import check_and_create_output_dirs
[docs]def parse_args():
"""Argparser definition for command line arguments from user.
Returns:
Argparse namespace object with command line arguments.
"""
parser = argparse.ArgumentParser(
description="Download the raw Pile data and associated vocabulary for pre-processing."
)
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="Base directory where raw data is to be downloaded.",
)
parser.add_argument(
"--name",
type=str,
default="pile",
help=(
"Sub-directory where raw data is to be downloaded."
+ " Defaults to `pile`."
),
)
parser.add_argument(
"--debug",
action="store_true",
help="Checks if a given split exists in remote location.",
)
return parser.parse_args()
[docs]def get_urls_from_split(split):
"""Get urls given split of dataset.
Args:
split (str): Split of dataset to get urls for.
Returns:
List of urls, containing jsonl.zst file names for downloading.
"""
if split == "train":
warnings.warn(
message=(
f"Starting a large download process for full training data."
+ f" This process takes time and needs a storage with"
+ f" at least 500GB space."
),
category=UserWarning,
)
urls = [
f"https://mystic.the-eye.eu/public/AI/pile/train/{i:02}.jsonl.zst"
for i in range(30)
]
elif split == "val":
urls = ["https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst"]
elif split == "test":
urls = ["https://mystic.the-eye.eu/public/AI/pile/test.jsonl.zst"]
return urls
[docs]def get_urls_for_tokenizer_files():
"""Get urls for downloading files for tokenization.
Returns:
A dictionary containing urls for original GPT2 tokenizaiton and GPT-NeoX
tokenization schemes
"""
return {
"gpt2-vocab.bpe": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
"gpt2-encoder.json": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
"neox-20B-tokenizer.json": "https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/20B_tokenizer.json",
}
[docs]def debug_or_download_individual_file(url, filepath, debug=False):
"""Download a single file from url to specified filepath.
Args:
url (str): Url to download the data from.
filepath (str): Filename (with path) to download the data to.
debug (bool): Check if remote file exists. Defaults to `False`.
"""
if debug:
# use --no-check-certificate as eye.ai throws the below error:
# `cannot verify mystic.the-eye.eu's certificate, issued by ‘/C=US/O=Let's Encrypt/CN=R3’,
# Issued certificate has expired.`
cmd = f"wget --no-check-certificate --spider {url}"
subprocess.run(cmd.split(" "), check=True)
return
execute = False
# check for each individual file, because the train split has 30
# individual files and in a potential previous download attempt,
# some files may not have downloaded to the specified path.
if not os.path.isfile(filepath):
execute = True
elif os.stat(filepath).st_size == 0:
# Previous attempt at downloading file failed, but wget stats
# the file. Check if filesize is 0, if so, delete and execute the
# download process again.
execute = True
print(f"Got empty file at {filepath}, deleting and downloading again.")
cmd = f"rm -rf {filepath}"
subprocess.run(cmd.split(" "), check=True)
else:
print(
f"{os.path.basename(filepath)} exists at {os.path.dirname(filepath)}"
+ f", skipping download."
)
# use --no-check-certificate as eye.ai throws the below error:
# `cannot verify mystic.the-eye.eu's certificate, issued by ‘/C=US/O=Let's Encrypt/CN=R3’,
# Issued certificate has expired.`
if execute:
cmd = f"wget --no-check-certificate {url} -O {filepath}"
subprocess.run(cmd.split(" "), check=True)
[docs]def download_pile(args, split):
"""Download The Pile dataset from eye.ai website.
Args:
args (argparse namespace): Arguments for downloading the dataset.
split (str): The subset of the PILE dataset to download.
"""
check_and_create_output_dirs(
os.path.join(args.data_dir, args.name, split),
filetype="jsonl.zst",
)
urls = get_urls_from_split(split)
for url in urls:
filepath = os.path.join(
args.data_dir, args.name, split, os.path.basename(url)
)
debug_or_download_individual_file(url, filepath, args.debug)
[docs]def download_tokenizer_files(args):
"""Download files needed for tokenization for dataset creation.
Args:
args (argparse namespace): Arguments for downloading the tokenizer files.
"""
check_and_create_output_dirs(
os.path.join(args.data_dir, args.name, "vocab"),
filetype="json",
)
check_and_create_output_dirs(
os.path.join(args.data_dir, args.name, "vocab"),
filetype="bpe",
)
urls_to_download = get_urls_for_tokenizer_files()
for key, value in urls_to_download.items():
if args.debug:
cmd = f"wget --no-check-certificate --spider {value}"
subprocess.run(cmd.split(" "), check=True)
# continue since we want to run only debug, but for all items
# in the url dictionary
continue
filepath = os.path.join(args.data_dir, args.name, "vocab", key)
cmd = f"wget --no-check-certificate {value} -O {filepath}"
subprocess.run(cmd.split(" "), check=True)
[docs]def main():
"""Main function for execution."""
args = parse_args()
# download all subsets and the corresponding tokenizer files
for split in ["train", "val", "test"]:
download_pile(args, split)
download_tokenizer_files(args)
if __name__ == "__main__":
main()