diff --git a/osc/util/models.py b/osc/util/models.py index d0f8dfb0..250e4b6a 100644 --- a/osc/util/models.py +++ b/osc/util/models.py @@ -9,9 +9,12 @@ This module IS NOT a supported API, it is meant for osc internal use only. import copy import inspect import sys +import tempfile import types +import typing from typing import Callable from typing import get_type_hints +from xml.etree import ElementTree as ET # supported types from enum import Enum @@ -23,7 +26,6 @@ from typing import Optional from typing import Tuple from typing import Union - if sys.version_info < (3, 8): def get_origin(typ): @@ -36,9 +38,14 @@ if sys.version_info < (3, 8): else: from typing import get_origin +import urllib3.response + +from . import xml + __all__ = ( "BaseModel", + "XmlModel", "Field", "NotSet", "FromParent", @@ -297,6 +304,10 @@ class Field(property): else: new_value.append(i) value = new_value + elif self.is_model and isinstance(value, str) and hasattr(self.origin_type, "XML_TAG_FIELD"): + klass = self.origin_type + key = getattr(self.origin_type, "XML_TAG_FIELD") + value = klass(**{key: value}) self.validate_type(value) obj._values[self.name] = value @@ -390,3 +401,350 @@ class BaseModel(metaclass=ModelMeta): result[name] = value return result + + +class XmlModel(BaseModel): + XML_TAG = None + + def to_xml(self) -> ET.Element: + xml_tag = None + + # check if there's a special field that sets the tag + for field_name, field in self.__fields__.items(): + xml_set_tag = field.extra.get("xml_set_tag", False) + if xml_set_tag: + value = getattr(self, field_name) + xml_tag = value + break + + # use the value from the class + if xml_tag is None: + xml_tag = self.XML_TAG + + assert xml_tag is not None + root = ET.Element(xml_tag) + + for field_name, field in self.__fields__.items(): + xml_attribute = field.extra.get("xml_attribute", False) + xml_set_tag = field.extra.get("xml_set_tag", False) + xml_set_text = field.extra.get("xml_set_text", False) + xml_name = field.extra.get("xml_name", field_name) + xml_wrapped = field.extra.get("xml_wrapped", False) + xml_item_name = field.extra.get("xml_item_name", xml_name) + + if xml_set_tag: + # a special case when the field determines the top-level tag name + continue + + value = getattr(self, field_name) + if value is None: + # skip fields that are not set + continue + + # if value is wrapped into an external element, create it + if xml_wrapped: + wrapper_node = ET.SubElement(root, xml_name) + else: + wrapper_node = root + + if xml_set_text: + wrapper_node.text = str(value) + continue + + if field.origin_type == list: + for entry in value: + if isinstance(entry, dict): + klass = field.inner_type + obj = klass(**entry) + node = obj.to_xml() + wrapper_node.append(node) + elif field.inner_type and issubclass(field.inner_type, XmlModel): + wrapper_node.append(entry.to_xml()) + else: + node = ET.SubElement(wrapper_node, xml_item_name) + if xml_attribute: + node.attrib[xml_attribute] = entry + else: + node.text = entry + elif issubclass(field.origin_type, XmlModel): + wrapper_node.append(value.to_xml()) + elif xml_attribute: + wrapper_node.attrib[xml_name] = str(value) + else: + node = ET.SubElement(wrapper_node, xml_name) + node.text = str(value) + return root + + @classmethod + def from_string(cls, string: str) -> "XmlModel": + """ + Instantiate model from string. + """ + root = ET.fromstring(string) + return cls.from_xml(root) + + @classmethod + def from_file(cls, file: Union[str, typing.IO]) -> "XmlModel": + """ + Instantiate model from file. + """ + root = ET.parse(file).getroot() + return cls.from_xml(root) + + def to_bytes(self) -> bytes: + """ + Serialize the object as XML and return it as utf-8 encoded bytes. + """ + root = self.to_xml() + xml.xml_indent(root) + return ET.tostring(root, encoding="utf-8") + + def to_string(self) -> str: + """ + Serialize the object as XML and return it as a string. + """ + return self.to_bytes().decode("utf-8") + + def to_file(self, file: Union[str, typing.IO]) -> None: + """ + Serialize the object as XML and save it to an utf-8 encoded file. + """ + root = self.to_xml() + xml.xml_indent(root) + return ET.ElementTree(root).write(file, encoding="utf-8") + + @staticmethod + def value_from_string(field, value): + """ + Convert field value from string to the actual type of the field. + """ + if field.origin_type is bool: + if value.lower() in ["1", "yes", "true", "on"]: + value = True + return value + if value.lower() in ["0", "no", "false", "off"]: + value = False + return value + + if field.origin_type is int: + value = int(value) + return value + + return value + + @classmethod + def _remove_processed_node(cls, parent, node): + """ + Remove a node that has been fully processed and is now empty. + """ + if len(node) != 0: + raise RuntimeError(f"Node {node} contains unprocessed child elements {list(node)}") + if node.attrib: + raise RuntimeError(f"Node {node} contains unprocessed attributes {node.attrib}") + if node.text is not None and node.text.strip(): + raise RuntimeError(f"Node {node} contains unprocessed text {node.text}") + if parent is not None: + parent.remove(node) + + @classmethod + def from_xml(cls, root: ET.Element): + """ + Instantiate model from a XML root. + """ + + # We need to make sure we parse all data + # and that's why we remove processed elements and attributes and check that nothing remains. + # Otherwise we'd be sending partial XML back and that would lead to data loss. + # + # Let's make a copy of the xml tree because we'll destroy it during the process. + orig_root = root + root = copy.deepcopy(root) + + kwargs = {} + for field_name, field in cls.__fields__.items(): + xml_attribute = field.extra.get("xml_attribute", False) + xml_set_tag = field.extra.get("xml_set_tag", False) + xml_set_text = field.extra.get("xml_set_text", False) + xml_name = field.extra.get("xml_name", field_name) + xml_wrapped = field.extra.get("xml_wrapped", False) + xml_item_name = field.extra.get("xml_item_name", xml_name) + value: Any + node: Optional[ET.Element] + + if xml_set_tag: + # field contains name of the ``root`` tag + if xml_wrapped: + # the last node wins (overrides the previous nodes) + for node in root[:]: + value = node.tag + cls._remove_processed_node(root, node) + else: + value = root.tag + + kwargs[field_name] = value + continue + + if xml_set_text: + # field contains the value (text) of the element + if xml_wrapped: + # the last node wins (overrides the previous nodes) + for node in root[:]: + value = node.text + node.text = None + cls._remove_processed_node(root, node) + else: + value = root.text + root.text = None + + value = value.strip() + kwargs[field_name] = value + continue + + if xml_attribute: + # field is an attribute that contains a scalar + if xml_name not in root.attrib: + continue + value = cls.value_from_string(field, root.attrib.pop(xml_name)) + kwargs[field_name] = value + continue + + if field.origin_type is list: + if xml_wrapped: + wrapper_node = root.find(xml_name) + # we'll consider all nodes inside the wrapper node + nodes = wrapper_node[:] if wrapper_node is not None else None + else: + wrapper_node = None + # we'll consider only nodes with matching name + nodes = root.findall(xml_item_name) + + if not nodes: + if wrapper_node is not None: + cls._remove_processed_node(root, wrapper_node) + continue + + values = [] + for node in nodes: + if field.is_model_list: + klass = field.inner_type + entry = klass.from_xml(node) + + # clear node as it was checked in from_xml() already + node.text = None + node.attrib = {} + node[:] = [] + else: + entry = cls.value_from_string(field, node.text) + node.text = None + + values.append(entry) + + if xml_wrapped: + cls._remove_processed_node(wrapper_node, node) + else: + cls._remove_processed_node(root, node) + + if xml_wrapped: + cls._remove_processed_node(root, wrapper_node) + + kwargs[field_name] = values + continue + + if field.is_model: + # field contains an instance of XmlModel + assert xml_name is not None + node = root.find(xml_name) + if node is None: + continue + klass = field.origin_type + kwargs[field_name] = klass.from_xml(node) + + # clear node as it was checked in from_xml() already + node.text = None + node.attrib = {} + node[:] = [] + + cls._remove_processed_node(root, node) + continue + + # field contains a scalar + node = root.find(xml_name) + if node is None: + continue + value = cls.value_from_string(field, node.text) + node.text = None + if value is None: + if field.is_optional: + continue + value = "" + kwargs[field_name] = value + cls._remove_processed_node(root, node) + + cls._remove_processed_node(None, root) + + obj = cls(**kwargs) + obj.__dict__["_root"] = orig_root + return obj + + @classmethod + def xml_request(cls, method: str, apiurl: str, path: List[str], query: Optional[dict] = None, data: Optional[str] = None) -> urllib3.response.HTTPResponse: + from ..connection import http_request + from ..core import makeurl + url = makeurl(apiurl, path, query) + # TODO: catch HTTPError and return the wrapped response as XmlModel instance + return http_request(method, url, data=data, retry_on_400=False) + + def do_update(self, other: "XmlModel") -> None: + """ + Update values of the fields in the current model instance from another. + """ + self._values = copy.deepcopy(other._values) + + def do_edit(self) -> Tuple[str, str, "XmlModel"]: + """ + Serialize model as XML and open it in an editor for editing. + Return a tuple with: + * a string with original data + * a string with edited data + * an instance of the class with edited data loaded + + IMPORTANT: This method is always interactive. + """ + from ..core import run_editor + from ..output import get_user_input + + def write_file(f, data): + f.seek(0) + f.write(data) + f.truncate() + f.flush() + + with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8", prefix="obs_xml_", suffix=".xml") as f: + original_data = self.to_string() + write_file(f, original_data) + + while True: + run_editor(f.name) + try: + edited_obj = self.__class__.from_file(f.name) + f.seek(0) + edited_data = f.read() + break + except Exception as e: + reply = get_user_input( + f""" + The edited data is not valid. + {e} + """, + answers={"a": "abort", "e": "edit", "u": "undo changes and edit"}, + ) + if reply == "a": + from .. import oscerr + raise oscerr.UserAbort() + elif reply == "e": + continue + elif reply == "u": + write_file(f, original_data) + continue + + return original_data, edited_data, edited_obj diff --git a/tests/test_models_xmlmodel.py b/tests/test_models_xmlmodel.py new file mode 100644 index 00000000..36f3942b --- /dev/null +++ b/tests/test_models_xmlmodel.py @@ -0,0 +1,154 @@ +import textwrap +import unittest + +from osc.util.models import * + + +class TestXmlModel(unittest.TestCase): + def test_attribute(self): + class TestModel(XmlModel): + XML_TAG = "tag" + value: str = Field(xml_attribute=True) + + m = TestModel(value="FOO") + self.assertEqual(m.dict(), {"value": "FOO"}) + expected = """""" + self.assertEqual(m.to_string(), expected) + + # verify that we can also load the serialized data + m = TestModel.from_string(expected) + self.assertEqual(m.to_string(), expected) + + def test_element(self): + class TestModel(XmlModel): + XML_TAG = "tag" + value: str = Field() + + m = TestModel(value="FOO") + self.assertEqual(m.dict(), {"value": "FOO"}) + expected = textwrap.dedent( + """ + + FOO + + """ + ).strip() + self.assertEqual(m.to_string(), expected) + + # verify that we can also load the serialized data + m = TestModel.from_string(expected) + self.assertEqual(m.to_string(), expected) + + def test_element_list(self): + class TestModel(XmlModel): + XML_TAG = "tag" + value_list: List[str] = Field(xml_name="value") + + m = TestModel(value_list=["FOO", "BAR"]) + self.assertEqual(m.dict(), {"value_list": ["FOO", "BAR"]}) + expected = textwrap.dedent( + """ + + FOO + BAR + + """ + ).strip() + self.assertEqual(m.to_string(), expected) + + # verify that we can also load the serialized data + m = TestModel.from_string(expected) + self.assertEqual(m.to_string(), expected) + + def test_child_model(self): + class ChildModel(XmlModel): + XML_TAG = "child" + value: str = Field() + + class ParentModel(XmlModel): + XML_TAG = "parent" + text: str = Field() + child: ChildModel = Field() + + m = ParentModel(text="TEXT", child={"value": "FOO"}) + expected = textwrap.dedent( + """ + + TEXT + + FOO + + + """ + ).strip() + self.assertEqual(m.to_string(), expected) + + # verify that we can also load the serialized data + m = ParentModel.from_string(expected) + self.assertEqual(m.to_string(), expected) + + def test_child_model_list(self): + class ChildModel(XmlModel): + XML_TAG = "child" + value: str = Field() + + class ParentModel(XmlModel): + XML_TAG = "parent" + text: str = Field() + child: List[ChildModel] = Field() + + m = ParentModel(text="TEXT", child=[{"value": "FOO"}, {"value": "BAR"}]) + expected = textwrap.dedent( + """ + + TEXT + + FOO + + + BAR + + + """ + ).strip() + self.assertEqual(m.to_string(), expected) + + # verify that we can also load the serialized data + m = ParentModel.from_string(expected) + self.assertEqual(m.to_string(), expected) + + def test_child_model_list_wrapped(self): + class ChildModel(XmlModel): + XML_TAG = "child" + value: str = Field() + + class ParentModel(XmlModel): + XML_TAG = "parent" + text: str = Field() + child: List[ChildModel] = Field(xml_wrapped=True, xml_name="children") + + m = ParentModel(text="TEXT", child=[{"value": "FOO"}, {"value": "BAR"}]) + expected = textwrap.dedent( + """ + + TEXT + + + FOO + + + BAR + + + + """ + ).strip() + self.assertEqual(m.to_string(), expected) + + # verify that we can also load the serialized data + m = ParentModel.from_string(expected) + self.assertEqual(m.to_string(), expected) + + +if __name__ == "__main__": + unittest.main()