diff --git a/osc/util/models.py b/osc/util/models.py index 125557ba..46f2f422 100644 --- a/osc/util/models.py +++ b/osc/util/models.py @@ -13,6 +13,7 @@ import types from typing import get_type_hints # supported types +from enum import Enum from typing import Any from typing import Dict from typing import List @@ -40,6 +41,7 @@ __all__ = ( "Field", "NotSet", "FromParent", + "Enum", "Dict", "List", "NewType", @@ -125,9 +127,26 @@ class Field(property): origin_type = get_origin(self.type) or self.type if self.is_optional: types = [i for i in self.type.__args__ if i != type(None)] - return types[0] + return get_origin(types[0]) or types[0] return origin_type + @property + def inner_type(self): + if self.is_optional: + types = [i for i in self.type.__args__ if i != type(None)] + type_ = types[0] + else: + type_ = self.type + + if get_origin(type_) != list: + return None + + if not hasattr(type_, "__args__"): + return None + + inner_type = [i for i in type_.__args__ if i != type(None)][0] + return inner_type + @property def is_optional(self): origin_type = get_origin(self.type) or self.type @@ -137,6 +156,10 @@ class Field(property): def is_model(self): return inspect.isclass(self.origin_type) and issubclass(self.origin_type, BaseModel) + @property + def is_model_list(self): + return inspect.isclass(self.inner_type) and issubclass(self.inner_type, BaseModel) + def validate_type(self, value, expected_types=None): if not expected_types and self.is_optional and value is None: return True @@ -176,6 +199,15 @@ class Field(property): valid_type = True continue + if ( + inspect.isclass(expected_type) + and issubclass(expected_type, Enum) + ): + # test if the value is part of the enum + expected_type(value) + valid_type = True + continue + if not isinstance(value, origin_type): msg = f"Field '{self.name}' has type '{self.type}'. Cannot assign a value with type '{type(value).__name__}'." raise TypeError(msg) @@ -241,9 +273,17 @@ class Field(property): def set(self, obj, value): # if this is a model field, convert dict to a model instance if self.is_model and isinstance(value, dict): - new_value = self.origin_type() # pylint: disable=not-callable - for k, v in value.items(): - setattr(new_value, k, v) + # initialize a model instance from a dictionary + klass = self.origin_type + value = klass(**value) # pylint: disable=not-callable + elif self.is_model_list and isinstance(value, list): + new_value = [] + for i in value: + if isinstance(i, dict): + klass = self.inner_type + new_value.append(klass(**i)) + else: + new_value.append(i) value = new_value self.validate_type(value) @@ -311,12 +351,12 @@ class BaseModel(metaclass=ModelMeta): if kwargs: unknown_fields_str = ", ".join([f"'{i}'" for i in kwargs]) - raise TypeError(f"The following kwargs do not match any field: {unknown_fields_str}") + raise TypeError(f"The following kwargs of '{self.__class__.__name__}.__init__()' do not match any field: {unknown_fields_str}") if uninitialized_fields: uninitialized_fields_str = ", ".join([f"'{i}'" for i in uninitialized_fields]) raise TypeError( - f"The following fields are not initialized and have no default either: {uninitialized_fields_str}" + f"The following fields of '{self.__class__.__name__}' object are not initialized and have no default either: {uninitialized_fields_str}" ) for name, field in self.__fields__.items(): @@ -329,8 +369,12 @@ class BaseModel(metaclass=ModelMeta): for name, field in self.__fields__.items(): if field.exclude: continue - if field.is_model: - result[name] = getattr(self, name).dict() + value = getattr(self, name) + if value is not None and field.is_model: + result[name] = value.dict() + if value is not None and field.is_model_list: + result[name] = [i.dict() for i in value] else: - result[name] = getattr(self, name) + result[name] = value + return result diff --git a/tests/test_models.py b/tests/test_models.py index ccde8a4b..add9bd55 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -186,14 +186,91 @@ class Test(unittest.TestCase): self.assertEqual(field.is_optional, True) self.assertEqual(field.origin_type, TestSubmodel) self.assertEqual(m.field, None) + m.dict() m = TestModel(field=TestSubmodel()) self.assertIsInstance(m.field, TestSubmodel) self.assertEqual(m.field.text, "default") + m.dict() m = TestModel(field={"text": "text"}) self.assertNotEqual(m.field, None) self.assertEqual(m.field.text, "text") + m.dict() + + def test_list_submodels(self): + class TestSubmodel(BaseModel): + text: str = Field(default="default") + + class TestModel(BaseModel): + field: List[TestSubmodel] = Field(default=[]) + + m = TestModel() + + field = m.__fields__["field"] + self.assertEqual(field.is_model, False) + self.assertEqual(field.is_model_list, True) + self.assertEqual(field.is_optional, False) + self.assertEqual(field.origin_type, list) + m.dict() + + m = TestModel(field=[TestSubmodel()]) + self.assertEqual(m.field[0].text, "default") + m.dict() + + m = TestModel(field=[{"text": "text"}]) + self.assertEqual(m.field[0].text, "text") + m.dict() + + self.assertRaises(TypeError, getattr(m, "field")) + + def test_optional_list_submodels(self): + class TestSubmodel(BaseModel): + text: str = Field(default="default") + + class TestModel(BaseModel): + field: Optional[List[TestSubmodel]] = Field(default=[]) + + m = TestModel() + + field = m.__fields__["field"] + self.assertEqual(field.is_model, False) + self.assertEqual(field.is_model_list, True) + self.assertEqual(field.is_optional, True) + self.assertEqual(field.origin_type, list) + m.dict() + + m = TestModel(field=[TestSubmodel()]) + self.assertEqual(m.field[0].text, "default") + m.dict() + + m = TestModel(field=[{"text": "text"}]) + self.assertEqual(m.field[0].text, "text") + m.dict() + + m.field = None + self.assertEqual(m.field, None) + m.dict() + + def test_enum(self): + class Numbers(Enum): + one = "one" + two = "two" + + class TestModel(BaseModel): + field: Optional[Numbers] = Field(default=None) + + m = TestModel() + field = m.__fields__["field"] + self.assertEqual(field.is_model, False) + self.assertEqual(field.is_optional, True) + self.assertEqual(field.origin_type, Numbers) + self.assertEqual(m.field, None) + + m.field = "one" + self.assertEqual(m.field, "one") + + self.assertRaises(ValueError, setattr, m, "field", "does-not-exist") def test_parent(self): class ParentModel(BaseModel):