From 930b7a8a4ec3941f34906f634e1d9ffe9d8ab29c Mon Sep 17 00:00:00 2001 From: Daniel Mach Date: Tue, 22 Aug 2023 15:15:04 +0200 Subject: [PATCH] Add 'osc.util.models' module implementing an alternative pydantic-like data validation --- osc/util/models.py | 318 +++++++++++++++++++++++++++++++++++++++++++ tests/test_models.py | 202 +++++++++++++++++++++++++++ 2 files changed, 520 insertions(+) create mode 100644 osc/util/models.py create mode 100644 tests/test_models.py diff --git a/osc/util/models.py b/osc/util/models.py new file mode 100644 index 00000000..633479b1 --- /dev/null +++ b/osc/util/models.py @@ -0,0 +1,318 @@ +""" +This module implements a lightweight and limited alternative +to pydantic's BaseModel and Field classes. +It works on python 3.6+. + +This module IS NOT a supported API, it is meant for osc internal use only. +""" + +import inspect +import sys +import types +from typing import get_type_hints + +# supported types +from typing import Any +from typing import Dict +from typing import List +from typing import NewType +from typing import Optional +from typing import Tuple +from typing import Union + + +if sys.version_info < (3, 8): + + def get_origin(typ): + result = getattr(typ, "__origin__", None) + bases = getattr(result, "__orig_bases__", None) + if bases: + result = bases[0] + return result + +else: + from typing import get_origin + + +__all__ = ( + "BaseModel", + "Field", + "NotSet", + "FromParent", + "Dict", + "List", + "NewType", + "Optional", + "Tuple", + "Union", +) + + +class NotSetClass: + def __repr__(self): + return "NotSet" + + def __bool__(self): + return False + + +NotSet = NotSetClass() + + +class FromParent(NotSetClass): + def __init__(self, field_name): + self.field_name = field_name + + def __repr__(self): + return f"FromParent(field_name={self.field_name})" + + +class Field(property): + def __init__( + self, + default: Any = NotSet, + description: Optional[str] = None, + exclude: bool = False, + **extra, + ): + # the default value; it can be a factory function that is lazily evaluated on the first use + # model sets it to None if it equals to NotSet (for better usability) + self.default = default + + # whether the field was set + self.is_set = False + + # the name of model's attribute associated with this field instance - set from the model + self.name = None + + # the type of this field instance - set from the model + self.type = None + + # the description of the field + self.description = description + + # docstring - for sphinx and help() + self.__doc__ = self.description + if self.__doc__: + # append information about the default value + if isinstance(self.default, FromParent): + self.__doc__ += f"\n\nDefault: inherited from parent config's field ``{self.default.field_name}``" + elif self.default is not NotSet: + self.__doc__ += f"\n\nDefault: ``{self.default}``" + + # whether to exclude this field from export + self.exclude = exclude + + # extra fields + self.extra = extra + + # create an instance specific of self.get() so we can annotate it in the model + self.get_copy = types.FunctionType( + self.get.__code__, + self.get.__globals__, + self.get.__name__, + self.get.__defaults__, + self.get.__closure__, + ) + # turn function into a method by binding it to the instance + self.get_copy = types.MethodType(self.get_copy, self) + + super().__init__(fget=self.get_copy, fset=self.set, doc=description) + + @property + def origin_type(self): + origin_type = get_origin(self.type) or self.type + if self.is_optional: + types = [i for i in self.type.__args__ if i != type(None)] + return types[0] + return origin_type + + @property + def is_optional(self): + origin_type = get_origin(self.type) or self.type + return origin_type == Union and len(self.type.__args__) == 2 and type(None) in self.type.__args__ + + @property + def is_model(self): + return inspect.isclass(self.origin_type) and issubclass(self.origin_type, BaseModel) + + def validate_type(self, value, expected_types=None): + if not expected_types and self.is_optional and value is None: + return True + + if expected_types is None: + expected_types = (self.type,) + elif not isinstance(expected_types, (list, tuple)): + expected_types = (expected_types,) + + valid_type = False + + for expected_type in expected_types: + if valid_type: + break + + origin_type = get_origin(expected_type) or expected_type + + # unwrap Union + if origin_type == Union: + if value is None and type(None) in expected_type.__args__: + valid_type = True + continue + + valid_type |= self.validate_type(value, expected_types=expected_type.__args__) + continue + + # unwrap NewType + if (callable(NewType) or isinstance(origin_type, NewType)) and hasattr(origin_type, "__supertype__"): + valid_type |= self.validate_type(value, expected_types=(origin_type.__supertype__,)) + continue + + if ( + inspect.isclass(expected_type) + and issubclass(expected_type, BaseModel) + and isinstance(value, (expected_type, dict)) + ): + valid_type = True + continue + + if not isinstance(value, origin_type): + msg = f"Field '{self.name}' has type '{self.type}'. Cannot assign a value with type '{type(value).__name__}'." + raise TypeError(msg) + + # the type annotation has no arguments -> no need to check those + if not getattr(expected_type, "__args__", None): + valid_type = True + continue + + if origin_type in (list, tuple): + valid_type_items = True + for i in value: + valid_type_items &= self.validate_type(i, expected_type.__args__) + valid_type |= valid_type_items + elif origin_type in (dict,): + valid_type_items = True + for k, v in value.items(): + valid_type_items &= self.validate_type(k, expected_type.__args__[0]) + valid_type_items &= self.validate_type(v, expected_type.__args__[1]) + valid_type |= valid_type_items + else: + raise TypeError(f"Field '{self.name}' has unsupported type '{self.type}'.") + + return valid_type + + def get(self, obj): + try: + return getattr(obj, f"_{self.name}") + except AttributeError: + pass + + if isinstance(self.default, FromParent): + if obj._parent is None: + raise RuntimeError(f"The field '{self.name}' has default {self.default} but the model has no parent set") + return getattr(obj._parent, self.default.field_name or self.name) + + if self.default is NotSet: + raise RuntimeError(f"The field '{self.name}' has no default") + + # lazy evaluation of a factory function on first use + if callable(self.default): + self.default = self.default() + + # if this is a model field, convert dict to a model instance + if self.is_model and isinstance(self.default, dict): + new_value = self.origin_type() # pylint: disable=not-callable + for k, v in self.default.items(): + setattr(new_value, k, v) + self.default = new_value + + return self.default + + def set(self, obj, value): + # if this is a model field, convert dict to a model instance + if self.is_model and isinstance(value, dict): + new_value = self.origin_type() # pylint: disable=not-callable + for k, v in value.items(): + setattr(new_value, k, v) + value = new_value + + self.validate_type(value) + setattr(obj, f"_{self.name}", value) + self.is_set = True + + +class ModelMeta(type): + def __new__(mcs, name, bases, attrs): + new_cls = super().__new__(mcs, name, bases, attrs) + new_cls.__fields__ = {} + + # NOTE: dir() doesn't preserve attribute order + # we need to iterate through __mro__ classes to workaround that + for parent_cls in reversed(new_cls.__mro__): + for field_name in parent_cls.__dict__: + if field_name in new_cls.__fields__: + continue + field = getattr(new_cls, field_name) + if not isinstance(field, Field): + continue + new_cls.__fields__[field_name] = field + + # fill model specific details back to the fields + for field_name, field in new_cls.__fields__.items(): + # property name associated with the field in this model + field.name = field_name + + # field type associated with the field in this model + field.type = get_type_hints(new_cls)[field_name] + + # set annotation for the getter so it shows up in sphinx + field.get_copy.__func__.__annotations__ = {"return": field.type} + + # set 'None' as the default for optional fields + if field.default is NotSet and field.is_optional: + field.default = None + + return new_cls + + +class BaseModel(metaclass=ModelMeta): + __fields__: Dict[str, Field] + + def __init__(self, **kwargs): + self._parent = kwargs.pop("_parent", None) + + uninitialized_fields = [] + + for name, field in self.__fields__.items(): + if name not in kwargs: + if field.default is NotSet: + uninitialized_fields.append(field.name) + continue + value = kwargs.pop(name) + setattr(self, name, value) + + if kwargs: + unknown_fields_str = ", ".join([f"'{i}'" for i in kwargs]) + raise TypeError(f"The following kwargs do not match any field: {unknown_fields_str}") + + if uninitialized_fields: + uninitialized_fields_str = ", ".join([f"'{i}'" for i in uninitialized_fields]) + raise TypeError( + f"The following fields are not initialized and have no default either: {uninitialized_fields_str}" + ) + + for name, field in self.__fields__.items(): + field.validate_type(getattr(self, name)) + + def dict(self, exclude_unset=False): + result = {} + for name, field in self.__fields__.items(): + if field.exclude: + continue + if exclude_unset and not field.is_set and field.is_optional: + # include only mandatory fields and optional fields that were set to an actual value + continue + if field.is_model: + result[name] = getattr(self, name).dict(exclude_unset=exclude_unset) + else: + result[name] = getattr(self, name) + return result diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 00000000..082b4fae --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,202 @@ +import unittest +from typing import Set + +from osc.util.models import * +from osc.util.models import get_origin + + +class TestTyping(unittest.TestCase): + def test_get_origin_list(self): + typ = get_origin(list) + self.assertEqual(typ, None) + + def test_get_origin_list_str(self): + typ = get_origin(List[str]) + self.assertEqual(typ, list) + + +class TestNotSet(unittest.TestCase): + def test_repr(self): + self.assertEqual(repr(NotSet), "NotSet") + + def test_bool(self): + self.assertEqual(bool(NotSet), False) + + +class Test(unittest.TestCase): + def test_modified(self): + class TestModel(BaseModel): + a: str = Field(default="default") + b: Optional[str] = Field(default=None) + + m = TestModel() + self.assertEqual(m.dict(exclude_unset=True), {"a": "default"}) + + m = TestModel(b=None) + self.assertEqual(m.dict(exclude_unset=True), {"a": "default", "b": None}) + + def test_unknown_fields(self): + class TestModel(BaseModel): + pass + + self.assertRaises(TypeError, TestModel, does_not_exist=None) + + def test_uninitialized(self): + class TestModel(BaseModel): + field: str = Field() + + self.assertRaises(TypeError, TestModel) + + def test_invalid_type(self): + class TestModel(BaseModel): + field: Optional[str] = Field() + + m = TestModel() + self.assertRaises(TypeError, setattr, m.field, []) + + def test_unsupported_type(self): + class TestModel(BaseModel): + field: Set[str] = Field(default=None) + + self.assertRaises(TypeError, TestModel) + + def test_is_set(self): + class TestModel(BaseModel): + field: Optional[str] = Field() + + m = TestModel() + + field = m.__fields__["field"] + self.assertEqual(field.is_set, False) + self.assertEqual(m.field, None) + m.field = "text" + self.assertEqual(field.is_set, True) + self.assertEqual(m.field, "text") + + def test_str(self): + class TestModel(BaseModel): + field: str = Field(default="default") + + m = TestModel() + + field = m.__fields__["field"] + self.assertEqual(field.is_model, False) + self.assertEqual(field.is_optional, False) + self.assertEqual(field.is_set, False) + self.assertEqual(field.origin_type, str) + + self.assertEqual(m.field, "default") + m.field = "text" + self.assertEqual(m.field, "text") + + def test_optional_str(self): + class TestModel(BaseModel): + field: Optional[str] = Field() + + m = TestModel() + + field = m.__fields__["field"] + self.assertEqual(field.is_model, False) + self.assertEqual(field.is_optional, True) + self.assertEqual(field.is_set, False) + self.assertEqual(field.origin_type, str) + + self.assertEqual(m.field, None) + m.field = "text" + self.assertEqual(m.field, "text") + + def test_int(self): + class TestModel(BaseModel): + field: int = Field(default=0) + + m = TestModel() + + field = m.__fields__["field"] + self.assertEqual(field.is_model, False) + self.assertEqual(field.is_optional, False) + self.assertEqual(field.origin_type, int) + + self.assertEqual(m.field, 0) + m.field = 1 + self.assertEqual(m.field, 1) + + def test_optional_int(self): + class TestModel(BaseModel): + field: Optional[int] = Field() + + m = TestModel() + + field = m.__fields__["field"] + self.assertEqual(field.is_model, False) + self.assertEqual(field.is_optional, True) + self.assertEqual(field.origin_type, int) + + self.assertEqual(m.field, None) + m.field = 1 + self.assertEqual(m.field, 1) + + def test_submodel(self): + class TestSubmodel(BaseModel): + text: str = Field(default="default") + + class TestModel(BaseModel): + field: TestSubmodel = Field(default={}) + + m = TestModel() + + field = m.__fields__["field"] + self.assertEqual(field.is_model, True) + self.assertEqual(field.is_optional, False) + self.assertEqual(field.origin_type, TestSubmodel) + + m = TestModel(field=TestSubmodel()) + self.assertEqual(m.field.text, "default") + + m = TestModel(field={"text": "text"}) + self.assertEqual(m.field.text, "text") + + def test_optional_submodel(self): + class TestSubmodel(BaseModel): + text: str = Field(default="default") + + class TestModel(BaseModel): + field: Optional[TestSubmodel] = Field(default=None) + + m = TestModel() + + field = m.__fields__["field"] + self.assertEqual(field.is_model, True) + self.assertEqual(field.is_optional, True) + self.assertEqual(field.origin_type, TestSubmodel) + self.assertEqual(m.field, None) + + m = TestModel(field=TestSubmodel()) + self.assertIsInstance(m.field, TestSubmodel) + self.assertEqual(m.field.text, "default") + + m = TestModel(field={"text": "text"}) + self.assertNotEqual(m.field, None) + self.assertEqual(m.field.text, "text") + + def test_parent(self): + class ParentModel(BaseModel): + field: str = Field(default="text") + + class ChildModel(BaseModel): + field: str = Field(default=FromParent("field")) + field2: str = Field(default=FromParent("field")) + + p = ParentModel() + c = ChildModel(_parent=p) + self.assertEqual(p.field, "text") + self.assertEqual(c.field, "text") + self.assertEqual(c.field2, "text") + + c.field = "new-text" + self.assertEqual(p.field, "text") + self.assertEqual(c.field, "new-text") + self.assertEqual(c.field2, "text") + + +if __name__ == "__main__": + unittest.main()