diff --git a/osc/util/models.py b/osc/util/models.py
index d0f8dfb0..250e4b6a 100644
--- a/osc/util/models.py
+++ b/osc/util/models.py
@@ -9,9 +9,12 @@ This module IS NOT a supported API, it is meant for osc internal use only.
import copy
import inspect
import sys
+import tempfile
import types
+import typing
from typing import Callable
from typing import get_type_hints
+from xml.etree import ElementTree as ET
# supported types
from enum import Enum
@@ -23,7 +26,6 @@ from typing import Optional
from typing import Tuple
from typing import Union
-
if sys.version_info < (3, 8):
def get_origin(typ):
@@ -36,9 +38,14 @@ if sys.version_info < (3, 8):
else:
from typing import get_origin
+import urllib3.response
+
+from . import xml
+
__all__ = (
"BaseModel",
+ "XmlModel",
"Field",
"NotSet",
"FromParent",
@@ -297,6 +304,10 @@ class Field(property):
else:
new_value.append(i)
value = new_value
+ elif self.is_model and isinstance(value, str) and hasattr(self.origin_type, "XML_TAG_FIELD"):
+ klass = self.origin_type
+ key = getattr(self.origin_type, "XML_TAG_FIELD")
+ value = klass(**{key: value})
self.validate_type(value)
obj._values[self.name] = value
@@ -390,3 +401,350 @@ class BaseModel(metaclass=ModelMeta):
result[name] = value
return result
+
+
+class XmlModel(BaseModel):
+ XML_TAG = None
+
+ def to_xml(self) -> ET.Element:
+ xml_tag = None
+
+ # check if there's a special field that sets the tag
+ for field_name, field in self.__fields__.items():
+ xml_set_tag = field.extra.get("xml_set_tag", False)
+ if xml_set_tag:
+ value = getattr(self, field_name)
+ xml_tag = value
+ break
+
+ # use the value from the class
+ if xml_tag is None:
+ xml_tag = self.XML_TAG
+
+ assert xml_tag is not None
+ root = ET.Element(xml_tag)
+
+ for field_name, field in self.__fields__.items():
+ xml_attribute = field.extra.get("xml_attribute", False)
+ xml_set_tag = field.extra.get("xml_set_tag", False)
+ xml_set_text = field.extra.get("xml_set_text", False)
+ xml_name = field.extra.get("xml_name", field_name)
+ xml_wrapped = field.extra.get("xml_wrapped", False)
+ xml_item_name = field.extra.get("xml_item_name", xml_name)
+
+ if xml_set_tag:
+ # a special case when the field determines the top-level tag name
+ continue
+
+ value = getattr(self, field_name)
+ if value is None:
+ # skip fields that are not set
+ continue
+
+ # if value is wrapped into an external element, create it
+ if xml_wrapped:
+ wrapper_node = ET.SubElement(root, xml_name)
+ else:
+ wrapper_node = root
+
+ if xml_set_text:
+ wrapper_node.text = str(value)
+ continue
+
+ if field.origin_type == list:
+ for entry in value:
+ if isinstance(entry, dict):
+ klass = field.inner_type
+ obj = klass(**entry)
+ node = obj.to_xml()
+ wrapper_node.append(node)
+ elif field.inner_type and issubclass(field.inner_type, XmlModel):
+ wrapper_node.append(entry.to_xml())
+ else:
+ node = ET.SubElement(wrapper_node, xml_item_name)
+ if xml_attribute:
+ node.attrib[xml_attribute] = entry
+ else:
+ node.text = entry
+ elif issubclass(field.origin_type, XmlModel):
+ wrapper_node.append(value.to_xml())
+ elif xml_attribute:
+ wrapper_node.attrib[xml_name] = str(value)
+ else:
+ node = ET.SubElement(wrapper_node, xml_name)
+ node.text = str(value)
+ return root
+
+ @classmethod
+ def from_string(cls, string: str) -> "XmlModel":
+ """
+ Instantiate model from string.
+ """
+ root = ET.fromstring(string)
+ return cls.from_xml(root)
+
+ @classmethod
+ def from_file(cls, file: Union[str, typing.IO]) -> "XmlModel":
+ """
+ Instantiate model from file.
+ """
+ root = ET.parse(file).getroot()
+ return cls.from_xml(root)
+
+ def to_bytes(self) -> bytes:
+ """
+ Serialize the object as XML and return it as utf-8 encoded bytes.
+ """
+ root = self.to_xml()
+ xml.xml_indent(root)
+ return ET.tostring(root, encoding="utf-8")
+
+ def to_string(self) -> str:
+ """
+ Serialize the object as XML and return it as a string.
+ """
+ return self.to_bytes().decode("utf-8")
+
+ def to_file(self, file: Union[str, typing.IO]) -> None:
+ """
+ Serialize the object as XML and save it to an utf-8 encoded file.
+ """
+ root = self.to_xml()
+ xml.xml_indent(root)
+ return ET.ElementTree(root).write(file, encoding="utf-8")
+
+ @staticmethod
+ def value_from_string(field, value):
+ """
+ Convert field value from string to the actual type of the field.
+ """
+ if field.origin_type is bool:
+ if value.lower() in ["1", "yes", "true", "on"]:
+ value = True
+ return value
+ if value.lower() in ["0", "no", "false", "off"]:
+ value = False
+ return value
+
+ if field.origin_type is int:
+ value = int(value)
+ return value
+
+ return value
+
+ @classmethod
+ def _remove_processed_node(cls, parent, node):
+ """
+ Remove a node that has been fully processed and is now empty.
+ """
+ if len(node) != 0:
+ raise RuntimeError(f"Node {node} contains unprocessed child elements {list(node)}")
+ if node.attrib:
+ raise RuntimeError(f"Node {node} contains unprocessed attributes {node.attrib}")
+ if node.text is not None and node.text.strip():
+ raise RuntimeError(f"Node {node} contains unprocessed text {node.text}")
+ if parent is not None:
+ parent.remove(node)
+
+ @classmethod
+ def from_xml(cls, root: ET.Element):
+ """
+ Instantiate model from a XML root.
+ """
+
+ # We need to make sure we parse all data
+ # and that's why we remove processed elements and attributes and check that nothing remains.
+ # Otherwise we'd be sending partial XML back and that would lead to data loss.
+ #
+ # Let's make a copy of the xml tree because we'll destroy it during the process.
+ orig_root = root
+ root = copy.deepcopy(root)
+
+ kwargs = {}
+ for field_name, field in cls.__fields__.items():
+ xml_attribute = field.extra.get("xml_attribute", False)
+ xml_set_tag = field.extra.get("xml_set_tag", False)
+ xml_set_text = field.extra.get("xml_set_text", False)
+ xml_name = field.extra.get("xml_name", field_name)
+ xml_wrapped = field.extra.get("xml_wrapped", False)
+ xml_item_name = field.extra.get("xml_item_name", xml_name)
+ value: Any
+ node: Optional[ET.Element]
+
+ if xml_set_tag:
+ # field contains name of the ``root`` tag
+ if xml_wrapped:
+ # the last node wins (overrides the previous nodes)
+ for node in root[:]:
+ value = node.tag
+ cls._remove_processed_node(root, node)
+ else:
+ value = root.tag
+
+ kwargs[field_name] = value
+ continue
+
+ if xml_set_text:
+ # field contains the value (text) of the element
+ if xml_wrapped:
+ # the last node wins (overrides the previous nodes)
+ for node in root[:]:
+ value = node.text
+ node.text = None
+ cls._remove_processed_node(root, node)
+ else:
+ value = root.text
+ root.text = None
+
+ value = value.strip()
+ kwargs[field_name] = value
+ continue
+
+ if xml_attribute:
+ # field is an attribute that contains a scalar
+ if xml_name not in root.attrib:
+ continue
+ value = cls.value_from_string(field, root.attrib.pop(xml_name))
+ kwargs[field_name] = value
+ continue
+
+ if field.origin_type is list:
+ if xml_wrapped:
+ wrapper_node = root.find(xml_name)
+ # we'll consider all nodes inside the wrapper node
+ nodes = wrapper_node[:] if wrapper_node is not None else None
+ else:
+ wrapper_node = None
+ # we'll consider only nodes with matching name
+ nodes = root.findall(xml_item_name)
+
+ if not nodes:
+ if wrapper_node is not None:
+ cls._remove_processed_node(root, wrapper_node)
+ continue
+
+ values = []
+ for node in nodes:
+ if field.is_model_list:
+ klass = field.inner_type
+ entry = klass.from_xml(node)
+
+ # clear node as it was checked in from_xml() already
+ node.text = None
+ node.attrib = {}
+ node[:] = []
+ else:
+ entry = cls.value_from_string(field, node.text)
+ node.text = None
+
+ values.append(entry)
+
+ if xml_wrapped:
+ cls._remove_processed_node(wrapper_node, node)
+ else:
+ cls._remove_processed_node(root, node)
+
+ if xml_wrapped:
+ cls._remove_processed_node(root, wrapper_node)
+
+ kwargs[field_name] = values
+ continue
+
+ if field.is_model:
+ # field contains an instance of XmlModel
+ assert xml_name is not None
+ node = root.find(xml_name)
+ if node is None:
+ continue
+ klass = field.origin_type
+ kwargs[field_name] = klass.from_xml(node)
+
+ # clear node as it was checked in from_xml() already
+ node.text = None
+ node.attrib = {}
+ node[:] = []
+
+ cls._remove_processed_node(root, node)
+ continue
+
+ # field contains a scalar
+ node = root.find(xml_name)
+ if node is None:
+ continue
+ value = cls.value_from_string(field, node.text)
+ node.text = None
+ if value is None:
+ if field.is_optional:
+ continue
+ value = ""
+ kwargs[field_name] = value
+ cls._remove_processed_node(root, node)
+
+ cls._remove_processed_node(None, root)
+
+ obj = cls(**kwargs)
+ obj.__dict__["_root"] = orig_root
+ return obj
+
+ @classmethod
+ def xml_request(cls, method: str, apiurl: str, path: List[str], query: Optional[dict] = None, data: Optional[str] = None) -> urllib3.response.HTTPResponse:
+ from ..connection import http_request
+ from ..core import makeurl
+ url = makeurl(apiurl, path, query)
+ # TODO: catch HTTPError and return the wrapped response as XmlModel instance
+ return http_request(method, url, data=data, retry_on_400=False)
+
+ def do_update(self, other: "XmlModel") -> None:
+ """
+ Update values of the fields in the current model instance from another.
+ """
+ self._values = copy.deepcopy(other._values)
+
+ def do_edit(self) -> Tuple[str, str, "XmlModel"]:
+ """
+ Serialize model as XML and open it in an editor for editing.
+ Return a tuple with:
+ * a string with original data
+ * a string with edited data
+ * an instance of the class with edited data loaded
+
+ IMPORTANT: This method is always interactive.
+ """
+ from ..core import run_editor
+ from ..output import get_user_input
+
+ def write_file(f, data):
+ f.seek(0)
+ f.write(data)
+ f.truncate()
+ f.flush()
+
+ with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8", prefix="obs_xml_", suffix=".xml") as f:
+ original_data = self.to_string()
+ write_file(f, original_data)
+
+ while True:
+ run_editor(f.name)
+ try:
+ edited_obj = self.__class__.from_file(f.name)
+ f.seek(0)
+ edited_data = f.read()
+ break
+ except Exception as e:
+ reply = get_user_input(
+ f"""
+ The edited data is not valid.
+ {e}
+ """,
+ answers={"a": "abort", "e": "edit", "u": "undo changes and edit"},
+ )
+ if reply == "a":
+ from .. import oscerr
+ raise oscerr.UserAbort()
+ elif reply == "e":
+ continue
+ elif reply == "u":
+ write_file(f, original_data)
+ continue
+
+ return original_data, edited_data, edited_obj
diff --git a/tests/test_models_xmlmodel.py b/tests/test_models_xmlmodel.py
new file mode 100644
index 00000000..36f3942b
--- /dev/null
+++ b/tests/test_models_xmlmodel.py
@@ -0,0 +1,154 @@
+import textwrap
+import unittest
+
+from osc.util.models import *
+
+
+class TestXmlModel(unittest.TestCase):
+ def test_attribute(self):
+ class TestModel(XmlModel):
+ XML_TAG = "tag"
+ value: str = Field(xml_attribute=True)
+
+ m = TestModel(value="FOO")
+ self.assertEqual(m.dict(), {"value": "FOO"})
+ expected = """"""
+ self.assertEqual(m.to_string(), expected)
+
+ # verify that we can also load the serialized data
+ m = TestModel.from_string(expected)
+ self.assertEqual(m.to_string(), expected)
+
+ def test_element(self):
+ class TestModel(XmlModel):
+ XML_TAG = "tag"
+ value: str = Field()
+
+ m = TestModel(value="FOO")
+ self.assertEqual(m.dict(), {"value": "FOO"})
+ expected = textwrap.dedent(
+ """
+
+ FOO
+
+ """
+ ).strip()
+ self.assertEqual(m.to_string(), expected)
+
+ # verify that we can also load the serialized data
+ m = TestModel.from_string(expected)
+ self.assertEqual(m.to_string(), expected)
+
+ def test_element_list(self):
+ class TestModel(XmlModel):
+ XML_TAG = "tag"
+ value_list: List[str] = Field(xml_name="value")
+
+ m = TestModel(value_list=["FOO", "BAR"])
+ self.assertEqual(m.dict(), {"value_list": ["FOO", "BAR"]})
+ expected = textwrap.dedent(
+ """
+
+ FOO
+ BAR
+
+ """
+ ).strip()
+ self.assertEqual(m.to_string(), expected)
+
+ # verify that we can also load the serialized data
+ m = TestModel.from_string(expected)
+ self.assertEqual(m.to_string(), expected)
+
+ def test_child_model(self):
+ class ChildModel(XmlModel):
+ XML_TAG = "child"
+ value: str = Field()
+
+ class ParentModel(XmlModel):
+ XML_TAG = "parent"
+ text: str = Field()
+ child: ChildModel = Field()
+
+ m = ParentModel(text="TEXT", child={"value": "FOO"})
+ expected = textwrap.dedent(
+ """
+
+ TEXT
+
+ FOO
+
+
+ """
+ ).strip()
+ self.assertEqual(m.to_string(), expected)
+
+ # verify that we can also load the serialized data
+ m = ParentModel.from_string(expected)
+ self.assertEqual(m.to_string(), expected)
+
+ def test_child_model_list(self):
+ class ChildModel(XmlModel):
+ XML_TAG = "child"
+ value: str = Field()
+
+ class ParentModel(XmlModel):
+ XML_TAG = "parent"
+ text: str = Field()
+ child: List[ChildModel] = Field()
+
+ m = ParentModel(text="TEXT", child=[{"value": "FOO"}, {"value": "BAR"}])
+ expected = textwrap.dedent(
+ """
+
+ TEXT
+
+ FOO
+
+
+ BAR
+
+
+ """
+ ).strip()
+ self.assertEqual(m.to_string(), expected)
+
+ # verify that we can also load the serialized data
+ m = ParentModel.from_string(expected)
+ self.assertEqual(m.to_string(), expected)
+
+ def test_child_model_list_wrapped(self):
+ class ChildModel(XmlModel):
+ XML_TAG = "child"
+ value: str = Field()
+
+ class ParentModel(XmlModel):
+ XML_TAG = "parent"
+ text: str = Field()
+ child: List[ChildModel] = Field(xml_wrapped=True, xml_name="children")
+
+ m = ParentModel(text="TEXT", child=[{"value": "FOO"}, {"value": "BAR"}])
+ expected = textwrap.dedent(
+ """
+
+ TEXT
+
+
+ FOO
+
+
+ BAR
+
+
+
+ """
+ ).strip()
+ self.assertEqual(m.to_string(), expected)
+
+ # verify that we can also load the serialized data
+ m = ParentModel.from_string(expected)
+ self.assertEqual(m.to_string(), expected)
+
+
+if __name__ == "__main__":
+ unittest.main()