# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List
import torch
# TODO: edit dataclass attrs
[docs]@dataclass
class Bbox:
"""
Source: indicates how the box was made:
xclick: are manually drawn boxes using the method presented in [1], were the annotators click on the four extreme points of the object. In V6 we release the actual 4 extreme points for all xclick boxes in train (13M), see below.
activemil: are boxes produced using an enhanced version of the method [2]. These are human verified to be accurate at IoU>0.7.
LabelName: the MID of the object class this box belongs to.
Confidence: a dummy value, always 1.
XMin, XMax, YMin, YMax: coordinates of the box, in normalized image coordinates. XMin is in [0,1], where 0 is the leftmost pixel, and 1 is the rightmost pixel in the image. Y coordinates go from the top pixel (0) to the bottom pixel (1).
IsOccluded: Indicates that the object is occluded by another object in the image.
IsTruncated: Indicates that the object extends beyond the boundary of the image.
IsGroupOf: Indicates that the box spans a group of objects (e.g., a bed of flowers or a crowd of people). We asked annotators to use this tag for cases with more than 5 instances which are heavily occluding each other and are physically touching.
IsDepiction: Indicates that the object is a depiction (e.g., a cartoon or drawing of the object, not a real physical instance).
IsInside: Indicates a picture taken from the inside of the object (e.g., a car interior or inside of a building).
For each of them, value 1 indicates present, 0 not present, and -1 unknown.
"""
XMin: float
YMin: float
XMax: float
YMax: float
ClassLabel: str
ClassIntID: int
ClassID: str = None
IsOccluded: bool = None
IsTruncated: bool = None
IsGroupOf: bool = None
IsDepiction: bool = None
IsInside: bool = None
IsTrainable: bool = None
Source: str = None
Confidence: int = None # dummy value of 1 in OpenImages
def bbox_to_tensor(self, format="yxyx"):
if format == "yxyx":
out = torch.tensor(
[self.YMin, self.XMin, self.YMax, self.XMax],
dtype=torch.float32,
)
elif format == "xyxy":
out = torch.tensor(
[self.XMin, self.YMin, self.XMax, self.YMax],
dtype=torch.float32,
)
else:
raise ValueError(
f"Unsupported format: {format}, supported values are ('xyxy', 'yxyx')"
)
return out
def labelID_to_tensor(self):
return torch.tensor(self.ClassIntID, dtype=torch.int64)
[docs]@dataclass
class ImageLabels:
ClassLabel: str
ClassIntID: str
ClassID: str = None
Confidence: float = None
IsTrainable: bool = None
Source: str = None
[docs]@dataclass
class ObjectDetectionFeaturesDict:
"""
ImageID: Name of image
"""
ImagePath: str
ImageID: str
Image: torch.Tensor = None # (C, H, W)
Bboxes: List[Bbox] = None
Objects: List[ImageLabels] = None
def compare(self, other):
for k in ["ImagePath", "ImageID"]:
if getattr(self, k) != getattr(other, k):
return False
if self.Image is None and other.Image is None:
pass
elif isinstance(self.Image, torch.Tensor) and isinstance(
other.Image, torch.Tensor
):
if not torch.all(torch.eq(self.Image, other.Image)):
return False
else:
return False
for k in ["Bboxes", "Objects"]:
obj_1 = getattr(self, k)
obj_2 = getattr(other, k)
if obj_1 is None and obj_2 is None:
pass
elif isinstance(obj_1, list) and isinstance(obj_2, list):
obj_1 = sorted(obj_1, key=lambda x: x.ClassIntID)
obj_2 = sorted(obj_2, key=lambda x: x.ClassIntID)
if obj_1 != obj_2:
return False
else:
return False
return True
def __eq__(self, other):
return self.compare(other)
[docs]@dataclass
class VQAAnswer:
answer_id: int
answer: str
answer_confidence: str
answer_language: str
[docs]@dataclass
class VQAQuestion:
question_id: int
question: str
question_language: str
[docs]@dataclass
class VQAFeaturesDict:
"""
ImageID: Name of image
"""
image_path: str
image_id: str
question: VQAQuestion
answers: List[VQAAnswer]
multiple_choice_answer: str # most frequent ground-truth answer.
multiple_choice_answer_language: str
answer_type: str = None
image: torch.Tensor = None # (C, H, W)
def __repr__(self):
s = (
f"VQAFeaturesDict.image_id: {self.image_id}, \n"
+ f"VQAFeaturesDict.image_path: {self.image_path}, \n\n"
)
s += repr(self.question) + f"\n\n"
for a in self.answers:
s += repr(a) + "\n"
s += f"VQAFeaturesDict.multiple_choice_answer: {self.multiple_choice_answer}, \n"
s += f"VQAFeaturesDict.answer_type: {self.answer_type}, \n"
return s