diff --git a/osc/util/models.py b/osc/util/models.py index 10f9d9ee..46f2f422 100644 --- a/osc/util/models.py +++ b/osc/util/models.py @@ -130,6 +130,23 @@ class Field(property): return get_origin(types[0]) or types[0] return origin_type + @property + def inner_type(self): + if self.is_optional: + types = [i for i in self.type.__args__ if i != type(None)] + type_ = types[0] + else: + type_ = self.type + + if get_origin(type_) != list: + return None + + if not hasattr(type_, "__args__"): + return None + + inner_type = [i for i in type_.__args__ if i != type(None)][0] + return inner_type + @property def is_optional(self): origin_type = get_origin(self.type) or self.type @@ -139,6 +156,10 @@ class Field(property): def is_model(self): return inspect.isclass(self.origin_type) and issubclass(self.origin_type, BaseModel) + @property + def is_model_list(self): + return inspect.isclass(self.inner_type) and issubclass(self.inner_type, BaseModel) + def validate_type(self, value, expected_types=None): if not expected_types and self.is_optional and value is None: return True @@ -255,6 +276,15 @@ class Field(property): # initialize a model instance from a dictionary klass = self.origin_type value = klass(**value) # pylint: disable=not-callable + elif self.is_model_list and isinstance(value, list): + new_value = [] + for i in value: + if isinstance(i, dict): + klass = self.inner_type + new_value.append(klass(**i)) + else: + new_value.append(i) + value = new_value self.validate_type(value) obj._values[self.name] = value @@ -342,6 +372,8 @@ class BaseModel(metaclass=ModelMeta): value = getattr(self, name) if value is not None and field.is_model: result[name] = value.dict() + if value is not None and field.is_model_list: + result[name] = [i.dict() for i in value] else: result[name] = value diff --git a/tests/test_models.py b/tests/test_models.py index bc36ce7b..add9bd55 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -198,6 +198,60 @@ class Test(unittest.TestCase): self.assertEqual(m.field.text, "text") m.dict() + def test_list_submodels(self): + class TestSubmodel(BaseModel): + text: str = Field(default="default") + + class TestModel(BaseModel): + field: List[TestSubmodel] = Field(default=[]) + + m = TestModel() + + field = m.__fields__["field"] + self.assertEqual(field.is_model, False) + self.assertEqual(field.is_model_list, True) + self.assertEqual(field.is_optional, False) + self.assertEqual(field.origin_type, list) + m.dict() + + m = TestModel(field=[TestSubmodel()]) + self.assertEqual(m.field[0].text, "default") + m.dict() + + m = TestModel(field=[{"text": "text"}]) + self.assertEqual(m.field[0].text, "text") + m.dict() + + self.assertRaises(TypeError, getattr(m, "field")) + + def test_optional_list_submodels(self): + class TestSubmodel(BaseModel): + text: str = Field(default="default") + + class TestModel(BaseModel): + field: Optional[List[TestSubmodel]] = Field(default=[]) + + m = TestModel() + + field = m.__fields__["field"] + self.assertEqual(field.is_model, False) + self.assertEqual(field.is_model_list, True) + self.assertEqual(field.is_optional, True) + self.assertEqual(field.origin_type, list) + m.dict() + + m = TestModel(field=[TestSubmodel()]) + self.assertEqual(m.field[0].text, "default") + m.dict() + + m = TestModel(field=[{"text": "text"}]) + self.assertEqual(m.field[0].text, "text") + m.dict() + + m.field = None + self.assertEqual(m.field, None) + m.dict() + def test_enum(self): class Numbers(Enum): one = "one"