# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import h5py
import numpy as np
import pandas as pd
import torch
from torchvision import transforms
from cerebras.modelzoo.common.registry import registry
from cerebras.modelzoo.data.vision.segmentation.Hdf5BaseDataProcessor import (
Hdf5BaseDataProcessor,
)
from cerebras.modelzoo.data.vision.segmentation.preprocessing_utils import (
adjust_brightness_transform,
normalize_tensor_transform,
rotation_90_transform,
)
[docs]@registry.register_datasetprocessor("Hdf5DataProcessor")
class Hdf5DataProcessor(Hdf5BaseDataProcessor):
"""
A HDF5 dataset processor for UNet HDF dataset.
Performs on-the-fly augmentation of image and labek.
Functionality includes:
Reading data from HDF5 documents
Augmenting data
:param dict params: dict containing training
input parameters for creating dataset.
Expects the following fields:
- "data_dir" (str or list of str): Path to dataset HDF5 files
- "num_classes (int): Maximum length of the sequence to generate
- "image_shape" (int): Expected shape of output images and label, used in assert checks.
- "loss" (str): Loss type, supported: {"bce", "multilabel_bce", "ssce"}
- "normalize_data_method" (str): Can be one of {None, "zero_centered", "zero_one"}
- "batch_size" (int): Batch size.
- "shuffle" (bool): Flag to enable data shuffling.
- "shuffle_buffer" (int): Size of shuffle buffer in samples.
- "shuffle_seed" (int): Shuffle seed.
- "num_workers" (int): How many subprocesses to use for data loading.
- "drop_last" (bool): If True and the dataset size is not divisible
by the batch size, the last incomplete batch will be dropped.
- "prefetch_factor" (int): Number of samples loaded in advance by each worker.
- "persistent_workers" (bool): If True, the data loader will not shutdown
the worker processes after a dataset has been consumed once.
"""
def _shard_files(self, is_training=False):
# Features in HDF5 record files
self.features_list = ["image", "label"]
assert self.batch_size > 0, "Batch size should be positive."
p = Path(self.data_dir)
assert p.is_dir()
files = sorted(p.glob('*.h5'))
if not files:
raise RuntimeError('No hdf5 datasets found')
all_files = [str(file.resolve()) for file in files]
self.all_files = []
self.files_in_this_task = []
self.num_examples = 0
self.num_examples_in_this_task = 0
for idx, file_path in enumerate(all_files):
with h5py.File(file_path, mode='r') as h5_file:
num_examples_in_file = h5_file.attrs["n_examples"]
file_details = (file_path, num_examples_in_file)
self.all_files.append(file_details)
self.num_examples += num_examples_in_file
if idx % self.num_tasks == self.task_id:
self.files_in_this_task.append(file_details)
self.num_examples_in_this_task += num_examples_in_file
# Prevent CoW which is effectively copy on read behavior for PT,
# see: https://github.com/pytorch/pytorch/issues/13246
self.all_files = pd.DataFrame(
self.all_files, columns=["file_path", "num_examples_in_file"]
)
self.files_in_this_task = pd.DataFrame(
self.files_in_this_task,
columns=["file_path", "num_examples_in_file"],
)
def _apply_normalization(self, x):
return normalize_tensor_transform(
x, normalize_data_method=self.normalize_data_method
)
def _load_buffer(self, data_partitions):
for file_path, start_idx, num_examples in data_partitions:
with h5py.File(file_path, mode='r') as h5_file:
for idx in range(start_idx, start_idx + num_examples):
yield h5_file[f"example_{idx}"]
def _maybe_shard_dataset(self, num_workers):
per_worker_partition = {}
idx = 0
files = (
self.all_files if self.disable_sharding else self.files_in_this_task
)
for _, row in files.iterrows():
# Try to evenly distribute number of examples between workers
file_path = row["file_path"]
num_examples_in_file = row["num_examples_in_file"]
num_examples_all_workers = [
(num_examples_in_file // num_workers)
] * num_workers
for i in range(num_examples_in_file % num_workers):
num_examples_all_workers[i] += 1
assert sum(num_examples_all_workers) == num_examples_in_file
for file_idx in range(num_examples_in_file):
per_worker_partition[idx] = (file_path, f"example_{file_idx}")
idx += 1
return per_worker_partition
def __len__(self):
if self.disable_sharding:
return self.num_examples
else:
return self.num_examples_in_this_task
def __getitem__(self, index):
"""Get item at a particular index"""
file_path, sample_name = self.data_partitions[index]
example_dict = {}
with h5py.File(file_path, mode='r') as h5_file:
example = h5_file[sample_name]
for _, feature in enumerate(self.features_list):
example_dict[feature] = torch.from_numpy(
np.array(example[feature])
)
image, label = self.transform_image_and_mask(
example_dict["image"], example_dict["label"]
)
return image, label
def transform_image_and_mask(self, image, mask):
if self.normalize_data_method:
image = self.normalize_transform(image)
if self.augment_data:
do_horizontal_flip = torch.rand(size=(1,)).item() > 0.5
# n_rots in range [0, 3)
n_rotations = torch.randint(low=0, high=3, size=(1,)).item()
if self.tgt_image_height != self.tgt_image_width:
# For a rectangle image
n_rotations = n_rotations * 2
augment_transform_image = self.get_augment_transforms(
do_horizontal_flip=do_horizontal_flip,
n_rotations=n_rotations,
do_random_brightness=True,
)
augment_transform_mask = self.get_augment_transforms(
do_horizontal_flip=do_horizontal_flip,
n_rotations=n_rotations,
do_random_brightness=False,
)
image = augment_transform_image(image)
mask = augment_transform_mask(mask)
# Handle dtypes and mask shapes based on `loss_type`
# and `mixed_precsion`
if self.loss_type == "bce":
mask = mask.to(self.mp_type)
elif self.loss_type == "multilabel_bce":
mask = torch.squeeze(mask, 0)
# Only long tensors are accepted by one_hot fcn.
mask = mask.to(torch.long)
# out shape: (H, W, num_classes)
mask = torch.nn.functional.one_hot(
mask, num_classes=self.num_classes
)
# out shape: (num_classes, H, W)
mask = torch.permute(mask, [2, 0, 1])
mask = mask.to(self.mp_type)
elif self.loss_type == "ssce":
# out shape: (H, W) with each value in [0, num_classes)
mask = torch.squeeze(mask, 0)
# TODO: Add MZ tags here when supported.
# SW-82348 workaround: Pass `labels` in `int32``
# PT crossentropy loss takes in `int64`,
# view and typecast does not change the orginal `labels`.
mask = mask.to(torch.int32)
if self.mixed_precision:
image = image.to(self.mp_type)
return image, mask
def get_augment_transforms(
self, do_horizontal_flip, n_rotations, do_random_brightness
):
augment_transforms_list = []
if do_horizontal_flip:
horizontal_flip_transform = transforms.Lambda(
lambda x: transforms.functional.hflip(x)
)
augment_transforms_list.append(horizontal_flip_transform)
if n_rotations > 0:
rotation_transform = transforms.Lambda(
lambda x: rotation_90_transform(x, num_rotations=n_rotations)
)
augment_transforms_list.append(rotation_transform)
if do_random_brightness:
brightness_transform = transforms.Lambda(
lambda x: adjust_brightness_transform(x, p=0.5, delta=0.2)
)
augment_transforms_list.append(brightness_transform)
return transforms.Compose(augment_transforms_list)