diff --git a/osc/util/models.py b/osc/util/models.py index 061374cb..8d6b3c4a 100644 --- a/osc/util/models.py +++ b/osc/util/models.py @@ -39,6 +39,15 @@ if sys.version_info < (3, 8): else: from typing import get_origin + +# types.UnionType was added in Python 3.10 +if sys.version_info < (3, 10): + class UnionType: + pass +else: + from types import UnionType + + import urllib3.response from . import xml @@ -79,8 +88,9 @@ class FromParent(NotSetClass): def __repr__(self): return f"FromParent(field_name={self.field_name})" - -class Field(property): +# HACK: inheriting from Any fixes the following mypy error: +# Incompatible types in assignment (expression has type "Field", variable has type "X | None") [assignment] +class Field(property, *([Any] if typing.TYPE_CHECKING else [])): def __init__( self, default: Any = NotSet, @@ -165,7 +175,7 @@ class Field(property): @property def is_optional(self): origin_type = get_origin(self.type) or self.type - return origin_type == Union and len(self.type.__args__) == 2 and type(None) in self.type.__args__ + return origin_type in (Union, UnionType) and type(None) in self.type.__args__ @property def is_model(self): @@ -193,7 +203,7 @@ class Field(property): origin_type = get_origin(expected_type) or expected_type # unwrap Union - if origin_type == Union: + if origin_type in (Union, UnionType): if value is None and type(None) in expected_type.__args__: valid_type = True continue @@ -464,7 +474,7 @@ class BaseModel(metaclass=ModelMeta): class XmlModel(BaseModel): - XML_TAG = None + XML_TAG: Optional[str] = None _apiurl: Optional[str] = Field( exclude=True, diff --git a/tests/test_models.py b/tests/test_models.py index 29d0a6a3..04756ead 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,3 +1,4 @@ +import sys import unittest from typing import Set @@ -24,6 +25,16 @@ class TestNotSet(unittest.TestCase): class Test(unittest.TestCase): + @unittest.skipIf(sys.version_info[:2] < (3, 10), "added in python 3.10") + def test_union_or(self): + class TestModel(BaseModel): + text: str | None = Field() + + m = TestModel() + self.assertEqual(m.dict(), {"text": None}) + + self.assertRaises(TypeError, setattr, m.text, 123) + def test_dict(self): class TestSubmodel(BaseModel): text: str = Field(default="default")