1
0
mirror of https://github.com/openSUSE/osc.git synced 2025-01-19 11:56:13 +01:00

Implement total ordering on BaseModel

This commit is contained in:
Daniel Mach 2024-03-04 13:52:57 +01:00
parent 747eb0ec52
commit cd95478ac8
2 changed files with 72 additions and 1 deletions

View File

@ -7,6 +7,7 @@ This module IS NOT a supported API, it is meant for osc internal use only.
""" """
import copy import copy
import functools
import inspect import inspect
import sys import sys
import tempfile import tempfile
@ -368,6 +369,7 @@ class ModelMeta(type):
return new_cls return new_cls
@functools.total_ordering
class BaseModel(metaclass=ModelMeta): class BaseModel(metaclass=ModelMeta):
__fields__: Dict[str, Field] __fields__: Dict[str, Field]
@ -412,10 +414,28 @@ class BaseModel(metaclass=ModelMeta):
self._allow_new_attributes = False self._allow_new_attributes = False
def _get_cmp_data(self):
result = []
for name, field in self.__fields__.items():
if field.exclude:
continue
value = getattr(self, name)
if isinstance(value, dict):
value = sorted(list(value.items()))
result.append((name, value))
return result
def __eq__(self, other): def __eq__(self, other):
if type(self) != type(other): if type(self) != type(other):
return False return False
return self.dict() == other.dict() if self._get_cmp_data() != other._get_cmp_data():
print(self._get_cmp_data(), other._get_cmp_data())
return self._get_cmp_data() == other._get_cmp_data()
def __lt__(self, other):
if type(self) != type(other):
return False
return self._get_cmp_data() < other._get_cmp_data()
def dict(self): def dict(self):
result = {} result = {}

View File

@ -395,6 +395,57 @@ class Test(unittest.TestCase):
self.assertIsInstance(m.field[0], BaseModel) self.assertIsInstance(m.field[0], BaseModel)
self.assertEqual(m.field[0].text, "value") self.assertEqual(m.field[0].text, "value")
def test_ordering(self):
class TestSubmodel(BaseModel):
txt: Optional[str] = Field()
class TestModel(BaseModel):
num: Optional[int] = Field()
txt: Optional[str] = Field()
sub: Optional[TestSubmodel] = Field()
dct: Optional[Dict[str, TestSubmodel]] = Field()
m1 = TestModel()
m2 = TestModel()
self.assertEqual(m1, m2)
m1 = TestModel(num=1)
m2 = TestModel(num=2)
self.assertNotEqual(m1, m2)
self.assertLess(m1, m2)
self.assertGreater(m2, m1)
m1 = TestModel(txt="a")
m2 = TestModel(txt="b")
self.assertNotEqual(m1, m2)
self.assertLess(m1, m2)
self.assertGreater(m2, m1)
m1 = TestModel(sub={})
m2 = TestModel(sub={})
self.assertEqual(m1, m2)
m1 = TestModel(sub={"txt": "a"})
m2 = TestModel(sub={"txt": "b"})
self.assertNotEqual(m1, m2)
self.assertLess(m1, m2)
self.assertGreater(m2, m1)
m1 = TestModel(dct={})
m2 = TestModel(dct={})
self.assertEqual(m1, m2)
m1 = TestModel(dct={"a": TestSubmodel()})
m2 = TestModel(dct={"b": TestSubmodel()})
self.assertNotEqual(m1, m2)
self.assertLess(m1, m2)
self.assertGreater(m2, m1)
# dict ordering doesn't matter
m1 = TestModel(dct={"a": TestSubmodel(), "b": TestSubmodel()})
m2 = TestModel(dct={"b": TestSubmodel(), "a": TestSubmodel()})
self.assertEqual(m1, m2)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()