# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Script to write HDF5 files for UNet datasets.
Usage:
# For help:
python create_hdf5_files.py -h
# Step-1:
Set image shape to desired shape in
`train_input.image_shape` and `eval_input.image_shape`
i.e. [H, W, 1] in config:
/path_to_modelzoo/vision/pytorch/unet/configs/params_severstal_binary.yaml
# Step-2: Run the script
python modelzoo.data_preparation.vision.unet.create_hdf5_files.py --params=/path_to_modelzoo/vision/pytorch/unet/configs/params_severstal_binary.yaml --output_dir=/path_to_outdir/severstal_binary_classid_3_hdf --num_output_files=10 --num_processes=5
"""
import argparse
import json
import logging
import os
import sys
from collections import defaultdict
from itertools import repeat
from multiprocessing import Pool, cpu_count
import h5py
from tqdm import tqdm
# isort: off
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../"))
# isort: on
from cerebras.modelzoo.common.utils.run.cli_parser import read_params_file
from cerebras.modelzoo.common.utils.utils import check_and_create_output_dirs
from cerebras.modelzoo.data.vision.classification.dataset_factory import (
VisionSubset,
)
from cerebras.modelzoo.data_preparation.utils import split_list
from cerebras.modelzoo.models.internal.vision.unet.utils import set_defaults
from cerebras.modelzoo.data.vision.segmentation.SeverstalBinaryClassDataProcessor import ( # noqa
SeverstalBinaryClassDataProcessor,
)
[docs]def update_params_from_args(args, params):
"""
Sets command line arguments from args into params.
:param argparse namespace args: Command line arguments
:param dict params: runconfig dict we want to update
"""
if args:
for k, v in list(vars(args).items()):
params[k] = v if v is not None else params.get(k)
def _get_dataset(params, is_training):
params["use_worker_cache"] = False
return getattr(sys.modules[__name__], params["data_processor"])(
params
).create_dataset(is_training)
def _get_data_generator(params, is_training, dataset_range):
dataset = _get_dataset(params, is_training)
sub_dataset = VisionSubset(dataset, dataset_range)
sub_dataset.set_transforms()
for idx, feature in enumerate(sub_dataset):
image, label = feature
yield (image, label, image.shape, label.shape)
[docs]def create_h5(params):
dataset_range, data_params, args, process_no = params
n_docs = len(dataset_range)
num_output_files = max(args.num_output_files // args.num_processes, 1)
output_files = [
os.path.join(
args.output_dir,
f"{args.name}-{fidx + num_output_files*process_no}_p{process_no}.h5",
)
for fidx in range(num_output_files)
]
## Create hdf5 writers for each hdf5 file
writers = []
meta_data = defaultdict(int)
writer_num_examples = 0
for output_file in output_files:
w = h5py.File(output_file, "w")
w.attrs["n_examples"] = 0
writers.append([w, writer_num_examples, output_file])
writer_index = 0
total_written = 0
## Names of keys of instance dictionary
fieldnames = ["image", "label"]
is_training = "train" in args.split
data_generator = lambda: _get_data_generator(
data_params, is_training, dataset_range
)
for features in tqdm(data_generator(), total=n_docs):
image, label, image_shape, label_shape = features
## write dictionary into hdf5
writer, writer_num_examples, output_file = writers[writer_index]
grp_name = f"example_{writer_num_examples}"
writer.create_dataset(
f"{grp_name}/image", data=image, shape=image_shape
)
writer.create_dataset(
f"{grp_name}/label", data=label, shape=label_shape
)
total_written += 1
writers[writer_index][1] += 1
writer_index = (writer_index + 1) % len(writers)
## Update meta info with number of lines in the input data.
meta_data[output_file] += 1
for writer, writer_num_examples, output_file in writers:
assert len(writer) == writer_num_examples
assert len(writer) == meta_data[output_file]
writer.attrs["n_examples"] = writer_num_examples
writer.flush()
writer.close()
return {
"total_written": total_written,
"meta_data": meta_data,
"n_docs": n_docs,
"dataset_range": {process_no: (min(dataset_range), max(dataset_range))},
}
[docs]def create_h5_mp(dataset_range, data_params, args):
try:
sub_dataset_range = split_list(
dataset_range, len(dataset_range) // args.num_processes
)
except ValueError as e:
# We hit errors in two potential scenarios,
# 1) Files is an empty list, in which case there is nothing to split
# 2) There are more processes than files, in which case we cannot split
# the files to processes correctly, as there will be many idle
# processes which are not doing anything.
print(e)
raise
with Pool(processes=args.num_processes) as pool:
results = pool.imap(
create_h5,
zip(
sub_dataset_range,
repeat(data_params),
repeat(args),
range(len(sub_dataset_range)),
),
)
meta = {
"total_written": 0,
"n_docs": 0,
"meta_data": {},
"dataset_range": {},
}
for r in results:
for k, v in r.items():
if not isinstance(v, dict):
meta[k] += v
else:
# Valid for both Counter and Dict objects
# For `Counter`` objects, values corresponding
# to same key are added.
# For `dict` objects, values corresponding
# to same key are updated with the new value `v`
meta[k].update(v)
return meta
[docs]def get_parser_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--num_processes",
type=int,
default=0,
help="Number of parallel processes to use, defaults to cpu count",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="directory where HDF5 files will be stored.",
)
parser.add_argument(
"--num_output_files",
type=int,
default=10,
help="number of output files in total i.e each process writes num_output_files//num_processes number of files"
"Defaults to 10.",
)
parser.add_argument(
"--name",
type=str,
default="preprocessed_data",
help="name of the dataset; i.e. prefix to use for hdf5 file names. "
"Defaults to 'preprocessed_data'.",
)
parser.add_argument(
"--params", type=str, required=True, help="params config yaml file"
)
return parser
[docs]def main():
args = get_parser_args().parse_args()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
output_dir = args.output_dir
if args.output_dir is None:
args.output_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
f"hdf5_dataset",
)
check_and_create_output_dirs(args.output_dir, filetype="h5")
json_params_file = os.path.join(args.output_dir, "data_params.json")
print(
f"\nStarting writing data to {args.output_dir}."
+ f" User arguments can be found at {json_params_file}."
)
# write initial params to file
params = read_params_file(args.params)
set_defaults(params)
params["input_args"] = {}
update_params_from_args(args, params["input_args"])
with open(json_params_file, 'w') as _fout:
json.dump(params, _fout, indent=4, sort_keys=True)
if args.num_processes == 0:
# if nothing is specified, then set number of processes to CPU count.
args.num_processes = cpu_count()
splits = ["train_input", "eval_input"]
for split in splits:
# set split specific output dir
args.output_dir = os.path.join(output_dir, split)
check_and_create_output_dirs(args.output_dir, filetype="h5")
args.split = split
dataset = _get_dataset(params[split], is_training="train" in split)
len_dataset = len(dataset)
dataset_range = list(range(len_dataset))
# Set defaults
# Data augmentation should be on the fly when training model.
params[split]["augment_data"] = False
# Write generic data, the data gets converted to appropriate dtypes in
# `transform_image_and_mask` fcn.
params[split]["mixed_precision"] = False
# Write data without hardcoding normalization.
# This helps use the same files with HDFDataProcessor
# and different normalization schemes
params[split]["normalize_data_method"] = None
if args.num_processes > 1:
results = create_h5_mp(dataset_range, params[split], args)
else:
# Run only single process run, with process number set as 0.
results = create_h5((dataset_range, params[split], args, 0))
## Update data_params file with new fields
with open(json_params_file, 'r') as _fin:
data = json.load(_fin)
data[split].update(params[split])
_key = f"{split}_hdf"
data[_key] = {}
data[_key]["n_docs"] = results["n_docs"]
data[_key]["total_written"] = results["total_written"]
data[_key]["dataset_range"] = results["dataset_range"]
with open(json_params_file, 'w') as _fout:
json.dump(data, _fout, indent=4, sort_keys=True)
print(
f"\nFinished writing {split} data to HDF5 to {args.output_dir}."
+ f" Runtime arguments and outputs can be found at {json_params_file}."
)
## Store meta file.
meta_file = os.path.join(output_dir, f"meta_{split}.dat")
with open(meta_file, "w") as fout:
for output_file, num_lines in results["meta_data"].items():
fout.write(f"{output_file} {num_lines}\n")
if __name__ == "__main__":
main()