1
0
mirror of https://github.com/openSUSE/osc.git synced 2025-03-01 05:32:13 +01:00

Allow storing apiurl in the XmlModel instances

This commit is contained in:
Daniel Mach 2024-02-29 09:36:22 +01:00
parent 9cce6dbb8e
commit 747eb0ec52
2 changed files with 53 additions and 9 deletions

View File

@ -379,6 +379,7 @@ class BaseModel(metaclass=ModelMeta):
raise AttributeError(f"Setting attribute '{self.__class__.__name__}.{name}' is not allowed")
def __init__(self, **kwargs):
self._allow_new_attributes = True
self._defaults = {} # field defaults cached in field.get()
self._values = {} # field values explicitly set after initializing the model
self._parent = kwargs.pop("_parent", None)
@ -447,6 +448,11 @@ class BaseModel(metaclass=ModelMeta):
class XmlModel(BaseModel):
XML_TAG = None
_apiurl: Optional[str] = Field(
exclude=True,
default=FromParent("_apiurl", fallback=None),
)
def to_xml(self) -> ET.Element:
xml_tag = None
@ -466,6 +472,8 @@ class XmlModel(BaseModel):
root = ET.Element(xml_tag)
for field_name, field in self.__fields__.items():
if field.exclude:
continue
xml_attribute = field.extra.get("xml_attribute", False)
xml_set_tag = field.extra.get("xml_set_tag", False)
xml_set_text = field.extra.get("xml_set_text", False)
@ -517,20 +525,20 @@ class XmlModel(BaseModel):
return root
@classmethod
def from_string(cls, string: str) -> "XmlModel":
def from_string(cls, string: str, *, apiurl: Optional[str] = None) -> "XmlModel":
"""
Instantiate model from string.
"""
root = ET.fromstring(string)
return cls.from_xml(root)
return cls.from_xml(root, apiurl=apiurl)
@classmethod
def from_file(cls, file: Union[str, typing.IO]) -> "XmlModel":
def from_file(cls, file: Union[str, typing.IO], *, apiurl: Optional[str] = None) -> "XmlModel":
"""
Instantiate model from file.
"""
root = ET.parse(file).getroot()
return cls.from_xml(root)
return cls.from_xml(root, apiurl=apiurl)
def to_bytes(self) -> bytes:
"""
@ -588,7 +596,7 @@ class XmlModel(BaseModel):
parent.remove(node)
@classmethod
def from_xml(cls, root: ET.Element):
def from_xml(cls, root: ET.Element, *, apiurl: Optional[str] = None):
"""
Instantiate model from a XML root.
"""
@ -668,7 +676,7 @@ class XmlModel(BaseModel):
for node in nodes:
if field.is_model_list:
klass = field.inner_type
entry = klass.from_xml(node)
entry = klass.from_xml(node, apiurl=apiurl)
# clear node as it was checked in from_xml() already
node.text = None
@ -698,7 +706,7 @@ class XmlModel(BaseModel):
if node is None:
continue
klass = field.origin_type
kwargs[field_name] = klass.from_xml(node)
kwargs[field_name] = klass.from_xml(node, apiurl=apiurl)
# clear node as it was checked in from_xml() already
node.text = None
@ -723,7 +731,7 @@ class XmlModel(BaseModel):
cls._remove_processed_node(None, root)
obj = cls(**kwargs)
obj = cls(**kwargs, _apiurl=apiurl)
obj.__dict__["_root"] = orig_root
return obj
@ -767,7 +775,7 @@ class XmlModel(BaseModel):
while True:
run_editor(f.name)
try:
edited_obj = self.__class__.from_file(f.name)
edited_obj = self.__class__.from_file(f.name, apiurl=self._apiurl)
f.seek(0)
edited_data = f.read()
break

View File

@ -1,3 +1,4 @@
import io
import textwrap
import unittest
@ -149,6 +150,41 @@ class TestXmlModel(unittest.TestCase):
m = ParentModel.from_string(expected)
self.assertEqual(m.to_string(), expected)
def test_apiurl(self):
class ChildModel(XmlModel):
XML_TAG = "child"
value: str = Field()
class ParentModel(XmlModel):
XML_TAG = "parent"
text: str = Field()
child: List[ChildModel] = Field(xml_wrapped=True, xml_name="children")
# serialize the model and load it with apiurl set
m = ParentModel(text="TEXT", child=[{"value": "FOO"}, {"value": "BAR"}])
xml = m.to_string()
apiurl = "https://api.example.com"
m = ParentModel.from_string(xml, apiurl=apiurl)
m.child.append({"value": "BAZ"})
self.assertEqual(m._apiurl, apiurl)
self.assertEqual(m.child[0]._apiurl, apiurl)
self.assertEqual(m.child[1]._apiurl, apiurl)
self.assertEqual(m.child[2]._apiurl, apiurl)
# test the same as above but with a file
f = io.StringIO(xml)
m = ParentModel.from_file(f, apiurl=apiurl)
m.child.append({"value": "BAZ"})
self.assertEqual(m._apiurl, apiurl)
self.assertEqual(m.child[0]._apiurl, apiurl)
self.assertEqual(m.child[1]._apiurl, apiurl)
self.assertEqual(m.child[2]._apiurl, apiurl)
if __name__ == "__main__":
unittest.main()