Source code for cerebras.modelzoo.config_manager.config_classes.base.base_config

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Base implementation for config classes with helper modules

"""

import ast
import logging
from dataclasses import dataclass, field, fields, is_dataclass
from typing import Any, Callable, Optional, Union, get_args, get_type_hints

from cerebras.modelzoo.config_manager.config_validators import (
    get_origin,
    validate_field_type,
)

# Alias required as an empty object to check for values that are mandatory and not provided
required = object()


[docs]def config_field( default: Any = required, constraint: Optional[Callable[..., Any]] = None, ): """ Custom field definition for config to abstract metadata usage Args: default: Default value expected for the field constraint: The function to be invoked to set constraints to the parameter value """ metadata = { "constraints": constraint, } return field(default=default, metadata=metadata)
[docs]def is_union_type_hint(type_hint): """Check if Union type""" return get_origin(type_hint) is Union
[docs]def get_member_type_hints(cls): """ Iterates over all members of a class and extracts their type hints. Args: cls: The class to iterate over. Returns: A dictionary mapping member names to their corresponding type hints. """ class_fields = fields(cls) type_hints = {} # Extracting type information for class_field in class_fields: type_hint = class_field.type type_hints[class_field.name] = get_args(type_hint) or type_hint return type_hints
[docs]def get_class_type(config_class, parameter): annotations = get_type_hints(config_class) field_type = annotations[parameter] if get_origin(field_type) is Union: for union_type in get_args(field_type): if is_dataclass(union_type): return union_type elif is_dataclass(field_type): return field_type return None
[docs]def set_constraint(current_constraint, updated_constraint): """Set the required constraint type if not already set""" if current_constraint[0] is not None: logging.warning( "Trying to select constraint for config implicitly, more than one valid type exists" ) current_constraint[0] = updated_constraint
[docs]@dataclass class BaseConfig: """This class represents a Base Model config, inherited by sub config classes""" def __validate__(self): """Validation method that iterates over class members with validation meta attached""" type_hints = get_member_type_hints(self) class_fields = fields(self) # Iterate over all class attributes and call their validations for class_field in class_fields: field_name = class_field.name field_value = getattr(self, field_name) if field_name != "__validate__" and hasattr(self, field_name): curr_field = getattr(self, field_name) # Check if the field is an instance of a child class if isinstance(curr_field, BaseConfig): # If it's an instance of a child class, recursively call its __validate__ method curr_field.__validate__() type_hints_child = get_member_type_hints(curr_field) type_hints.update(type_hints_child) field_meta = class_field.metadata constraint = None # Get the implicit constraint if we have one set explicitly if "constraints" in field_meta: constraint = field_meta['constraints'] # Check if all the mandatory params received a value if field_value is required: raise ValueError( f"required value for {field_name}, which is mandatory and must be set" ) # Check if the field is made optional is_optional = is_union_type_hint(class_field.type) and type( None ) in get_args(class_field.type) if not is_optional and field_value is None: raise ValueError( f"None value for {field_name}, which is not of optional type" ) # If there is a custom validation logic attached use that if constraint is not None and field_value is not None: if constraint(field_value) is False: raise ValueError( f"value for {field_name}, does not match the constraint" ) elif field_value is not None: # If its a valid value, check for type based validation validate_field_type(class_field, field_value) def _set_class_type(self, field_name, class_type, field_value): """ Set the field to class type instance It calls the constructor of the class object type Init params are the same as the dict/list we get. The typecase will fail in class init if the param list doesn`t match the class signature """ try: if isinstance(field_value, str): field_dict = ast.literal_eval(field_value) setattr(self, field_name, class_type(**field_dict)) elif isinstance(field_value, list): for i, item in enumerate(field_value): if isinstance(item, dict): field_value[i] = class_type(**item) elif isinstance(field_value, dict): field_dict = field_value setattr(self, field_name, class_type(**field_dict)) elif not is_dataclass(field_value) and not isinstance( field_value, dict ): logging.warning( f"We got a config class initialization with invalid type {type(field_value)}" ) except Exception as e: raise RuntimeError( f"Failed to set field \"{field_name}\" with class type \"{class_type}\"." ) from e def __post_init__(self): """ Post init runs through the class object and creates sub-class objects from dict type initializations """ for curr_field in fields(self): field_name = curr_field.name field_type = curr_field.type field_value = getattr(self, field_name) # Check if the field type is a Union if get_origin(field_type) is Union: for union_type in get_args(field_type): if ( is_dataclass(union_type) and field_value is not None and field_value is not required ): self._set_class_type( field_name=field_name, class_type=union_type, field_value=field_value, ) break elif get_origin(union_type) is list and get_args( union_type ): if ( is_dataclass(get_args(union_type)[0]) and field_value is not None and field_value is not required ): self._set_class_type( field_name=field_name, class_type=get_args(union_type)[0], field_value=field_value, ) break elif is_dataclass(field_type): if field_value is not None and field_value is not required: self._set_class_type( field_name=field_name, class_type=field_type, field_value=field_value, ) elif get_origin(field_type) is list and get_args(field_type): if ( is_dataclass(get_args(field_type)[0]) and field_value is not None and field_value is not required ): self._set_class_type( field_name=field_name, class_type=get_args(field_type)[0], field_value=field_value, )