diff --git a/osc/util/models.py b/osc/util/models.py index 8b1df84d..ab65b64f 100644 --- a/osc/util/models.py +++ b/osc/util/models.py @@ -13,6 +13,7 @@ import types from typing import get_type_hints # supported types +from enum import Enum from typing import Any from typing import Dict from typing import List @@ -40,6 +41,7 @@ __all__ = ( "Field", "NotSet", "FromParent", + "Enum", "Dict", "List", "NewType", @@ -176,6 +178,15 @@ class Field(property): valid_type = True continue + if ( + inspect.isclass(expected_type) + and issubclass(expected_type, Enum) + ): + # test if the value is part of the enum + expected_type(value) + valid_type = True + continue + if not isinstance(value, origin_type): msg = f"Field '{self.name}' has type '{self.type}'. Cannot assign a value with type '{type(value).__name__}'." raise TypeError(msg) diff --git a/tests/test_models.py b/tests/test_models.py index ccde8a4b..01501580 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -195,6 +195,26 @@ class Test(unittest.TestCase): self.assertNotEqual(m.field, None) self.assertEqual(m.field.text, "text") + def test_enum(self): + class Numbers(Enum): + one = "one" + two = "two" + + class TestModel(BaseModel): + field: Optional[Numbers] = Field(default=None) + + m = TestModel() + field = m.__fields__["field"] + self.assertEqual(field.is_model, False) + self.assertEqual(field.is_optional, True) + self.assertEqual(field.origin_type, Numbers) + self.assertEqual(m.field, None) + + m.field = "one" + self.assertEqual(m.field, "one") + + self.assertRaises(ValueError, setattr, m, "field", "does-not-exist") + def test_parent(self): class ParentModel(BaseModel): field: str = Field(default="text")