Source code for cerebras.modelzoo.data.vision.segmentation.transforms.utility_transforms
# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from: https://github.com/MIC-DKFZ/batchgenerators (commit id: 01f225d)
#
# Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
# and Applied Computer Vision Lab, Helmholtz Imaging Platform
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
[docs]class NumpyToTensor:
def __init__(self, cast_to=None):
"""Utility function for pytorch. Converts data (and seg) numpy ndarrays to pytorch tensors
:param cast_to [list]: images will be cast to cast_to[0], targets will be cast to cast_to[1].
"""
self.cast_to = cast_to
def __call__(self, **data_dict):
data_dict['data'] = (
torch.from_numpy(data_dict['data']).contiguous().to(self.cast_to[0])
)
data_dict['target'] = (
torch.from_numpy(data_dict['target'])
.contiguous()
.to(self.cast_to[1])
)
return data_dict
[docs]class RemoveLabelTransform:
'''
Replaces all pixels in data_dict[input_key] that have value remove_label with replace_with and saves the result to
data_dict[output_key]
'''
def __init__(
self, remove_label, replace_with=0, input_key="seg", output_key="seg"
):
self.output_key = output_key
self.input_key = input_key
self.replace_with = replace_with
self.remove_label = remove_label
def __call__(self, **data_dict):
seg = data_dict[self.input_key]
seg[seg == self.remove_label] = self.replace_with
data_dict[self.output_key] = seg
return data_dict
[docs]class RenameTransform:
'''
Saves the value of data_dict[in_key] to data_dict[out_key]. Optionally removes data_dict[in_key] from the dict.
'''
def __init__(self, in_key, out_key, delete_old=False):
self.delete_old = delete_old
self.out_key = out_key
self.in_key = in_key
def __call__(self, **data_dict):
data_dict[self.out_key] = data_dict[self.in_key]
if self.delete_old:
del data_dict[self.in_key]
return data_dict
[docs]class OneHotTransform:
def __init__(self, num_classes):
self.num_classes = num_classes
def __call__(self, **data_dict):
data_dict['target'] = torch.tensor(
data_dict['target'][:, 0, :], dtype=torch.long
)
# out shape: (H, W, num_classes)
data_dict['target'] = torch.nn.functional.one_hot(
data_dict['target'], num_classes=self.num_classes
)
data_dict['target'] = data_dict['target'].to(torch.float32)
data_dict['target'] = torch.permute(
data_dict['target'], (0, -1, 1, 2, 3)
)
return data_dict
[docs]class OneHotTransformKits:
def __init__(self, num_classes):
self.num_classes = num_classes
def __call__(self, data_dict):
data_dict['label'] = torch.tensor(data_dict['label'], dtype=torch.long)
# out shape: (H, W, num_classes)
data_dict['label'] = torch.nn.functional.one_hot(
data_dict['label'], num_classes=self.num_classes
)
data_dict['label'] = data_dict['label'].to(torch.float32)
data_dict['label'] = torch.permute(data_dict['label'], (0, -1, 1, 2, 3))
return data_dict