From e6def0c2b8820ca4cd8e1267419300970721a15a Mon Sep 17 00:00:00 2001 From: Cedric Bosdonnat Date: Mon, 7 Sep 2020 15:00:40 +0200 Subject: [PATCH] Backport virt patches from 3001+ (#256) * Fix various spelling mistakes in master branch (#55954) * Fix typo of additional Signed-off-by: Benjamin Drung * Fix typo of against Signed-off-by: Benjamin Drung * Fix typo of amount Signed-off-by: Benjamin Drung * Fix typo of argument Signed-off-by: Benjamin Drung * Fix typo of attempt Signed-off-by: Benjamin Drung * Fix typo of bandwidth Signed-off-by: Benjamin Drung * Fix typo of caught Signed-off-by: Benjamin Drung * Fix typo of compatibility Signed-off-by: Benjamin Drung * Fix typo of consistency Signed-off-by: Benjamin Drung * Fix typo of conversions Signed-off-by: Benjamin Drung * Fix typo of corresponding Signed-off-by: Benjamin Drung * Fix typo of dependent Signed-off-by: Benjamin Drung * Fix typo of dictionary Signed-off-by: Benjamin Drung * Fix typo of disabled Signed-off-by: Benjamin Drung * Fix typo of adapters Signed-off-by: Benjamin Drung * Fix typo of disassociates Signed-off-by: Benjamin Drung * Fix typo of changes Signed-off-by: Benjamin Drung * Fix typo of command Signed-off-by: Benjamin Drung * Fix typo of communicate Signed-off-by: Benjamin Drung * Fix typo of community Signed-off-by: Benjamin Drung * Fix typo of configuration Signed-off-by: Benjamin Drung * Fix typo of default Signed-off-by: Benjamin Drung * Fix typo of absence Signed-off-by: Benjamin Drung * Fix typo of attribute Signed-off-by: Benjamin Drung * Fix typo of container Signed-off-by: Benjamin Drung * Fix typo of described Signed-off-by: Benjamin Drung * Fix typo of existence Signed-off-by: Benjamin Drung * Fix typo of explicit Signed-off-by: Benjamin Drung * Fix typo of formatted Signed-off-by: Benjamin Drung * Fix typo of guarantees Signed-off-by: Benjamin Drung * Fix typo of hexadecimal Signed-off-by: Benjamin Drung * Fix typo of hierarchy Signed-off-by: Benjamin Drung * Fix typo of initialize Signed-off-by: Benjamin Drung * Fix typo of label Signed-off-by: Benjamin Drung * Fix typo of management Signed-off-by: Benjamin Drung * Fix typo of mismatch Signed-off-by: Benjamin Drung * Fix typo of don't Signed-off-by: Benjamin Drung * Fix typo of manually Signed-off-by: Benjamin Drung * Fix typo of getting Signed-off-by: Benjamin Drung * Fix typo of information Signed-off-by: Benjamin Drung * Fix typo of meant Signed-off-by: Benjamin Drung * Fix typo of nonexistent Signed-off-by: Benjamin Drung * Fix typo of occur Signed-off-by: Benjamin Drung * Fix typo of omitted Signed-off-by: Benjamin Drung * Fix typo of normally Signed-off-by: Benjamin Drung * Fix typo of overridden Signed-off-by: Benjamin Drung * Fix typo of repository Signed-off-by: Benjamin Drung * Fix typo of separate Signed-off-by: Benjamin Drung * Fix typo of separator Signed-off-by: Benjamin Drung * Fix typo of specific Signed-off-by: Benjamin Drung * Fix typo of successful Signed-off-by: Benjamin Drung * Fix typo of succeeded Signed-off-by: Benjamin Drung * Fix typo of support Signed-off-by: Benjamin Drung * Fix typo of version Signed-off-by: Benjamin Drung * Fix typo of that's Signed-off-by: Benjamin Drung * Fix typo of "will be removed" Signed-off-by: Benjamin Drung * Fix typo of release Signed-off-by: Benjamin Drung * Fix typo of synchronize Signed-off-by: Benjamin Drung * Fix typo of python Signed-off-by: Benjamin Drung * Fix typo of usually Signed-off-by: Benjamin Drung * Fix typo of override Signed-off-by: Benjamin Drung * Fix typo of running Signed-off-by: Benjamin Drung * Fix typo of whether Signed-off-by: Benjamin Drung * Fix typo of package Signed-off-by: Benjamin Drung * Fix typo of persist Signed-off-by: Benjamin Drung * Fix typo of preferred Signed-off-by: Benjamin Drung * Fix typo of present Signed-off-by: Benjamin Drung * Fix typo of run Signed-off-by: Benjamin Drung * Fix spelling mistake of "allows someone to..." "Allows to" is not correct English. It must either be "allows someone to" or "allows doing". Signed-off-by: Benjamin Drung * Fix spelling mistake of "number of times" Signed-off-by: Benjamin Drung * Fix spelling mistake of msgpack Signed-off-by: Benjamin Drung * Fix spelling mistake of daemonized Signed-off-by: Benjamin Drung * Fix spelling mistake of daemons Signed-off-by: Benjamin Drung * Fix spelling mistake of extemporaneous Signed-off-by: Benjamin Drung * Fix spelling mistake of instead Signed-off-by: Benjamin Drung * Fix spelling mistake of returning Signed-off-by: Benjamin Drung * Fix literal comparissons * virt: Convert cpu_baseline ElementTree to string In commit 0f5184c (Remove minidom use in virt module) the value of `cpu` become `xml.etree.ElementTree.Element` and no longer has a method `toxml()`. This results in the following error: $ salt '*' virt.cpu_baseline host2: The minion function caused an exception: Traceback (most recent call last): File "/usr/lib/python3.7/site-packages/salt/minion.py", line 1675, in _thread_return return_data = minion_instance.executors[fname](opts, data, func, args, kwargs) File "/usr/lib/python3.7/site-packages/salt/executors/direct_call.py", line 12, in execute return func(*args, **kwargs) File "/usr/lib/python3.7/site-packages/salt/modules/virt.py", line 4410, in cpu_baseline return cpu.toxml() AttributeError: 'xml.etree.ElementTree.Element' object has no attribute 'toxml' Signed-off-by: Radostin Stoyanov * PR#57374 backport virt: pool secret should be undefined in pool_undefine not pool_delete virt: handle build differently depending on the pool type virt: don't fail if the pool secret has been removed * PR #57396 backport add firmware auto select feature * virt: Update dependencies Closes: #57641 Signed-off-by: Radostin Stoyanov * use null in sls file to map None object add sls file example reword doc * Update virt module and states and their tests to python3 * PR #57545 backport Move virt.init boot_dev parameter away from the kwargs virt: handle boot device in virt.update() virt: add boot_dev parameter to virt.running state * PR #57431 backport virt: Handle no available hypervisors virt: Remove unused imports * Blacken salt * Add method to remove circular references in data objects and add test (#54930) * Add method to remove circular references in data objects and add test * remove trailing whitespace * Blacken changed files Co-authored-by: xeacott Co-authored-by: Frode Gundersen Co-authored-by: Daniel A. Wozniak * PR #58332 backport virt: add debug log with VM XML definition Add xmlutil.get_xml_node() helper function Add salt.utils.data.get_value function Add change_xml() function to xmlutil virt.update: refactor the XML diffing code virt.test_update: move some code to make test more readable Co-authored-by: Benjamin Drung Co-authored-by: Pedro Algarvio Co-authored-by: Radostin Stoyanov Co-authored-by: Firefly Co-authored-by: Blacken Salt Co-authored-by: Joe Eacott <31625359+xeacott@users.noreply.github.com> Co-authored-by: xeacott Co-authored-by: Frode Gundersen Co-authored-by: Daniel A. Wozniak --- changelog/56454.fixed | 1 + changelog/57544.added | 1 + changelog/58331.fixed | 1 + salt/modules/virt.py | 442 ++++--- salt/states/virt.py | 171 ++- salt/templates/virt/libvirt_domain.jinja | 2 +- salt/utils/data.py | 977 +++++++++------ salt/utils/xmlutil.py | 251 +++- tests/pytests/unit/utils/test_data.py | 57 + tests/pytests/unit/utils/test_xmlutil.py | 169 +++ tests/unit/modules/test_virt.py | 218 ++-- tests/unit/states/test_virt.py | 98 +- tests/unit/utils/test_data.py | 1399 ++++++++++++---------- tests/unit/utils/test_xmlutil.py | 164 +-- 14 files changed, 2588 insertions(+), 1363 deletions(-) create mode 100644 changelog/56454.fixed create mode 100644 changelog/57544.added create mode 100644 changelog/58331.fixed create mode 100644 tests/pytests/unit/utils/test_data.py create mode 100644 tests/pytests/unit/utils/test_xmlutil.py diff --git a/changelog/56454.fixed b/changelog/56454.fixed new file mode 100644 index 0000000000..978b4b6e03 --- /dev/null +++ b/changelog/56454.fixed @@ -0,0 +1 @@ +Better handle virt.pool_rebuild in virt.pool_running and virt.pool_defined states diff --git a/changelog/57544.added b/changelog/57544.added new file mode 100644 index 0000000000..52071cf2c7 --- /dev/null +++ b/changelog/57544.added @@ -0,0 +1 @@ +Allow setting VM boot devices order in virt.running and virt.defined states diff --git a/changelog/58331.fixed b/changelog/58331.fixed new file mode 100644 index 0000000000..4b8f78dd53 --- /dev/null +++ b/changelog/58331.fixed @@ -0,0 +1 @@ +Leave boot parameters untouched if boot parameter is set to None in virt.update diff --git a/salt/modules/virt.py b/salt/modules/virt.py index a78c21e323..cd80fbe608 100644 --- a/salt/modules/virt.py +++ b/salt/modules/virt.py @@ -1,8 +1,11 @@ -# -*- coding: utf-8 -*- """ Work with virtual machines managed by libvirt -:depends: libvirt Python module +:depends: + * libvirt Python module + * libvirt client + * qemu-img + * grep Connection ========== @@ -73,7 +76,7 @@ The calls not using the libvirt connection setup are: # of his in the virt func module have been used # Import python libs -from __future__ import absolute_import, print_function, unicode_literals + import base64 import copy import datetime @@ -89,23 +92,19 @@ from xml.etree import ElementTree from xml.sax import saxutils # Import third party libs -import jinja2 import jinja2.exceptions # Import salt libs +import salt.utils.data import salt.utils.files import salt.utils.json -import salt.utils.network import salt.utils.path import salt.utils.stringutils import salt.utils.templates -import salt.utils.validate.net -import salt.utils.versions import salt.utils.xmlutil as xmlutil import salt.utils.yaml from salt._compat import ipaddress from salt.exceptions import CommandExecutionError, SaltInvocationError -from salt.ext import six from salt.ext.six.moves import range # pylint: disable=import-error,redefined-builtin from salt.ext.six.moves.urllib.parse import urlparse, urlunparse from salt.utils.virt import check_remote, download_remote @@ -227,8 +226,8 @@ def __get_conn(**kwargs): ) except Exception: # pylint: disable=broad-except raise CommandExecutionError( - "Sorry, {0} failed to open a connection to the hypervisor " - "software at {1}".format(__grains__["fqdn"], conn_str) + "Sorry, {} failed to open a connection to the hypervisor " + "software at {}".format(__grains__["fqdn"], conn_str) ) return conn @@ -405,7 +404,7 @@ def _get_nics(dom): # driver, source, and match can all have optional attributes if re.match("(driver|source|address)", v_node.tag): temp = {} - for key, value in six.iteritems(v_node.attrib): + for key, value in v_node.attrib.items(): temp[key] = value nic[v_node.tag] = temp # virtualport needs to be handled separately, to pick up the @@ -413,7 +412,7 @@ def _get_nics(dom): if v_node.tag == "virtualport": temp = {} temp["type"] = v_node.get("type") - for key, value in six.iteritems(v_node.attrib): + for key, value in v_node.attrib.items(): temp[key] = value nic["virtualport"] = temp if "mac" not in nic: @@ -435,7 +434,7 @@ def _get_graphics(dom): } doc = ElementTree.fromstring(dom.XMLDesc(0)) for g_node in doc.findall("devices/graphics"): - for key, value in six.iteritems(g_node.attrib): + for key, value in g_node.attrib.items(): out[key] = value return out @@ -448,7 +447,7 @@ def _get_loader(dom): doc = ElementTree.fromstring(dom.XMLDesc(0)) for g_node in doc.findall("os/loader"): out["path"] = g_node.text - for key, value in six.iteritems(g_node.attrib): + for key, value in g_node.attrib.items(): out[key] = value return out @@ -503,7 +502,7 @@ def _get_disks(conn, dom): qemu_target = source.get("protocol") source_name = source.get("name") if source_name: - qemu_target = "{0}:{1}".format(qemu_target, source_name) + qemu_target = "{}:{}".format(qemu_target, source_name) # Reverse the magic for the rbd and gluster pools if source.get("protocol") in ["rbd", "gluster"]: @@ -633,7 +632,7 @@ def _get_target(target, ssh): proto = "qemu" if ssh: proto += "+ssh" - return " {0}://{1}/{2}".format(proto, target, "system") + return " {}://{}/{}".format(proto, target, "system") def _gen_xml( @@ -648,6 +647,7 @@ def _gen_xml( arch, graphics=None, boot=None, + boot_dev=None, **kwargs ): """ @@ -657,8 +657,8 @@ def _gen_xml( context = { "hypervisor": hypervisor, "name": name, - "cpu": six.text_type(cpu), - "mem": six.text_type(mem), + "cpu": str(cpu), + "mem": str(mem), } if hypervisor in ["qemu", "kvm"]: context["controller_model"] = False @@ -681,15 +681,17 @@ def _gen_xml( graphics = None context["graphics"] = graphics - if "boot_dev" in kwargs: - context["boot_dev"] = [] - for dev in kwargs["boot_dev"].split(): - context["boot_dev"].append(dev) - else: - context["boot_dev"] = ["hd"] + context["boot_dev"] = boot_dev.split() if boot_dev is not None else ["hd"] context["boot"] = boot if boot else {} + # if efi parameter is specified, prepare os_attrib + efi_value = context["boot"].get("efi", None) if boot else None + if efi_value is True: + context["boot"]["os_attrib"] = "firmware='efi'" + elif efi_value is not None and type(efi_value) != bool: + raise SaltInvocationError("Invalid efi value") + if os_type == "xen": # Compute the Xen PV boot method if __grains__["os_family"] == "Suse": @@ -720,7 +722,7 @@ def _gen_xml( "target_dev": _get_disk_target(targets, len(diskp), prefix), "disk_bus": disk["model"], "format": disk.get("format", "raw"), - "index": six.text_type(i), + "index": str(i), } targets.append(disk_context["target_dev"]) if disk.get("source_file"): @@ -825,8 +827,8 @@ def _gen_vol_xml( "name": name, "target": {"permissions": permissions, "nocow": nocow}, "format": format, - "size": six.text_type(size), - "allocation": six.text_type(int(allocation) * 1024), + "size": str(size), + "allocation": str(int(allocation) * 1024), "backingStore": backing_store, } fn_ = "libvirt_volume.jinja" @@ -978,31 +980,29 @@ def _zfs_image_create( """ if not disk_image_name and not disk_size: raise CommandExecutionError( - "Unable to create new disk {0}, please specify" + "Unable to create new disk {}, please specify" " the disk image name or disk size argument".format(disk_name) ) if not pool: raise CommandExecutionError( - "Unable to create new disk {0}, please specify" + "Unable to create new disk {}, please specify" " the disk pool name".format(disk_name) ) - destination_fs = os.path.join(pool, "{0}.{1}".format(vm_name, disk_name)) + destination_fs = os.path.join(pool, "{}.{}".format(vm_name, disk_name)) log.debug("Image destination will be %s", destination_fs) existing_disk = __salt__["zfs.list"](name=pool) if "error" in existing_disk: raise CommandExecutionError( - "Unable to create new disk {0}. {1}".format( + "Unable to create new disk {}. {}".format( destination_fs, existing_disk["error"] ) ) elif destination_fs in existing_disk: log.info( - "ZFS filesystem {0} already exists. Skipping creation".format( - destination_fs - ) + "ZFS filesystem {} already exists. Skipping creation".format(destination_fs) ) blockdevice_path = os.path.join("/dev/zvol", pool, vm_name) return blockdevice_path @@ -1025,7 +1025,7 @@ def _zfs_image_create( ) blockdevice_path = os.path.join( - "/dev/zvol", pool, "{0}.{1}".format(vm_name, disk_name) + "/dev/zvol", pool, "{}.{}".format(vm_name, disk_name) ) log.debug("Image path will be %s", blockdevice_path) return blockdevice_path @@ -1042,7 +1042,7 @@ def _qemu_image_create(disk, create_overlay=False, saltenv="base"): if not disk_size and not disk_image: raise CommandExecutionError( - "Unable to create new disk {0}, please specify" + "Unable to create new disk {}, please specify" " disk size and/or disk image argument".format(disk["filename"]) ) @@ -1066,7 +1066,7 @@ def _qemu_image_create(disk, create_overlay=False, saltenv="base"): if create_overlay and qcow2: log.info("Cloning qcow2 image %s using copy on write", sfn) __salt__["cmd.run"]( - 'qemu-img create -f qcow2 -o backing_file="{0}" "{1}"'.format( + 'qemu-img create -f qcow2 -o backing_file="{}" "{}"'.format( sfn, img_dest ).split() ) @@ -1079,16 +1079,16 @@ def _qemu_image_create(disk, create_overlay=False, saltenv="base"): if disk_size and qcow2: log.debug("Resize qcow2 image to %sM", disk_size) __salt__["cmd.run"]( - 'qemu-img resize "{0}" {1}M'.format(img_dest, disk_size) + 'qemu-img resize "{}" {}M'.format(img_dest, disk_size) ) log.debug("Apply umask and remove exec bit") mode = (0o0777 ^ mask) & 0o0666 os.chmod(img_dest, mode) - except (IOError, OSError) as err: + except OSError as err: raise CommandExecutionError( - "Problem while copying image. {0} - {1}".format(disk_image, err) + "Problem while copying image. {} - {}".format(disk_image, err) ) else: @@ -1099,13 +1099,13 @@ def _qemu_image_create(disk, create_overlay=False, saltenv="base"): if disk_size: log.debug("Create empty image with size %sM", disk_size) __salt__["cmd.run"]( - 'qemu-img create -f {0} "{1}" {2}M'.format( + 'qemu-img create -f {} "{}" {}M'.format( disk.get("format", "qcow2"), img_dest, disk_size ) ) else: raise CommandExecutionError( - "Unable to create new disk {0}," + "Unable to create new disk {}," " please specify argument".format(img_dest) ) @@ -1113,9 +1113,9 @@ def _qemu_image_create(disk, create_overlay=False, saltenv="base"): mode = (0o0777 ^ mask) & 0o0666 os.chmod(img_dest, mode) - except (IOError, OSError) as err: + except OSError as err: raise CommandExecutionError( - "Problem while creating volume {0} - {1}".format(img_dest, err) + "Problem while creating volume {} - {}".format(img_dest, err) ) return img_dest @@ -1252,7 +1252,7 @@ def _disk_profile(conn, profile, hypervisor, disks, vm_name): __salt__["config.get"]("virt:disk", {}).get(profile, default) ) - # Transform the list to remove one level of dictionnary and add the name as a property + # Transform the list to remove one level of dictionary and add the name as a property disklist = [dict(d, name=name) for disk in disklist for name, d in disk.items()] # Merge with the user-provided disks definitions @@ -1274,7 +1274,7 @@ def _disk_profile(conn, profile, hypervisor, disks, vm_name): disk["model"] = "ide" # Add the missing properties that have defaults - for key, val in six.iteritems(overlay): + for key, val in overlay.items(): if key not in disk: disk[key] = val @@ -1296,7 +1296,7 @@ def _fill_disk_filename(conn, vm_name, disk, hypervisor, pool_caps): Compute the disk file name and update it in the disk value. """ # Compute the filename without extension since it may not make sense for some pool types - disk["filename"] = "{0}_{1}".format(vm_name, disk["name"]) + disk["filename"] = "{}_{}".format(vm_name, disk["name"]) # Compute the source file path base_dir = disk.get("pool", None) @@ -1311,7 +1311,7 @@ def _fill_disk_filename(conn, vm_name, disk, hypervisor, pool_caps): # For path-based disks, keep the qcow2 default format if not disk.get("format"): disk["format"] = "qcow2" - disk["filename"] = "{0}.{1}".format(disk["filename"], disk["format"]) + disk["filename"] = "{}.{}".format(disk["filename"], disk["format"]) disk["source_file"] = os.path.join(base_dir, disk["filename"]) else: if "pool" not in disk: @@ -1365,7 +1365,7 @@ def _fill_disk_filename(conn, vm_name, disk, hypervisor, pool_caps): disk["format"] = volume_options.get("default_format", None) elif hypervisor == "bhyve" and vm_name: - disk["filename"] = "{0}.{1}".format(vm_name, disk["name"]) + disk["filename"] = "{}.{}".format(vm_name, disk["name"]) disk["source_file"] = os.path.join( "/dev/zvol", base_dir or "", disk["filename"] ) @@ -1373,8 +1373,8 @@ def _fill_disk_filename(conn, vm_name, disk, hypervisor, pool_caps): elif hypervisor in ["esxi", "vmware"]: if not base_dir: base_dir = __salt__["config.get"]("virt:storagepool", "[0] ") - disk["filename"] = "{0}.{1}".format(disk["filename"], disk["format"]) - disk["source_file"] = "{0}{1}".format(base_dir, disk["filename"]) + disk["filename"] = "{}.{}".format(disk["filename"], disk["format"]) + disk["source_file"] = "{}{}".format(base_dir, disk["filename"]) def _complete_nics(interfaces, hypervisor): @@ -1422,7 +1422,7 @@ def _complete_nics(interfaces, hypervisor): """ Apply the default overlay to attributes """ - for key, value in six.iteritems(overlays[hypervisor]): + for key, value in overlays[hypervisor].items(): if key not in attributes or not attributes[key]: attributes[key] = value @@ -1449,7 +1449,7 @@ def _nic_profile(profile_name, hypervisor): """ Append dictionary profile data to interfaces list """ - for interface_name, attributes in six.iteritems(profile_dict): + for interface_name, attributes in profile_dict.items(): attributes["name"] = interface_name interfaces.append(attributes) @@ -1522,17 +1522,24 @@ def _handle_remote_boot_params(orig_boot): new_boot = orig_boot.copy() keys = orig_boot.keys() cases = [ + {"efi"}, + {"kernel", "initrd", "efi"}, + {"kernel", "initrd", "cmdline", "efi"}, {"loader", "nvram"}, {"kernel", "initrd"}, {"kernel", "initrd", "cmdline"}, - {"loader", "nvram", "kernel", "initrd"}, - {"loader", "nvram", "kernel", "initrd", "cmdline"}, + {"kernel", "initrd", "loader", "nvram"}, + {"kernel", "initrd", "cmdline", "loader", "nvram"}, ] try: if keys in cases: for key in keys: - if orig_boot.get(key) is not None and check_remote(orig_boot.get(key)): + if key == "efi" and type(orig_boot.get(key)) == bool: + new_boot[key] = orig_boot.get(key) + elif orig_boot.get(key) is not None and check_remote( + orig_boot.get(key) + ): if saltinst_dir is None: os.makedirs(CACHE_DIR) saltinst_dir = CACHE_DIR @@ -1540,12 +1547,41 @@ def _handle_remote_boot_params(orig_boot): return new_boot else: raise SaltInvocationError( - "Invalid boot parameters, (kernel, initrd) or/and (loader, nvram) must be both present" + "Invalid boot parameters,It has to follow this combination: [(kernel, initrd) or/and cmdline] or/and [(loader, nvram) or efi]" ) except Exception as err: # pylint: disable=broad-except raise err +def _handle_efi_param(boot, desc): + """ + Checks if boot parameter contains efi boolean value, if so, handles the firmware attribute. + :param boot: The boot parameters passed to the init or update functions. + :param desc: The XML description of that domain. + :return: A boolean value. + """ + efi_value = boot.get("efi", None) if boot else None + parent_tag = desc.find("os") + os_attrib = parent_tag.attrib + + # newly defined vm without running, loader tag might not be filled yet + if efi_value is False and os_attrib != {}: + parent_tag.attrib.pop("firmware", None) + return True + + # check the case that loader tag might be present. This happens after the vm ran + elif type(efi_value) == bool and os_attrib == {}: + if efi_value is True and parent_tag.find("loader") is None: + parent_tag.set("firmware", "efi") + if efi_value is False and parent_tag.find("loader") is not None: + parent_tag.remove(parent_tag.find("loader")) + parent_tag.remove(parent_tag.find("nvram")) + return True + elif type(efi_value) != bool: + raise SaltInvocationError("Invalid efi value") + return False + + def init( name, cpu, @@ -1566,6 +1602,7 @@ def init( os_type=None, arch=None, boot=None, + boot_dev=None, **kwargs ): """ @@ -1635,7 +1672,8 @@ def init( This is an optional parameter, all of the keys are optional within the dictionary. The structure of the dictionary is documented in :ref:`init-boot-def`. If a remote path is provided to kernel or initrd, salt will handle the downloading of the specified remote file and modify the XML accordingly. - To boot VM with UEFI, specify loader and nvram path. + To boot VM with UEFI, specify loader and nvram path or specify 'efi': ``True`` if your libvirtd version + is >= 5.2.0 and QEMU >= 3.0.0. .. versionadded:: 3000 @@ -1649,6 +1687,12 @@ def init( 'nvram': '/usr/share/OVMF/OVMF_VARS.ms.fd' } + :param boot_dev: + Space separated list of devices to boot from sorted by decreasing priority. + Values can be ``hd``, ``fd``, ``cdrom`` or ``network``. + + By default, the value will ``"hd"``. + .. _init-boot-def: .. rubric:: Boot parameters definition @@ -1674,6 +1718,11 @@ def init( .. versionadded:: sodium + efi + A boolean value. + + .. versionadded:: sodium + .. _init-nic-def: .. rubric:: Network Interfaces Definitions @@ -1797,7 +1846,7 @@ def init( .. rubric:: Graphics Definition - The graphics dictionnary can have the following properties: + The graphics dictionary can have the following properties: type Graphics type. The possible values are ``none``, ``'spice'``, ``'vnc'`` and other values @@ -1858,6 +1907,8 @@ def init( for x in y } ) + if len(hypervisors) == 0: + raise SaltInvocationError("No supported hypervisors were found") virt_hypervisor = "kvm" if "kvm" in hypervisors else hypervisors[0] # esxi used to be a possible value for the hypervisor: map it to vmware since it's the same @@ -1890,8 +1941,8 @@ def init( else: # assume libvirt manages disks for us log.debug("Generating libvirt XML for %s", _disk) - volume_name = "{0}/{1}".format(name, _disk["name"]) - filename = "{0}.{1}".format(volume_name, _disk["format"]) + volume_name = "{}/{}".format(name, _disk["name"]) + filename = "{}.{}".format(volume_name, _disk["format"]) vol_xml = _gen_vol_xml( filename, _disk["size"], format=_disk["format"] ) @@ -1939,7 +1990,7 @@ def init( else: # Unknown hypervisor raise SaltInvocationError( - "Unsupported hypervisor when handling disk image: {0}".format( + "Unsupported hypervisor when handling disk image: {}".format( virt_hypervisor ) ) @@ -1965,8 +2016,10 @@ def init( arch, graphics, boot, + boot_dev, **kwargs ) + log.debug("New virtual machine definition: %s", vm_xml) conn.defineXML(vm_xml) except libvirt.libvirtError as err: conn.close() @@ -2192,6 +2245,7 @@ def update( live=True, boot=None, test=False, + boot_dev=None, **kwargs ): """ @@ -2234,11 +2288,28 @@ def update( Refer to :ref:`init-boot-def` for the complete boot parameter description. - To update any boot parameters, specify the new path for each. To remove any boot parameters, - pass a None object, for instance: 'kernel': ``None``. + To update any boot parameters, specify the new path for each. To remove any boot parameters, pass ``None`` object, + for instance: 'kernel': ``None``. To switch back to BIOS boot, specify ('loader': ``None`` and 'nvram': ``None``) + or 'efi': ``False``. Please note that ``None`` is mapped to ``null`` in sls file, pass ``null`` in sls file instead. + + SLS file Example: + + .. code-block:: yaml + + - boot: + loader: null + nvram: null .. versionadded:: 3000 + :param boot_dev: + Space separated list of devices to boot from sorted by decreasing priority. + Values can be ``hd``, ``fd``, ``cdrom`` or ``network``. + + By default, the value will ``"hd"``. + + .. versionadded:: Magnesium + :param test: run in dry-run mode if set to True .. versionadded:: sodium @@ -2286,6 +2357,8 @@ def update( if boot is not None: boot = _handle_remote_boot_params(boot) + if boot.get("efi", None) is not None: + need_update = _handle_efi_param(boot, desc) new_desc = ElementTree.fromstring( _gen_xml( @@ -2307,76 +2380,58 @@ def update( # Update the cpu cpu_node = desc.find("vcpu") if cpu and int(cpu_node.text) != cpu: - cpu_node.text = six.text_type(cpu) - cpu_node.set("current", six.text_type(cpu)) + cpu_node.text = str(cpu) + cpu_node.set("current", str(cpu)) need_update = True - # Update the kernel boot parameters - boot_tags = ["kernel", "initrd", "cmdline", "loader", "nvram"] - parent_tag = desc.find("os") - - # We need to search for each possible subelement, and update it. - for tag in boot_tags: - # The Existing Tag... - found_tag = parent_tag.find(tag) - - # The new value - boot_tag_value = boot.get(tag, None) if boot else None - - # Existing tag is found and values don't match - if found_tag is not None and found_tag.text != boot_tag_value: - - # If the existing tag is found, but the new value is None - # remove it. If the existing tag is found, and the new value - # doesn't match update it. In either case, mark for update. - if boot_tag_value is None and boot is not None and parent_tag is not None: - parent_tag.remove(found_tag) - else: - found_tag.text = boot_tag_value + def _set_loader(node, value): + salt.utils.xmlutil.set_node_text(node, value) + if value is not None: + node.set("readonly", "yes") + node.set("type", "pflash") - # If the existing tag is loader or nvram, we need to update the corresponding attribute - if found_tag.tag == "loader" and boot_tag_value is not None: - found_tag.set("readonly", "yes") - found_tag.set("type", "pflash") + def _set_nvram(node, value): + node.set("template", value) - if found_tag.tag == "nvram" and boot_tag_value is not None: - found_tag.set("template", found_tag.text) - found_tag.text = None + def _set_with_mib_unit(node, value): + node.text = str(value) + node.set("unit", "MiB") - need_update = True - - # Existing tag is not found, but value is not None - elif found_tag is None and boot_tag_value is not None: - - # Need to check for parent tag, and add it if it does not exist. - # Add a subelement and set the value to the new value, and then - # mark for update. - if parent_tag is not None: - child_tag = ElementTree.SubElement(parent_tag, tag) - else: - new_parent_tag = ElementTree.Element("os") - child_tag = ElementTree.SubElement(new_parent_tag, tag) - - child_tag.text = boot_tag_value - - # If the newly created tag is loader or nvram, we need to update the corresponding attribute - if child_tag.tag == "loader": - child_tag.set("readonly", "yes") - child_tag.set("type", "pflash") - - if child_tag.tag == "nvram": - child_tag.set("template", child_tag.text) - child_tag.text = None - - need_update = True + # Update the kernel boot parameters + params_mapping = [ + {"path": "boot:kernel", "xpath": "os/kernel"}, + {"path": "boot:initrd", "xpath": "os/initrd"}, + {"path": "boot:cmdline", "xpath": "os/cmdline"}, + {"path": "boot:loader", "xpath": "os/loader", "set": _set_loader}, + {"path": "boot:nvram", "xpath": "os/nvram", "set": _set_nvram}, + # Update the memory, note that libvirt outputs all memory sizes in KiB + { + "path": "mem", + "xpath": "memory", + "get": lambda n: int(n.text) / 1024, + "set": _set_with_mib_unit, + }, + { + "path": "mem", + "xpath": "currentMemory", + "get": lambda n: int(n.text) / 1024, + "set": _set_with_mib_unit, + }, + { + "path": "boot_dev:{dev}", + "xpath": "os/boot[$dev]", + "get": lambda n: n.get("dev"), + "set": lambda n, v: n.set("dev", v), + "del": salt.utils.xmlutil.del_attribute("dev"), + }, + ] - # Update the memory, note that libvirt outputs all memory sizes in KiB - for mem_node_name in ["memory", "currentMemory"]: - mem_node = desc.find(mem_node_name) - if mem and int(mem_node.text) != mem * 1024: - mem_node.text = six.text_type(mem) - mem_node.set("unit", "MiB") - need_update = True + data = {k: v for k, v in locals().items() if bool(v)} + if boot_dev: + data["boot_dev"] = {i + 1: dev for i, dev in enumerate(boot_dev.split())} + need_update = need_update or salt.utils.xmlutil.change_xml( + desc, data, params_mapping + ) # Update the XML definition with the new disks and diff changes devices_node = desc.find("devices") @@ -2395,8 +2450,8 @@ def update( if func_locals.get(param, None) is not None ]: old = devices_node.findall(dev_type) - new = new_desc.findall("devices/{0}".format(dev_type)) - changes[dev_type] = globals()["_diff_{0}_lists".format(dev_type)](old, new) + new = new_desc.findall("devices/{}".format(dev_type)) + changes[dev_type] = globals()["_diff_{}_lists".format(dev_type)](old, new) if changes[dev_type]["deleted"] or changes[dev_type]["new"]: for item in old: devices_node.remove(item) @@ -2423,9 +2478,9 @@ def update( _disk_volume_create(conn, all_disks[idx]) if not test: - conn.defineXML( - salt.utils.stringutils.to_str(ElementTree.tostring(desc)) - ) + xml_desc = ElementTree.tostring(desc) + log.debug("Update virtual machine definition: %s", xml_desc) + conn.defineXML(salt.utils.stringutils.to_str(xml_desc)) status["definition"] = True except libvirt.libvirtError as err: conn.close() @@ -2554,7 +2609,7 @@ def update( except libvirt.libvirtError as err: if "errors" not in status: status["errors"] = [] - status["errors"].append(six.text_type(err)) + status["errors"].append(str(err)) conn.close() return status @@ -2768,7 +2823,7 @@ def _node_info(conn): info = { "cpucores": raw[6], "cpumhz": raw[3], - "cpumodel": six.text_type(raw[0]), + "cpumodel": str(raw[0]), "cpus": raw[2], "cputhreads": raw[7], "numanodes": raw[4], @@ -3207,24 +3262,21 @@ def get_profiles(hypervisor=None, **kwargs): for x in y } ) - default_hypervisor = "kvm" if "kvm" in hypervisors else hypervisors[0] + if len(hypervisors) == 0: + raise SaltInvocationError("No supported hypervisors were found") if not hypervisor: - hypervisor = default_hypervisor + hypervisor = "kvm" if "kvm" in hypervisors else hypervisors[0] virtconf = __salt__["config.get"]("virt", {}) for typ in ["disk", "nic"]: - _func = getattr(sys.modules[__name__], "_{0}_profile".format(typ)) + _func = getattr(sys.modules[__name__], "_{}_profile".format(typ)) ret[typ] = { - "default": _func( - "default", hypervisor if hypervisor else default_hypervisor - ) + "default": _func("default", hypervisor) } if typ in virtconf: ret.setdefault(typ, {}) for prf in virtconf[typ]: - ret[typ][prf] = _func( - prf, hypervisor if hypervisor else default_hypervisor - ) + ret[typ][prf] = _func(prf, hypervisor) return ret @@ -3506,7 +3558,7 @@ def create_xml_path(path, **kwargs): return create_xml_str( salt.utils.stringutils.to_unicode(fp_.read()), **kwargs ) - except (OSError, IOError): + except OSError: return False @@ -3564,7 +3616,7 @@ def define_xml_path(path, **kwargs): return define_xml_str( salt.utils.stringutils.to_unicode(fp_.read()), **kwargs ) - except (OSError, IOError): + except OSError: return False @@ -3576,7 +3628,7 @@ def _define_vol_xml_str(conn, xml, pool=None): # pylint: disable=redefined-oute poolname = ( pool if pool else __salt__["config.get"]("virt:storagepool", default_pool) ) - pool = conn.storagePoolLookupByName(six.text_type(poolname)) + pool = conn.storagePoolLookupByName(str(poolname)) ret = pool.createXML(xml, 0) is not None return ret @@ -3660,7 +3712,7 @@ def define_vol_xml_path(path, pool=None, **kwargs): return define_vol_xml_str( salt.utils.stringutils.to_unicode(fp_.read()), pool=pool, **kwargs ) - except (OSError, IOError): + except OSError: return False @@ -3777,7 +3829,7 @@ def seed_non_shared_migrate(disks, force=False): salt '*' virt.seed_non_shared_migrate """ - for _, data in six.iteritems(disks): + for _, data in disks.items(): fn_ = data["file"] form = data["file format"] size = data["virtual size"].split()[1][1:] @@ -3921,14 +3973,14 @@ def purge(vm_, dirs=False, removables=False, **kwargs): # TODO create solution for 'dataset is busy' time.sleep(3) fs_name = disks[disk]["file"][len("/dev/zvol/") :] - log.info("Destroying VM ZFS volume {0}".format(fs_name)) + log.info("Destroying VM ZFS volume {}".format(fs_name)) __salt__["zfs.destroy"](name=fs_name, force=True) elif os.path.exists(disks[disk]["file"]): os.remove(disks[disk]["file"]) directories.add(os.path.dirname(disks[disk]["file"])) else: # We may have a volume to delete here - matcher = re.match("^(?P[^/]+)/(?P.*)$", disks[disk]["file"],) + matcher = re.match("^(?P[^/]+)/(?P.*)$", disks[disk]["file"]) if matcher: pool_name = matcher.group("pool") pool = None @@ -3975,7 +4027,7 @@ def _is_kvm_hyper(): with salt.utils.files.fopen("/proc/modules") as fp_: if "kvm_" not in salt.utils.stringutils.to_unicode(fp_.read()): return False - except IOError: + except OSError: # No /proc/modules? Are we on Windows? Or Solaris? return False return "libvirtd" in __salt__["cmd.run"](__grains__["ps"]) @@ -3995,7 +4047,7 @@ def _is_xen_hyper(): with salt.utils.files.fopen("/proc/modules") as fp_: if "xen_" not in salt.utils.stringutils.to_unicode(fp_.read()): return False - except (OSError, IOError): + except OSError: # No /proc/modules? Are we on Windows? Or Solaris? return False return "libvirtd" in __salt__["cmd.run"](__grains__["ps"]) @@ -4110,7 +4162,7 @@ def vm_cputime(vm_=None, **kwargs): cputime_percent = (1.0e-7 * cputime / host_cpus) / vcpus return { "cputime": int(raw[4]), - "cputime_percent": int("{0:.0f}".format(cputime_percent)), + "cputime_percent": int("{:.0f}".format(cputime_percent)), } info = {} @@ -4180,7 +4232,7 @@ def vm_netstats(vm_=None, **kwargs): "tx_errs": 0, "tx_drop": 0, } - for attrs in six.itervalues(nics): + for attrs in nics.values(): if "target" in attrs: dev = attrs["target"] stats = dom.interfaceStats(dev) @@ -4508,7 +4560,7 @@ def revert_snapshot(name, vm_snapshot=None, cleanup=False, **kwargs): conn.close() raise CommandExecutionError( snapshot - and 'Snapshot "{0}" not found'.format(vm_snapshot) + and 'Snapshot "{}" not found'.format(vm_snapshot) or "No more previous snapshots available" ) elif snap.isCurrent(): @@ -5102,10 +5154,10 @@ def cpu_baseline(full=False, migratable=False, out="libvirt", **kwargs): ] if not cpu_specs: - raise ValueError("Model {0} not found in CPU map".format(cpu_model)) + raise ValueError("Model {} not found in CPU map".format(cpu_model)) elif len(cpu_specs) > 1: raise ValueError( - "Multiple models {0} found in CPU map".format(cpu_model) + "Multiple models {} found in CPU map".format(cpu_model) ) cpu_specs = cpu_specs[0] @@ -5126,7 +5178,7 @@ def cpu_baseline(full=False, migratable=False, out="libvirt", **kwargs): "vendor": cpu.find("vendor").text, "features": [feature.get("name") for feature in cpu.findall("feature")], } - return cpu.toxml() + return ElementTree.tostring(cpu) def network_define(name, bridge, forward, ipv4_config=None, ipv6_config=None, **kwargs): @@ -5250,7 +5302,7 @@ def list_networks(**kwargs): def network_info(name=None, **kwargs): """ - Return informations on a virtual network provided its name. + Return information on a virtual network provided its name. :param name: virtual network name :param connection: libvirt connection URI, overriding defaults @@ -5446,20 +5498,20 @@ def _parse_pools_caps(doc): for option_kind in ["pool", "vol"]: options = {} default_format_node = pool.find( - "{0}Options/defaultFormat".format(option_kind) + "{}Options/defaultFormat".format(option_kind) ) if default_format_node is not None: options["default_format"] = default_format_node.get("type") options_enums = { enum.get("name"): [value.text for value in enum.findall("value")] - for enum in pool.findall("{0}Options/enum".format(option_kind)) + for enum in pool.findall("{}Options/enum".format(option_kind)) } if options_enums: options.update(options_enums) if options: if "options" not in pool_caps: pool_caps["options"] = {} - kind = option_kind if option_kind is not "vol" else "volume" + kind = option_kind if option_kind != "vol" else "volume" pool_caps["options"][kind] = options return pool_caps @@ -5695,7 +5747,7 @@ def pool_define( keys. The path is the qualified name for iSCSI devices. Report to `this libvirt page `_ - for more informations on the use of ``part_separator`` + for more information on the use of ``part_separator`` :param source_dir: Path to the source directory for pools of type ``dir``, ``netfs`` or ``gluster``. (Default: ``None``) @@ -5847,15 +5899,19 @@ def _pool_set_secret( if secret_type: # Get the previously defined secret if any secret = None - if usage: - usage_type = ( - libvirt.VIR_SECRET_USAGE_TYPE_CEPH - if secret_type == "ceph" - else libvirt.VIR_SECRET_USAGE_TYPE_ISCSI - ) - secret = conn.secretLookupByUsage(usage_type, usage) - elif uuid: - secret = conn.secretLookupByUUIDString(uuid) + try: + if usage: + usage_type = ( + libvirt.VIR_SECRET_USAGE_TYPE_CEPH + if secret_type == "ceph" + else libvirt.VIR_SECRET_USAGE_TYPE_ISCSI + ) + secret = conn.secretLookupByUsage(usage_type, usage) + elif uuid: + secret = conn.secretLookupByUUIDString(uuid) + except libvirt.libvirtError as err: + # For some reason the secret has been removed. Don't fail since we'll recreate it + log.info("Secret not found: %s", err.get_error_message()) # Create secret if needed if not secret: @@ -5918,7 +5974,7 @@ def pool_update( keys. The path is the qualified name for iSCSI devices. Report to `this libvirt page `_ - for more informations on the use of ``part_separator`` + for more information on the use of ``part_separator`` :param source_dir: Path to the source directory for pools of type ``dir``, ``netfs`` or ``gluster``. (Default: ``None``) @@ -6107,7 +6163,7 @@ def list_pools(**kwargs): def pool_info(name=None, **kwargs): """ - Return informations on a storage pool provided its name. + Return information on a storage pool provided its name. :param name: libvirt storage pool name :param connection: libvirt connection URI, overriding defaults @@ -6283,6 +6339,22 @@ def pool_undefine(name, **kwargs): conn = __get_conn(**kwargs) try: pool = conn.storagePoolLookupByName(name) + desc = ElementTree.fromstring(pool.XMLDesc()) + + # Is there a secret that we generated and would need to be removed? + # Don't remove the other secrets + auth_node = desc.find("source/auth") + if auth_node is not None: + auth_types = { + "ceph": libvirt.VIR_SECRET_USAGE_TYPE_CEPH, + "iscsi": libvirt.VIR_SECRET_USAGE_TYPE_ISCSI, + } + secret_type = auth_types[auth_node.get("type")] + secret_usage = auth_node.find("secret").get("usage") + if secret_type and "pool_{}".format(name) == secret_usage: + secret = conn.secretLookupByUsage(secret_type, secret_usage) + secret.undefine() + return not bool(pool.undefine()) finally: conn.close() @@ -6308,22 +6380,6 @@ def pool_delete(name, **kwargs): conn = __get_conn(**kwargs) try: pool = conn.storagePoolLookupByName(name) - desc = ElementTree.fromstring(pool.XMLDesc()) - - # Is there a secret that we generated and would need to be removed? - # Don't remove the other secrets - auth_node = desc.find("source/auth") - if auth_node is not None: - auth_types = { - "ceph": libvirt.VIR_SECRET_USAGE_TYPE_CEPH, - "iscsi": libvirt.VIR_SECRET_USAGE_TYPE_ISCSI, - } - secret_type = auth_types[auth_node.get("type")] - secret_usage = auth_node.find("secret").get("usage") - if secret_type and "pool_{}".format(name) == secret_usage: - secret = conn.secretLookupByUsage(secret_type, secret_usage) - secret.undefine() - return not bool(pool.delete(libvirt.VIR_STORAGE_POOL_DELETE_NORMAL)) finally: conn.close() @@ -6768,7 +6824,7 @@ def _volume_upload(conn, pool, volume, file, offset=0, length=0, sparse=False): stream.abort() if ret: raise CommandExecutionError( - "Failed to close file: {0}".format(err.strerror) + "Failed to close file: {}".format(err.strerror) ) if stream: try: @@ -6776,7 +6832,7 @@ def _volume_upload(conn, pool, volume, file, offset=0, length=0, sparse=False): except libvirt.libvirtError as err: if ret: raise CommandExecutionError( - "Failed to finish stream: {0}".format(err.get_error_message()) + "Failed to finish stream: {}".format(err.get_error_message()) ) return ret diff --git a/salt/states/virt.py b/salt/states/virt.py index fdef002293..3d99fd53c8 100644 --- a/salt/states/virt.py +++ b/salt/states/virt.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ Manage virt =========== @@ -13,9 +12,9 @@ for the generation and signing of certificates for systems running libvirt: """ # Import Python libs -from __future__ import absolute_import, print_function, unicode_literals import fnmatch +import logging import os # Import Salt libs @@ -25,9 +24,6 @@ import salt.utils.stringutils import salt.utils.versions from salt.exceptions import CommandExecutionError, SaltInvocationError -# Import 3rd-party libs -from salt.ext import six - try: import libvirt # pylint: disable=import-error @@ -38,6 +34,8 @@ except ImportError: __virtualname__ = "virt" +log = logging.getLogger(__name__) + def __virtual__(): """ @@ -99,8 +97,8 @@ def keys(name, basepath="/etc/pki", **kwargs): # rename them to something hopefully unique to avoid # overriding anything existing pillar_kwargs = {} - for key, value in six.iteritems(kwargs): - pillar_kwargs["ext_pillar_virt.{0}".format(key)] = value + for key, value in kwargs.items(): + pillar_kwargs["ext_pillar_virt.{}".format(key)] = value pillar = __salt__["pillar.ext"]({"libvirt": "_"}, pillar_kwargs) paths = { @@ -112,7 +110,7 @@ def keys(name, basepath="/etc/pki", **kwargs): } for key in paths: - p_key = "libvirt.{0}.pem".format(key) + p_key = "libvirt.{}.pem".format(key) if p_key not in pillar: continue if not os.path.exists(os.path.dirname(paths[key])): @@ -134,7 +132,7 @@ def keys(name, basepath="/etc/pki", **kwargs): for key in ret["changes"]: with salt.utils.files.fopen(paths[key], "w+") as fp_: fp_.write( - salt.utils.stringutils.to_str(pillar["libvirt.{0}.pem".format(key)]) + salt.utils.stringutils.to_str(pillar["libvirt.{}.pem".format(key)]) ) ret["comment"] = "Updated libvirt certs and keys" @@ -176,7 +174,7 @@ def _virt_call( domain_state = __salt__["virt.vm_state"](targeted_domain) action_needed = domain_state.get(targeted_domain) != state if action_needed: - response = __salt__["virt.{0}".format(function)]( + response = __salt__["virt.{}".format(function)]( targeted_domain, connection=connection, username=username, @@ -189,9 +187,7 @@ def _virt_call( else: noaction_domains.append(targeted_domain) except libvirt.libvirtError as err: - ignored_domains.append( - {"domain": targeted_domain, "issue": six.text_type(err)} - ) + ignored_domains.append({"domain": targeted_domain, "issue": str(err)}) if not changed_domains: ret["result"] = not ignored_domains and bool(targeted_domains) ret["comment"] = "No changes had happened" @@ -292,6 +288,7 @@ def defined( arch=None, boot=None, update=True, + boot_dev=None, ): """ Starts an existing guest, or defines and starts a new VM with specified arguments. @@ -352,6 +349,14 @@ def defined( .. deprecated:: sodium + :param boot_dev: + Space separated list of devices to boot from sorted by decreasing priority. + Values can be ``hd``, ``fd``, ``cdrom`` or ``network``. + + By default, the value will ``"hd"``. + + .. versionadded:: Magnesium + .. rubric:: Example States Make sure a virtual machine called ``domain_name`` is defined: @@ -362,6 +367,7 @@ def defined( virt.defined: - cpu: 2 - mem: 2048 + - boot_dev: network hd - disk_profile: prod - disks: - name: system @@ -414,17 +420,18 @@ def defined( password=password, boot=boot, test=__opts__["test"], + boot_dev=boot_dev, ) ret["changes"][name] = status if not status.get("definition"): - ret["comment"] = "Domain {0} unchanged".format(name) + ret["comment"] = "Domain {} unchanged".format(name) ret["result"] = True elif status.get("errors"): ret[ "comment" - ] = "Domain {0} updated with live update(s) failures".format(name) + ] = "Domain {} updated with live update(s) failures".format(name) else: - ret["comment"] = "Domain {0} updated".format(name) + ret["comment"] = "Domain {} updated".format(name) else: if not __opts__["test"]: __salt__["virt.init"]( @@ -448,12 +455,13 @@ def defined( password=password, boot=boot, start=False, + boot_dev=boot_dev, ) ret["changes"][name] = {"definition": True} - ret["comment"] = "Domain {0} defined".format(name) + ret["comment"] = "Domain {} defined".format(name) except libvirt.libvirtError as err: # Something bad happened when defining / updating the VM, report it - ret["comment"] = six.text_type(err) + ret["comment"] = str(err) ret["result"] = False return ret @@ -480,6 +488,7 @@ def running( os_type=None, arch=None, boot=None, + boot_dev=None, ): """ Starts an existing guest, or defines and starts a new VM with specified arguments. @@ -591,6 +600,14 @@ def running( .. versionadded:: 3000 + :param boot_dev: + Space separated list of devices to boot from sorted by decreasing priority. + Values can be ``hd``, ``fd``, ``cdrom`` or ``network``. + + By default, the value will ``"hd"``. + + .. versionadded:: Magnesium + .. rubric:: Example States Make sure an already-defined virtual machine called ``domain_name`` is running: @@ -609,6 +626,7 @@ def running( - cpu: 2 - mem: 2048 - disk_profile: prod + - boot_dev: network hd - disks: - name: system size: 8192 @@ -657,6 +675,7 @@ def running( arch=arch, boot=boot, update=update, + boot_dev=boot_dev, connection=connection, username=username, password=password, @@ -681,11 +700,11 @@ def running( ret["comment"] = comment ret["changes"][name]["started"] = True elif not changed: - ret["comment"] = "Domain {0} exists and is running".format(name) + ret["comment"] = "Domain {} exists and is running".format(name) except libvirt.libvirtError as err: # Something bad happened when starting / updating the VM, report it - ret["comment"] = six.text_type(err) + ret["comment"] = str(err) ret["result"] = False return ret @@ -830,7 +849,7 @@ def reverted( try: domains = fnmatch.filter(__salt__["virt.list_domains"](), name) if not domains: - ret["comment"] = 'No domains found for criteria "{0}"'.format(name) + ret["comment"] = 'No domains found for criteria "{}"'.format(name) else: ignored_domains = list() if len(domains) > 1: @@ -848,9 +867,7 @@ def reverted( } except CommandExecutionError as err: if len(domains) > 1: - ignored_domains.append( - {"domain": domain, "issue": six.text_type(err)} - ) + ignored_domains.append({"domain": domain, "issue": str(err)}) if len(domains) > 1: if result: ret["changes"]["reverted"].append(result) @@ -860,7 +877,7 @@ def reverted( ret["result"] = len(domains) != len(ignored_domains) if ret["result"]: - ret["comment"] = "Domain{0} has been reverted".format( + ret["comment"] = "Domain{} has been reverted".format( len(domains) > 1 and "s" or "" ) if ignored_domains: @@ -868,9 +885,9 @@ def reverted( if not ret["changes"]["reverted"]: ret["changes"].pop("reverted") except libvirt.libvirtError as err: - ret["comment"] = six.text_type(err) + ret["comment"] = str(err) except CommandExecutionError as err: - ret["comment"] = six.text_type(err) + ret["comment"] = str(err) return ret @@ -955,7 +972,7 @@ def network_defined( name, connection=connection, username=username, password=password ) if info and info[name]: - ret["comment"] = "Network {0} exists".format(name) + ret["comment"] = "Network {} exists".format(name) ret["result"] = True else: if not __opts__["test"]: @@ -974,7 +991,7 @@ def network_defined( password=password, ) ret["changes"][name] = "Network defined" - ret["comment"] = "Network {0} defined".format(name) + ret["comment"] = "Network {} defined".format(name) except libvirt.libvirtError as err: ret["result"] = False ret["comment"] = err.get_error_message() @@ -1108,6 +1125,10 @@ def network_running( return ret +# Some of the libvirt storage drivers do not support the build action +BUILDABLE_POOL_TYPES = {"disk", "fs", "netfs", "dir", "logical", "vstorage", "zfs"} + + def pool_defined( name, ptype=None, @@ -1222,25 +1243,35 @@ def pool_defined( action = "" if info[name]["state"] != "running": - if not __opts__["test"]: - __salt__["virt.pool_build"]( - name, - connection=connection, - username=username, - password=password, - ) - action = ", built" + if ptype in BUILDABLE_POOL_TYPES: + if not __opts__["test"]: + # Storage pools build like disk or logical will fail if the disk or LV group + # was already existing. Since we can't easily figure that out, just log the + # possible libvirt error. + try: + __salt__["virt.pool_build"]( + name, + connection=connection, + username=username, + password=password, + ) + except libvirt.libvirtError as err: + log.warning( + "Failed to build libvirt storage pool: %s", + err.get_error_message(), + ) + action = ", built" action = ( "{}, autostart flag changed".format(action) if needs_autostart else action ) - ret["changes"][name] = "Pool updated{0}".format(action) - ret["comment"] = "Pool {0} updated{1}".format(name, action) + ret["changes"][name] = "Pool updated{}".format(action) + ret["comment"] = "Pool {} updated{}".format(name, action) else: - ret["comment"] = "Pool {0} unchanged".format(name) + ret["comment"] = "Pool {} unchanged".format(name) ret["result"] = True else: needs_autostart = autostart @@ -1265,15 +1296,28 @@ def pool_defined( password=password, ) - __salt__["virt.pool_build"]( - name, connection=connection, username=username, password=password - ) + if ptype in BUILDABLE_POOL_TYPES: + # Storage pools build like disk or logical will fail if the disk or LV group + # was already existing. Since we can't easily figure that out, just log the + # possible libvirt error. + try: + __salt__["virt.pool_build"]( + name, + connection=connection, + username=username, + password=password, + ) + except libvirt.libvirtError as err: + log.warning( + "Failed to build libvirt storage pool: %s", + err.get_error_message(), + ) if needs_autostart: ret["changes"][name] = "Pool defined, marked for autostart" - ret["comment"] = "Pool {0} defined, marked for autostart".format(name) + ret["comment"] = "Pool {} defined, marked for autostart".format(name) else: ret["changes"][name] = "Pool defined" - ret["comment"] = "Pool {0} defined".format(name) + ret["comment"] = "Pool {} defined".format(name) if needs_autostart: if not __opts__["test"]: @@ -1374,7 +1418,7 @@ def pool_running( is_running = info.get(name, {}).get("state", "stopped") == "running" if is_running: if updated: - action = "built, restarted" + action = "restarted" if not __opts__["test"]: __salt__["virt.pool_stop"]( name, @@ -1382,13 +1426,16 @@ def pool_running( username=username, password=password, ) - if not __opts__["test"]: - __salt__["virt.pool_build"]( - name, - connection=connection, - username=username, - password=password, - ) + # if the disk or LV group is already existing build will fail (issue #56454) + if ptype in BUILDABLE_POOL_TYPES - {"disk", "logical"}: + if not __opts__["test"]: + __salt__["virt.pool_build"]( + name, + connection=connection, + username=username, + password=password, + ) + action = "built, {}".format(action) else: action = "already running" result = True @@ -1402,16 +1449,16 @@ def pool_running( password=password, ) - comment = "Pool {0}".format(name) + comment = "Pool {}".format(name) change = "Pool" if name in ret["changes"]: - comment = "{0},".format(ret["comment"]) - change = "{0},".format(ret["changes"][name]) + comment = "{},".format(ret["comment"]) + change = "{},".format(ret["changes"][name]) if action != "already running": - ret["changes"][name] = "{0} {1}".format(change, action) + ret["changes"][name] = "{} {}".format(change, action) - ret["comment"] = "{0} {1}".format(comment, action) + ret["comment"] = "{} {}".format(comment, action) ret["result"] = result except libvirt.libvirtError as err: @@ -1539,15 +1586,13 @@ def pool_deleted(name, purge=False, connection=None, username=None, password=Non ret["result"] = None if unsupported: - ret[ - "comment" - ] = 'Unsupported actions for pool of type "{0}": {1}'.format( + ret["comment"] = 'Unsupported actions for pool of type "{}": {}'.format( info[name]["type"], ", ".join(unsupported) ) else: - ret["comment"] = "Storage pool could not be found: {0}".format(name) + ret["comment"] = "Storage pool could not be found: {}".format(name) except libvirt.libvirtError as err: - ret["comment"] = "Failed deleting pool: {0}".format(err.get_error_message()) + ret["comment"] = "Failed deleting pool: {}".format(err.get_error_message()) ret["result"] = False return ret diff --git a/salt/templates/virt/libvirt_domain.jinja b/salt/templates/virt/libvirt_domain.jinja index aac6283eb0..04a61ffa78 100644 --- a/salt/templates/virt/libvirt_domain.jinja +++ b/salt/templates/virt/libvirt_domain.jinja @@ -3,7 +3,7 @@ {{ cpu }} {{ mem }} {{ mem }} - + {{ os_type }} {% if boot %} {% if 'kernel' in boot %} diff --git a/salt/utils/data.py b/salt/utils/data.py index 8f84c2ea42..5a7acc9e7c 100644 --- a/salt/utils/data.py +++ b/salt/utils/data.py @@ -1,22 +1,15 @@ -# -*- coding: utf-8 -*- -''' +""" Functions for manipulating, inspecting, or otherwise working with data types and data structures. -''' +""" -from __future__ import absolute_import, print_function, unicode_literals # Import Python libs import copy import fnmatch +import functools import logging import re -import functools - -try: - from collections.abc import Mapping, MutableMapping, Sequence -except ImportError: - from collections import Mapping, MutableMapping, Sequence # Import Salt libs import salt.utils.dictupdate @@ -24,13 +17,22 @@ import salt.utils.stringutils import salt.utils.yaml from salt.defaults import DEFAULT_TARGET_DELIM from salt.exceptions import SaltException -from salt.utils.decorators.jinja import jinja_filter -from salt.utils.odict import OrderedDict +from salt.ext import six # Import 3rd-party libs -from salt.ext.six.moves import zip # pylint: disable=redefined-builtin -from salt.ext import six from salt.ext.six.moves import range # pylint: disable=redefined-builtin +from salt.ext.six.moves import zip # pylint: disable=redefined-builtin +from salt.utils.decorators.jinja import jinja_filter +from salt.utils.odict import OrderedDict + +try: + from collections.abc import Mapping, MutableMapping, Sequence +except ImportError: + # pylint: disable=no-name-in-module + from collections import Mapping, MutableMapping, Sequence + + # pylint: enable=no-name-in-module + try: import jmespath @@ -41,15 +43,16 @@ log = logging.getLogger(__name__) class CaseInsensitiveDict(MutableMapping): - ''' + """ Inspired by requests' case-insensitive dict implementation, but works with non-string keys as well. - ''' + """ + def __init__(self, init=None, **kwargs): - ''' + """ Force internal dict to be ordered to ensure a consistent iteration order, irrespective of case. - ''' + """ self._data = OrderedDict() self.update(init or {}, **kwargs) @@ -67,7 +70,7 @@ class CaseInsensitiveDict(MutableMapping): return self._data[to_lowercase(key)][1] def __iter__(self): - return (item[0] for item in six.itervalues(self._data)) + return (item[0] for item in self._data.values()) def __eq__(self, rval): if not isinstance(rval, Mapping): @@ -76,28 +79,28 @@ class CaseInsensitiveDict(MutableMapping): return dict(self.items_lower()) == dict(CaseInsensitiveDict(rval).items_lower()) def __repr__(self): - return repr(dict(six.iteritems(self))) + return repr(dict(self.items())) def items_lower(self): - ''' + """ Returns a generator iterating over keys and values, with the keys all being lowercase. - ''' - return ((key, val[1]) for key, val in six.iteritems(self._data)) + """ + return ((key, val[1]) for key, val in self._data.items()) def copy(self): - ''' + """ Returns a copy of the object - ''' - return CaseInsensitiveDict(six.iteritems(self._data)) + """ + return CaseInsensitiveDict(self._data.items()) def __change_case(data, attr, preserve_dict_class=False): - ''' + """ Calls data.attr() if data has an attribute/method called attr. Processes data recursively if data is a Mapping or Sequence. For Mapping, processes both keys and values. - ''' + """ try: return getattr(data, attr)() except AttributeError: @@ -107,73 +110,120 @@ def __change_case(data, attr, preserve_dict_class=False): if isinstance(data, Mapping): return (data_type if preserve_dict_class else dict)( - (__change_case(key, attr, preserve_dict_class), - __change_case(val, attr, preserve_dict_class)) - for key, val in six.iteritems(data) + ( + __change_case(key, attr, preserve_dict_class), + __change_case(val, attr, preserve_dict_class), + ) + for key, val in data.items() ) if isinstance(data, Sequence): return data_type( - __change_case(item, attr, preserve_dict_class) for item in data) + __change_case(item, attr, preserve_dict_class) for item in data + ) return data def to_lowercase(data, preserve_dict_class=False): - ''' + """ Recursively changes everything in data to lowercase. - ''' - return __change_case(data, 'lower', preserve_dict_class) + """ + return __change_case(data, "lower", preserve_dict_class) def to_uppercase(data, preserve_dict_class=False): - ''' + """ Recursively changes everything in data to uppercase. - ''' - return __change_case(data, 'upper', preserve_dict_class) + """ + return __change_case(data, "upper", preserve_dict_class) -@jinja_filter('compare_dicts') +@jinja_filter("compare_dicts") def compare_dicts(old=None, new=None): - ''' + """ Compare before and after results from various salt functions, returning a dict describing the changes that were made. - ''' + """ ret = {} - for key in set((new or {})).union((old or {})): + for key in set(new or {}).union(old or {}): if key not in old: # New key - ret[key] = {'old': '', - 'new': new[key]} + ret[key] = {"old": "", "new": new[key]} elif key not in new: # Key removed - ret[key] = {'new': '', - 'old': old[key]} + ret[key] = {"new": "", "old": old[key]} elif new[key] != old[key]: # Key modified - ret[key] = {'old': old[key], - 'new': new[key]} + ret[key] = {"old": old[key], "new": new[key]} return ret -@jinja_filter('compare_lists') +@jinja_filter("compare_lists") def compare_lists(old=None, new=None): - ''' + """ Compare before and after results from various salt functions, returning a dict describing the changes that were made - ''' + """ ret = {} for item in new: if item not in old: - ret.setdefault('new', []).append(item) + ret.setdefault("new", []).append(item) for item in old: if item not in new: - ret.setdefault('old', []).append(item) + ret.setdefault("old", []).append(item) return ret -def decode(data, encoding=None, errors='strict', keep=False, - normalize=False, preserve_dict_class=False, preserve_tuples=False, - to_str=False): - ''' +def _remove_circular_refs(ob, _seen=None): + """ + Generic method to remove circular references from objects. + This has been taken from author Martijn Pieters + https://stackoverflow.com/questions/44777369/ + remove-circular-references-in-dicts-lists-tuples/44777477#44777477 + :param ob: dict, list, typle, set, and frozenset + Standard python object + :param object _seen: + Object that has circular reference + :returns: + Cleaned Python object + :rtype: + type(ob) + """ + if _seen is None: + _seen = set() + if id(ob) in _seen: + # Here we caught a circular reference. + # Alert user and cleanup to continue. + log.exception( + "Caught a circular reference in data structure below." + "Cleaning and continuing execution.\n%r\n", + ob, + ) + return None + _seen.add(id(ob)) + res = ob + if isinstance(ob, dict): + res = { + _remove_circular_refs(k, _seen): _remove_circular_refs(v, _seen) + for k, v in ob.items() + } + elif isinstance(ob, (list, tuple, set, frozenset)): + res = type(ob)(_remove_circular_refs(v, _seen) for v in ob) + # remove id again; only *nested* references count + _seen.remove(id(ob)) + return res + + +def decode( + data, + encoding=None, + errors="strict", + keep=False, + normalize=False, + preserve_dict_class=False, + preserve_tuples=False, + to_str=False, +): + """ Generic function which will decode whichever type is passed, if necessary. Optionally use to_str=True to ensure strings are str types and not unicode on Python 2. @@ -199,22 +249,55 @@ def decode(data, encoding=None, errors='strict', keep=False, two strings above, in which "й" is represented as two code points (i.e. one for the base character, and one for the breve mark). Normalizing allows for a more reliable test case. - ''' - _decode_func = salt.utils.stringutils.to_unicode \ - if not to_str \ + + """ + # Clean data object before decoding to avoid circular references + data = _remove_circular_refs(data) + + _decode_func = ( + salt.utils.stringutils.to_unicode + if not to_str else salt.utils.stringutils.to_str + ) if isinstance(data, Mapping): - return decode_dict(data, encoding, errors, keep, normalize, - preserve_dict_class, preserve_tuples, to_str) + return decode_dict( + data, + encoding, + errors, + keep, + normalize, + preserve_dict_class, + preserve_tuples, + to_str, + ) if isinstance(data, list): - return decode_list(data, encoding, errors, keep, normalize, - preserve_dict_class, preserve_tuples, to_str) + return decode_list( + data, + encoding, + errors, + keep, + normalize, + preserve_dict_class, + preserve_tuples, + to_str, + ) if isinstance(data, tuple): - return decode_tuple(data, encoding, errors, keep, normalize, - preserve_dict_class, to_str) \ - if preserve_tuples \ - else decode_list(data, encoding, errors, keep, normalize, - preserve_dict_class, preserve_tuples, to_str) + return ( + decode_tuple( + data, encoding, errors, keep, normalize, preserve_dict_class, to_str + ) + if preserve_tuples + else decode_list( + data, + encoding, + errors, + keep, + normalize, + preserve_dict_class, + preserve_tuples, + to_str, + ) + ) try: data = _decode_func(data, encoding, errors, normalize) except TypeError: @@ -228,25 +311,48 @@ def decode(data, encoding=None, errors='strict', keep=False, return data -def decode_dict(data, encoding=None, errors='strict', keep=False, - normalize=False, preserve_dict_class=False, - preserve_tuples=False, to_str=False): - ''' +def decode_dict( + data, + encoding=None, + errors="strict", + keep=False, + normalize=False, + preserve_dict_class=False, + preserve_tuples=False, + to_str=False, +): + """ Decode all string values to Unicode. Optionally use to_str=True to ensure strings are str types and not unicode on Python 2. - ''' - _decode_func = salt.utils.stringutils.to_unicode \ - if not to_str \ + """ + # Clean data object before decoding to avoid circular references + data = _remove_circular_refs(data) + + _decode_func = ( + salt.utils.stringutils.to_unicode + if not to_str else salt.utils.stringutils.to_str + ) # Make sure we preserve OrderedDicts ret = data.__class__() if preserve_dict_class else {} - for key, value in six.iteritems(data): + for key, value in data.items(): if isinstance(key, tuple): - key = decode_tuple(key, encoding, errors, keep, normalize, - preserve_dict_class, to_str) \ - if preserve_tuples \ - else decode_list(key, encoding, errors, keep, normalize, - preserve_dict_class, preserve_tuples, to_str) + key = ( + decode_tuple( + key, encoding, errors, keep, normalize, preserve_dict_class, to_str + ) + if preserve_tuples + else decode_list( + key, + encoding, + errors, + keep, + normalize, + preserve_dict_class, + preserve_tuples, + to_str, + ) + ) else: try: key = _decode_func(key, encoding, errors, normalize) @@ -260,17 +366,50 @@ def decode_dict(data, encoding=None, errors='strict', keep=False, raise if isinstance(value, list): - value = decode_list(value, encoding, errors, keep, normalize, - preserve_dict_class, preserve_tuples, to_str) + value = decode_list( + value, + encoding, + errors, + keep, + normalize, + preserve_dict_class, + preserve_tuples, + to_str, + ) elif isinstance(value, tuple): - value = decode_tuple(value, encoding, errors, keep, normalize, - preserve_dict_class, to_str) \ - if preserve_tuples \ - else decode_list(value, encoding, errors, keep, normalize, - preserve_dict_class, preserve_tuples, to_str) + value = ( + decode_tuple( + value, + encoding, + errors, + keep, + normalize, + preserve_dict_class, + to_str, + ) + if preserve_tuples + else decode_list( + value, + encoding, + errors, + keep, + normalize, + preserve_dict_class, + preserve_tuples, + to_str, + ) + ) elif isinstance(value, Mapping): - value = decode_dict(value, encoding, errors, keep, normalize, - preserve_dict_class, preserve_tuples, to_str) + value = decode_dict( + value, + encoding, + errors, + keep, + normalize, + preserve_dict_class, + preserve_tuples, + to_str, + ) else: try: value = _decode_func(value, encoding, errors, normalize) @@ -287,30 +426,69 @@ def decode_dict(data, encoding=None, errors='strict', keep=False, return ret -def decode_list(data, encoding=None, errors='strict', keep=False, - normalize=False, preserve_dict_class=False, - preserve_tuples=False, to_str=False): - ''' +def decode_list( + data, + encoding=None, + errors="strict", + keep=False, + normalize=False, + preserve_dict_class=False, + preserve_tuples=False, + to_str=False, +): + """ Decode all string values to Unicode. Optionally use to_str=True to ensure strings are str types and not unicode on Python 2. - ''' - _decode_func = salt.utils.stringutils.to_unicode \ - if not to_str \ + """ + # Clean data object before decoding to avoid circular references + data = _remove_circular_refs(data) + + _decode_func = ( + salt.utils.stringutils.to_unicode + if not to_str else salt.utils.stringutils.to_str + ) ret = [] for item in data: if isinstance(item, list): - item = decode_list(item, encoding, errors, keep, normalize, - preserve_dict_class, preserve_tuples, to_str) + item = decode_list( + item, + encoding, + errors, + keep, + normalize, + preserve_dict_class, + preserve_tuples, + to_str, + ) elif isinstance(item, tuple): - item = decode_tuple(item, encoding, errors, keep, normalize, - preserve_dict_class, to_str) \ - if preserve_tuples \ - else decode_list(item, encoding, errors, keep, normalize, - preserve_dict_class, preserve_tuples, to_str) + item = ( + decode_tuple( + item, encoding, errors, keep, normalize, preserve_dict_class, to_str + ) + if preserve_tuples + else decode_list( + item, + encoding, + errors, + keep, + normalize, + preserve_dict_class, + preserve_tuples, + to_str, + ) + ) elif isinstance(item, Mapping): - item = decode_dict(item, encoding, errors, keep, normalize, - preserve_dict_class, preserve_tuples, to_str) + item = decode_dict( + item, + encoding, + errors, + keep, + normalize, + preserve_dict_class, + preserve_tuples, + to_str, + ) else: try: item = _decode_func(item, encoding, errors, normalize) @@ -327,21 +505,35 @@ def decode_list(data, encoding=None, errors='strict', keep=False, return ret -def decode_tuple(data, encoding=None, errors='strict', keep=False, - normalize=False, preserve_dict_class=False, to_str=False): - ''' +def decode_tuple( + data, + encoding=None, + errors="strict", + keep=False, + normalize=False, + preserve_dict_class=False, + to_str=False, +): + """ Decode all string values to Unicode. Optionally use to_str=True to ensure strings are str types and not unicode on Python 2. - ''' + """ return tuple( - decode_list(data, encoding, errors, keep, normalize, - preserve_dict_class, True, to_str) + decode_list( + data, encoding, errors, keep, normalize, preserve_dict_class, True, to_str + ) ) -def encode(data, encoding=None, errors='strict', keep=False, - preserve_dict_class=False, preserve_tuples=False): - ''' +def encode( + data, + encoding=None, + errors="strict", + keep=False, + preserve_dict_class=False, + preserve_tuples=False, +): + """ Generic function which will encode whichever type is passed, if necessary If `strict` is True, and `keep` is False, and we fail to encode, a @@ -349,18 +541,27 @@ def encode(data, encoding=None, errors='strict', keep=False, original value to silently be returned in cases where encoding fails. This can be useful for cases where the data passed to this function is likely to contain binary blobs. - ''' + + """ + # Clean data object before encoding to avoid circular references + data = _remove_circular_refs(data) + if isinstance(data, Mapping): - return encode_dict(data, encoding, errors, keep, - preserve_dict_class, preserve_tuples) + return encode_dict( + data, encoding, errors, keep, preserve_dict_class, preserve_tuples + ) if isinstance(data, list): - return encode_list(data, encoding, errors, keep, - preserve_dict_class, preserve_tuples) + return encode_list( + data, encoding, errors, keep, preserve_dict_class, preserve_tuples + ) if isinstance(data, tuple): - return encode_tuple(data, encoding, errors, keep, preserve_dict_class) \ - if preserve_tuples \ - else encode_list(data, encoding, errors, keep, - preserve_dict_class, preserve_tuples) + return ( + encode_tuple(data, encoding, errors, keep, preserve_dict_class) + if preserve_tuples + else encode_list( + data, encoding, errors, keep, preserve_dict_class, preserve_tuples + ) + ) try: return salt.utils.stringutils.to_bytes(data, encoding, errors) except TypeError: @@ -374,20 +575,31 @@ def encode(data, encoding=None, errors='strict', keep=False, return data -@jinja_filter('json_decode_dict') # Remove this for Aluminium -@jinja_filter('json_encode_dict') -def encode_dict(data, encoding=None, errors='strict', keep=False, - preserve_dict_class=False, preserve_tuples=False): - ''' +@jinja_filter("json_decode_dict") # Remove this for Aluminium +@jinja_filter("json_encode_dict") +def encode_dict( + data, + encoding=None, + errors="strict", + keep=False, + preserve_dict_class=False, + preserve_tuples=False, +): + """ Encode all string values to bytes - ''' + """ + # Clean data object before encoding to avoid circular references + data = _remove_circular_refs(data) ret = data.__class__() if preserve_dict_class else {} - for key, value in six.iteritems(data): + for key, value in data.items(): if isinstance(key, tuple): - key = encode_tuple(key, encoding, errors, keep, preserve_dict_class) \ - if preserve_tuples \ - else encode_list(key, encoding, errors, keep, - preserve_dict_class, preserve_tuples) + key = ( + encode_tuple(key, encoding, errors, keep, preserve_dict_class) + if preserve_tuples + else encode_list( + key, encoding, errors, keep, preserve_dict_class, preserve_tuples + ) + ) else: try: key = salt.utils.stringutils.to_bytes(key, encoding, errors) @@ -401,16 +613,21 @@ def encode_dict(data, encoding=None, errors='strict', keep=False, raise if isinstance(value, list): - value = encode_list(value, encoding, errors, keep, - preserve_dict_class, preserve_tuples) + value = encode_list( + value, encoding, errors, keep, preserve_dict_class, preserve_tuples + ) elif isinstance(value, tuple): - value = encode_tuple(value, encoding, errors, keep, preserve_dict_class) \ - if preserve_tuples \ - else encode_list(value, encoding, errors, keep, - preserve_dict_class, preserve_tuples) + value = ( + encode_tuple(value, encoding, errors, keep, preserve_dict_class) + if preserve_tuples + else encode_list( + value, encoding, errors, keep, preserve_dict_class, preserve_tuples + ) + ) elif isinstance(value, Mapping): - value = encode_dict(value, encoding, errors, keep, - preserve_dict_class, preserve_tuples) + value = encode_dict( + value, encoding, errors, keep, preserve_dict_class, preserve_tuples + ) else: try: value = salt.utils.stringutils.to_bytes(value, encoding, errors) @@ -427,26 +644,40 @@ def encode_dict(data, encoding=None, errors='strict', keep=False, return ret -@jinja_filter('json_decode_list') # Remove this for Aluminium -@jinja_filter('json_encode_list') -def encode_list(data, encoding=None, errors='strict', keep=False, - preserve_dict_class=False, preserve_tuples=False): - ''' +@jinja_filter("json_decode_list") # Remove this for Aluminium +@jinja_filter("json_encode_list") +def encode_list( + data, + encoding=None, + errors="strict", + keep=False, + preserve_dict_class=False, + preserve_tuples=False, +): + """ Encode all string values to bytes - ''' + """ + # Clean data object before encoding to avoid circular references + data = _remove_circular_refs(data) + ret = [] for item in data: if isinstance(item, list): - item = encode_list(item, encoding, errors, keep, - preserve_dict_class, preserve_tuples) + item = encode_list( + item, encoding, errors, keep, preserve_dict_class, preserve_tuples + ) elif isinstance(item, tuple): - item = encode_tuple(item, encoding, errors, keep, preserve_dict_class) \ - if preserve_tuples \ - else encode_list(item, encoding, errors, keep, - preserve_dict_class, preserve_tuples) + item = ( + encode_tuple(item, encoding, errors, keep, preserve_dict_class) + if preserve_tuples + else encode_list( + item, encoding, errors, keep, preserve_dict_class, preserve_tuples + ) + ) elif isinstance(item, Mapping): - item = encode_dict(item, encoding, errors, keep, - preserve_dict_class, preserve_tuples) + item = encode_dict( + item, encoding, errors, keep, preserve_dict_class, preserve_tuples + ) else: try: item = salt.utils.stringutils.to_bytes(item, encoding, errors) @@ -463,42 +694,37 @@ def encode_list(data, encoding=None, errors='strict', keep=False, return ret -def encode_tuple(data, encoding=None, errors='strict', keep=False, - preserve_dict_class=False): - ''' +def encode_tuple( + data, encoding=None, errors="strict", keep=False, preserve_dict_class=False +): + """ Encode all string values to Unicode - ''' - return tuple( - encode_list(data, encoding, errors, keep, preserve_dict_class, True)) + """ + return tuple(encode_list(data, encoding, errors, keep, preserve_dict_class, True)) -@jinja_filter('exactly_n_true') +@jinja_filter("exactly_n_true") def exactly_n(iterable, amount=1): - ''' + """ Tests that exactly N items in an iterable are "truthy" (neither None, False, nor 0). - ''' + """ i = iter(iterable) return all(any(i) for j in range(amount)) and not any(i) -@jinja_filter('exactly_one_true') +@jinja_filter("exactly_one_true") def exactly_one(iterable): - ''' + """ Check if only one item is not None, False, or 0 in an iterable. - ''' + """ return exactly_n(iterable) -def filter_by(lookup_dict, - lookup, - traverse, - merge=None, - default='default', - base=None): - ''' +def filter_by(lookup_dict, lookup, traverse, merge=None, default="default", base=None): + """ Common code to filter data structures like grains and pillar - ''' + """ ret = None # Default value would be an empty list if lookup not found val = traverse_dict_and_list(traverse, lookup, []) @@ -507,10 +733,8 @@ def filter_by(lookup_dict, # lookup_dict keys for each in val if isinstance(val, list) else [val]: for key in lookup_dict: - test_key = key if isinstance(key, six.string_types) \ - else six.text_type(key) - test_each = each if isinstance(each, six.string_types) \ - else six.text_type(each) + test_key = key if isinstance(key, str) else str(key) + test_each = each if isinstance(each, str) else str(each) if fnmatch.fnmatchcase(test_each, test_key): ret = lookup_dict[key] break @@ -528,14 +752,13 @@ def filter_by(lookup_dict, elif isinstance(base_values, Mapping): if not isinstance(ret, Mapping): raise SaltException( - 'filter_by default and look-up values must both be ' - 'dictionaries.') + "filter_by default and look-up values must both be " "dictionaries." + ) ret = salt.utils.dictupdate.update(copy.deepcopy(base_values), ret) if merge: if not isinstance(merge, Mapping): - raise SaltException( - 'filter_by merge argument must be a dictionary.') + raise SaltException("filter_by merge argument must be a dictionary.") if ret is None: ret = merge @@ -546,12 +769,12 @@ def filter_by(lookup_dict, def traverse_dict(data, key, default=None, delimiter=DEFAULT_TARGET_DELIM): - ''' + """ Traverse a dict using a colon-delimited (or otherwise delimited, using the 'delimiter' param) target string. The target 'foo:bar:baz' will return data['foo']['bar']['baz'] if this value exists, and will otherwise return the dict in the default argument. - ''' + """ ptr = data try: for each in key.split(delimiter): @@ -562,9 +785,9 @@ def traverse_dict(data, key, default=None, delimiter=DEFAULT_TARGET_DELIM): return ptr -@jinja_filter('traverse') +@jinja_filter("traverse") def traverse_dict_and_list(data, key, default=None, delimiter=DEFAULT_TARGET_DELIM): - ''' + """ Traverse a dict or list using a colon-delimited (or otherwise delimited, using the 'delimiter' param) target string. The target 'foo:bar:0' will return data['foo']['bar'][0] if this value exists, and will otherwise @@ -573,7 +796,7 @@ def traverse_dict_and_list(data, key, default=None, delimiter=DEFAULT_TARGET_DEL The target 'foo:bar:0' will return data['foo']['bar'][0] if data like {'foo':{'bar':['baz']}} , if data like {'foo':{'bar':{'0':'baz'}}} then return data['foo']['bar']['0'] - ''' + """ ptr = data for each in key.split(delimiter): if isinstance(ptr, list): @@ -605,18 +828,17 @@ def traverse_dict_and_list(data, key, default=None, delimiter=DEFAULT_TARGET_DEL return ptr -def subdict_match(data, - expr, - delimiter=DEFAULT_TARGET_DELIM, - regex_match=False, - exact_match=False): - ''' +def subdict_match( + data, expr, delimiter=DEFAULT_TARGET_DELIM, regex_match=False, exact_match=False +): + """ Check for a match in a dictionary using a delimiter character to denote levels of subdicts, and also allowing the delimiter character to be matched. Thus, 'foo:bar:baz' will match data['foo'] == 'bar:baz' and data['foo']['bar'] == 'baz'. The latter would take priority over the former, as more deeply-nested matches are tried first. - ''' + """ + def _match(target, pattern, regex_match=False, exact_match=False): # The reason for using six.text_type first and _then_ using # to_unicode as a fallback is because we want to eventually have @@ -628,11 +850,11 @@ def subdict_match(data, # begin with is that (by design) to_unicode will raise a TypeError if a # non-string/bytestring/bytearray value is passed. try: - target = six.text_type(target).lower() + target = str(target).lower() except UnicodeDecodeError: target = salt.utils.stringutils.to_unicode(target).lower() try: - pattern = six.text_type(pattern).lower() + pattern = str(pattern).lower() except UnicodeDecodeError: pattern = salt.utils.stringutils.to_unicode(pattern).lower() @@ -640,48 +862,54 @@ def subdict_match(data, try: return re.match(pattern, target) except Exception: # pylint: disable=broad-except - log.error('Invalid regex \'%s\' in match', pattern) + log.error("Invalid regex '%s' in match", pattern) return False else: - return target == pattern if exact_match \ - else fnmatch.fnmatch(target, pattern) + return ( + target == pattern if exact_match else fnmatch.fnmatch(target, pattern) + ) def _dict_match(target, pattern, regex_match=False, exact_match=False): ret = False - wildcard = pattern.startswith('*:') + wildcard = pattern.startswith("*:") if wildcard: pattern = pattern[2:] - if pattern == '*': + if pattern == "*": # We are just checking that the key exists ret = True if not ret and pattern in target: # We might want to search for a key ret = True - if not ret and subdict_match(target, - pattern, - regex_match=regex_match, - exact_match=exact_match): + if not ret and subdict_match( + target, pattern, regex_match=regex_match, exact_match=exact_match + ): ret = True if not ret and wildcard: for key in target: if isinstance(target[key], dict): - if _dict_match(target[key], - pattern, - regex_match=regex_match, - exact_match=exact_match): + if _dict_match( + target[key], + pattern, + regex_match=regex_match, + exact_match=exact_match, + ): return True elif isinstance(target[key], list): for item in target[key]: - if _match(item, - pattern, - regex_match=regex_match, - exact_match=exact_match): - return True - elif _match(target[key], + if _match( + item, pattern, regex_match=regex_match, - exact_match=exact_match): + exact_match=exact_match, + ): + return True + elif _match( + target[key], + pattern, + regex_match=regex_match, + exact_match=exact_match, + ): return True return ret @@ -695,7 +923,7 @@ def subdict_match(data, # want to use are 3, 2, and 1, in that order. for idx in range(num_splits - 1, 0, -1): key = delimiter.join(splits[:idx]) - if key == '*': + if key == "*": # We are matching on everything under the top level, so we need to # treat the match as the entire data being passed in matchstr = expr @@ -703,54 +931,55 @@ def subdict_match(data, else: matchstr = delimiter.join(splits[idx:]) match = traverse_dict_and_list(data, key, {}, delimiter=delimiter) - log.debug("Attempting to match '%s' in '%s' using delimiter '%s'", - matchstr, key, delimiter) + log.debug( + "Attempting to match '%s' in '%s' using delimiter '%s'", + matchstr, + key, + delimiter, + ) if match == {}: continue if isinstance(match, dict): - if _dict_match(match, - matchstr, - regex_match=regex_match, - exact_match=exact_match): + if _dict_match( + match, matchstr, regex_match=regex_match, exact_match=exact_match + ): return True continue if isinstance(match, (list, tuple)): # We are matching a single component to a single list member for member in match: if isinstance(member, dict): - if _dict_match(member, - matchstr, - regex_match=regex_match, - exact_match=exact_match): + if _dict_match( + member, + matchstr, + regex_match=regex_match, + exact_match=exact_match, + ): return True - if _match(member, - matchstr, - regex_match=regex_match, - exact_match=exact_match): + if _match( + member, matchstr, regex_match=regex_match, exact_match=exact_match + ): return True continue - if _match(match, - matchstr, - regex_match=regex_match, - exact_match=exact_match): + if _match(match, matchstr, regex_match=regex_match, exact_match=exact_match): return True return False -@jinja_filter('substring_in_list') +@jinja_filter("substring_in_list") def substr_in_list(string_to_search_for, list_to_search): - ''' + """ Return a boolean value that indicates whether or not a given string is present in any of the strings which comprise a list - ''' + """ return any(string_to_search_for in s for s in list_to_search) def is_dictlist(data): - ''' + """ Returns True if data is a list of one-element dicts (as found in many SLS schemas), otherwise returns False - ''' + """ if isinstance(data, list): for element in data: if isinstance(element, dict): @@ -762,16 +991,12 @@ def is_dictlist(data): return False -def repack_dictlist(data, - strict=False, - recurse=False, - key_cb=None, - val_cb=None): - ''' +def repack_dictlist(data, strict=False, recurse=False, key_cb=None, val_cb=None): + """ Takes a list of one-element dicts (as found in many SLS schemas) and repacks into a single dictionary. - ''' - if isinstance(data, six.string_types): + """ + if isinstance(data, str): try: data = salt.utils.yaml.safe_load(data) except salt.utils.yaml.parser.ParserError as err: @@ -783,7 +1008,7 @@ def repack_dictlist(data, if val_cb is None: val_cb = lambda x, y: y - valid_non_dict = (six.string_types, six.integer_types, float) + valid_non_dict = ((str,), (int,), float) if isinstance(data, list): for element in data: if isinstance(element, valid_non_dict): @@ -791,21 +1016,21 @@ def repack_dictlist(data, if isinstance(element, dict): if len(element) != 1: log.error( - 'Invalid input for repack_dictlist: key/value pairs ' - 'must contain only one element (data passed: %s).', - element + "Invalid input for repack_dictlist: key/value pairs " + "must contain only one element (data passed: %s).", + element, ) return {} else: log.error( - 'Invalid input for repack_dictlist: element %s is ' - 'not a string/dict/numeric value', element + "Invalid input for repack_dictlist: element %s is " + "not a string/dict/numeric value", + element, ) return {} else: log.error( - 'Invalid input for repack_dictlist, data passed is not a list ' - '(%s)', data + "Invalid input for repack_dictlist, data passed is not a list " "(%s)", data ) return {} @@ -821,8 +1046,8 @@ def repack_dictlist(data, ret[key_cb(key)] = repack_dictlist(val, recurse=recurse) elif strict: log.error( - 'Invalid input for repack_dictlist: nested dictlist ' - 'found, but recurse is set to False' + "Invalid input for repack_dictlist: nested dictlist " + "found, but recurse is set to False" ) return {} else: @@ -832,17 +1057,17 @@ def repack_dictlist(data, return ret -@jinja_filter('is_list') +@jinja_filter("is_list") def is_list(value): - ''' + """ Check if a variable is a list. - ''' + """ return isinstance(value, list) -@jinja_filter('is_iter') -def is_iter(thing, ignore=six.string_types): - ''' +@jinja_filter("is_iter") +def is_iter(thing, ignore=(str,)): + """ Test if an object is iterable, but not a string type. Test if an object is an iterator or is iterable itself. By default this @@ -853,7 +1078,7 @@ def is_iter(thing, ignore=six.string_types): dictionaries or named tuples. Based on https://bitbucket.org/petershinners/yter - ''' + """ if ignore and isinstance(thing, ignore): return False try: @@ -863,9 +1088,9 @@ def is_iter(thing, ignore=six.string_types): return False -@jinja_filter('sorted_ignorecase') +@jinja_filter("sorted_ignorecase") def sorted_ignorecase(to_sort): - ''' + """ Sort a list of strings ignoring case. >>> L = ['foo', 'Foo', 'bar', 'Bar'] @@ -874,19 +1099,19 @@ def sorted_ignorecase(to_sort): >>> sorted(L, key=lambda x: x.lower()) ['bar', 'Bar', 'foo', 'Foo'] >>> - ''' + """ return sorted(to_sort, key=lambda x: x.lower()) def is_true(value=None): - ''' + """ Returns a boolean value representing the "truth" of the value passed. The rules for what is a "True" value are: 1. Integer/float values greater than 0 2. The string values "True" and "true" 3. Any object for which bool(obj) returns True - ''' + """ # First, try int/float conversion try: value = int(value) @@ -898,26 +1123,26 @@ def is_true(value=None): pass # Now check for truthiness - if isinstance(value, (six.integer_types, float)): + if isinstance(value, ((int,), float)): return value > 0 - if isinstance(value, six.string_types): - return six.text_type(value).lower() == 'true' + if isinstance(value, str): + return str(value).lower() == "true" return bool(value) -@jinja_filter('mysql_to_dict') +@jinja_filter("mysql_to_dict") def mysql_to_dict(data, key): - ''' + """ Convert MySQL-style output to a python dictionary - ''' + """ ret = {} - headers = [''] + headers = [""] for line in data: if not line: continue - if line.startswith('+'): + if line.startswith("+"): continue - comps = line.split('|') + comps = line.split("|") for comp in range(len(comps)): comps[comp] = comps[comp].strip() if len(headers) > 1: @@ -934,14 +1159,14 @@ def mysql_to_dict(data, key): def simple_types_filter(data): - ''' + """ Convert the data list, dictionary into simple types, i.e., int, float, string, bool, etc. - ''' + """ if data is None: return data - simpletypes_keys = (six.string_types, six.text_type, six.integer_types, float, bool) + simpletypes_keys = ((str,), str, (int,), float, bool) simpletypes_values = tuple(list(simpletypes_keys) + [list, tuple]) if isinstance(data, (list, tuple)): @@ -957,7 +1182,7 @@ def simple_types_filter(data): if isinstance(data, dict): simpledict = {} - for key, value in six.iteritems(data): + for key, value in data.items(): if key is not None and not isinstance(key, simpletypes_keys): key = repr(key) if value is not None and isinstance(value, (dict, list, tuple)): @@ -971,23 +1196,23 @@ def simple_types_filter(data): def stringify(data): - ''' + """ Given an iterable, returns its items as a list, with any non-string items converted to unicode strings. - ''' + """ ret = [] for item in data: if six.PY2 and isinstance(item, str): item = salt.utils.stringutils.to_unicode(item) - elif not isinstance(item, six.string_types): - item = six.text_type(item) + elif not isinstance(item, str): + item = str(item) ret.append(item) return ret -@jinja_filter('json_query') +@jinja_filter("json_query") def json_query(data, expr): - ''' + """ Query data using JMESPath language (http://jmespath.org). Requires the https://github.com/jmespath/jmespath.py library. @@ -1009,16 +1234,16 @@ def json_query(data, expr): .. code-block:: text [80, 25, 22] - ''' + """ if jmespath is None: - err = 'json_query requires jmespath module installed' + err = "json_query requires jmespath module installed" log.error(err) raise RuntimeError(err) return jmespath.search(expr, data) def _is_not_considered_falsey(value, ignore_types=()): - ''' + """ Helper function for filter_falsey to determine if something is not to be considered falsey. @@ -1026,12 +1251,12 @@ def _is_not_considered_falsey(value, ignore_types=()): :param list ignore_types: The types to ignore when considering the value. :return bool - ''' + """ return isinstance(value, bool) or type(value) in ignore_types or value def filter_falsey(data, recurse_depth=None, ignore_types=()): - ''' + """ Helper function to remove items from an iterable with falsey value. Removes ``None``, ``{}`` and ``[]``, 0, '' (but does not remove ``False``). Recurses into sub-iterables if ``recurse`` is set to ``True``. @@ -1045,37 +1270,42 @@ def filter_falsey(data, recurse_depth=None, ignore_types=()): :return type(data) .. versionadded:: 3000 - ''' + """ filter_element = ( - functools.partial(filter_falsey, - recurse_depth=recurse_depth-1, - ignore_types=ignore_types) - if recurse_depth else lambda x: x + functools.partial( + filter_falsey, recurse_depth=recurse_depth - 1, ignore_types=ignore_types + ) + if recurse_depth + else lambda x: x ) if isinstance(data, dict): - processed_elements = [(key, filter_element(value)) for key, value in six.iteritems(data)] - return type(data)([ - (key, value) - for key, value in processed_elements - if _is_not_considered_falsey(value, ignore_types=ignore_types) - ]) + processed_elements = [ + (key, filter_element(value)) for key, value in data.items() + ] + return type(data)( + [ + (key, value) + for key, value in processed_elements + if _is_not_considered_falsey(value, ignore_types=ignore_types) + ] + ) if is_iter(data): processed_elements = (filter_element(value) for value in data) - return type(data)([ - value for value in processed_elements - if _is_not_considered_falsey(value, ignore_types=ignore_types) - ]) + return type(data)( + [ + value + for value in processed_elements + if _is_not_considered_falsey(value, ignore_types=ignore_types) + ] + ) return data def recursive_diff( - old, - new, - ignore_keys=None, - ignore_order=False, - ignore_missing_keys=False): - ''' + old, new, ignore_keys=None, ignore_order=False, ignore_missing_keys=False +): + """ Performs a recursive diff on mappings and/or iterables and returns the result in a {'old': values, 'new': values}-style. Compares dicts and sets unordered (obviously), OrderedDicts and Lists ordered @@ -1090,12 +1320,16 @@ def recursive_diff( but missing in ``new``. Only works for regular dicts. :return dict: Returns dict with keys 'old' and 'new' containing the differences. - ''' + """ ignore_keys = ignore_keys or [] res = {} ret_old = copy.deepcopy(old) ret_new = copy.deepcopy(new) - if isinstance(old, OrderedDict) and isinstance(new, OrderedDict) and not ignore_order: + if ( + isinstance(old, OrderedDict) + and isinstance(new, OrderedDict) + and not ignore_order + ): append_old, append_new = [], [] if len(old) != len(new): min_length = min(len(old), len(new)) @@ -1114,13 +1348,14 @@ def recursive_diff( new[key_new], ignore_keys=ignore_keys, ignore_order=ignore_order, - ignore_missing_keys=ignore_missing_keys) + ignore_missing_keys=ignore_missing_keys, + ) if not res: # Equal del ret_old[key_old] del ret_new[key_new] else: - ret_old[key_old] = res['old'] - ret_new[key_new] = res['new'] + ret_old[key_old] = res["old"] + ret_new[key_new] = res["new"] else: if key_old in ignore_keys: del ret_old[key_old] @@ -1131,7 +1366,7 @@ def recursive_diff( ret_old[item] = old[item] for item in append_new: ret_new[item] = new[item] - ret = {'old': ret_old, 'new': ret_new} if ret_old or ret_new else {} + ret = {"old": ret_old, "new": ret_new} if ret_old or ret_new else {} elif isinstance(old, Mapping) and isinstance(new, Mapping): # Compare unordered for key in set(list(old) + list(new)): @@ -1146,16 +1381,17 @@ def recursive_diff( new[key], ignore_keys=ignore_keys, ignore_order=ignore_order, - ignore_missing_keys=ignore_missing_keys) + ignore_missing_keys=ignore_missing_keys, + ) if not res: # Equal del ret_old[key] del ret_new[key] else: - ret_old[key] = res['old'] - ret_new[key] = res['new'] - ret = {'old': ret_old, 'new': ret_new} if ret_old or ret_new else {} + ret_old[key] = res["old"] + ret_new[key] = res["new"] + ret = {"old": ret_old, "new": ret_new} if ret_old or ret_new else {} elif isinstance(old, set) and isinstance(new, set): - ret = {'old': old - new, 'new': new - old} if old - new or new - old else {} + ret = {"old": old - new, "new": new - old} if old - new or new - old else {} elif is_iter(old) and is_iter(new): # Create a list so we can edit on an index-basis. list_old = list(ret_old) @@ -1168,7 +1404,8 @@ def recursive_diff( item_new, ignore_keys=ignore_keys, ignore_order=ignore_order, - ignore_missing_keys=ignore_missing_keys) + ignore_missing_keys=ignore_missing_keys, + ) if not res: list_old.remove(item_old) list_new.remove(item_new) @@ -1181,19 +1418,87 @@ def recursive_diff( iter_new, ignore_keys=ignore_keys, ignore_order=ignore_order, - ignore_missing_keys=ignore_missing_keys) + ignore_missing_keys=ignore_missing_keys, + ) if not res: # Equal remove_indices.append(index) else: - list_old[index] = res['old'] - list_new[index] = res['new'] + list_old[index] = res["old"] + list_new[index] = res["new"] for index in reversed(remove_indices): list_old.pop(index) list_new.pop(index) # Instantiate a new whatever-it-was using the list as iterable source. # This may not be the most optimized in way of speed and memory usage, # but it will work for all iterable types. - ret = {'old': type(old)(list_old), 'new': type(new)(list_new)} if list_old or list_new else {} + ret = ( + {"old": type(old)(list_old), "new": type(new)(list_new)} + if list_old or list_new + else {} + ) else: - ret = {} if old == new else {'old': ret_old, 'new': ret_new} + ret = {} if old == new else {"old": ret_old, "new": ret_new} return ret + + +def get_value(obj, path, default=None): + """ + Get the values for a given path. + + :param path: + keys of the properties in the tree separated by colons. + One segment in the path can be replaced by an id surrounded by curly braces. + This will match all items in a list of dictionary. + + :param default: + default value to return when no value is found + + :return: + a list of dictionaries, with at least the "value" key providing the actual value. + If a placeholder was used, the placeholder id will be a key providing the replacement for it. + Note that a value that wasn't found in the tree will be an empty list. + This ensures we can make the difference with a None value set by the user. + """ + res = [{"value": obj}] + if path: + key = path[: path.find(":")] if ":" in path else path + next_path = path[path.find(":") + 1 :] if ":" in path else None + + if key.startswith("{") and key.endswith("}"): + placeholder_name = key[1:-1] + # There will be multiple values to get here + items = [] + if obj is None: + return res + if isinstance(obj, dict): + items = obj.items() + elif isinstance(obj, list): + items = enumerate(obj) + + def _append_placeholder(value_dict, key): + value_dict[placeholder_name] = key + return value_dict + + values = [ + [ + _append_placeholder(item, key) + for item in get_value(val, next_path, default) + ] + for key, val in items + ] + + # flatten the list + values = [y for x in values for y in x] + return values + elif isinstance(obj, dict): + if key not in obj.keys(): + return [{"value": default}] + + value = obj.get(key) + if res is not None: + res = get_value(value, next_path, default) + else: + res = [{"value": value}] + else: + return [{"value": default if obj is not None else obj}] + return res diff --git a/salt/utils/xmlutil.py b/salt/utils/xmlutil.py index 6d8d74fd3f..2b9c7bf43f 100644 --- a/salt/utils/xmlutil.py +++ b/salt/utils/xmlutil.py @@ -1,30 +1,34 @@ -# -*- coding: utf-8 -*- -''' +""" Various XML utilities -''' +""" # Import Python libs -from __future__ import absolute_import, print_function, unicode_literals +import re +import string # pylint: disable=deprecated-module +from xml.etree import ElementTree + +# Import salt libs +import salt.utils.data def _conv_name(x): - ''' + """ If this XML tree has an xmlns attribute, then etree will add it to the beginning of the tag, like: "{http://path}tag". - ''' - if '}' in x: - comps = x.split('}') + """ + if "}" in x: + comps = x.split("}") name = comps[1] return name return x def _to_dict(xmltree): - ''' + """ Converts an XML ElementTree to a dictionary that only contains items. This is the default behavior in version 2017.7. This will default to prevent unexpected parsing issues on modules dependant on this. - ''' + """ # If this object has no children, the for..loop below will return nothing # for it, so just return a single dict representing it. if len(xmltree.getchildren()) < 1: @@ -51,9 +55,9 @@ def _to_dict(xmltree): def _to_full_dict(xmltree): - ''' + """ Returns the full XML dictionary including attributes. - ''' + """ xmldict = {} for attrName, attrValue in xmltree.attrib.items(): @@ -87,15 +91,234 @@ def _to_full_dict(xmltree): def to_dict(xmltree, attr=False): - ''' + """ Convert an XML tree into a dict. The tree that is passed in must be an ElementTree object. Args: xmltree: An ElementTree object. attr: If true, attributes will be parsed. If false, they will be ignored. - ''' + """ if attr: return _to_full_dict(xmltree) else: return _to_dict(xmltree) + + +def get_xml_node(node, xpath): + """ + Get an XML node using a path (super simple xpath showing complete node ancestry). + This also creates the missing nodes. + + The supported XPath can contain elements filtering using [@attr='value']. + + Args: + node: an Element object + xpath: simple XPath to look for. + """ + if not xpath.startswith("./"): + xpath = "./{}".format(xpath) + res = node.find(xpath) + if res is None: + parent_xpath = xpath[: xpath.rfind("/")] + parent = node.find(parent_xpath) + if parent is None: + parent = get_xml_node(node, parent_xpath) + segment = xpath[xpath.rfind("/") + 1 :] + # We may have [] filter in the segment + matcher = re.match( + r"""(?P[^[]+)(?:\[@(?P\w+)=["'](?P[^"']+)["']])?""", + segment, + ) + attrib = ( + {matcher.group("attr"): matcher.group("value")} + if matcher.group("attr") and matcher.group("value") + else {} + ) + res = ElementTree.SubElement(parent, matcher.group("tag"), attrib) + return res + + +def set_node_text(node, value): + """ + Function to use in the ``set`` value in the :py:func:`change_xml` mapping items to set the text. + This is the default. + + :param node: the node to set the text to + :param value: the value to set + """ + node.text = str(value) + + +def clean_node(parent_map, node, ignored=None): + """ + Remove the node from its parent if it has no attribute but the ignored ones, no text and no child. + Recursively called up to the document root to ensure no empty node is left. + + :param parent_map: dictionary mapping each node to its parent + :param node: the node to clean + :param ignored: a list of ignored attributes. + """ + has_text = node.text is not None and node.text.strip() + parent = parent_map.get(node) + if ( + len(node.attrib.keys() - (ignored or [])) == 0 + and not list(node) + and not has_text + ): + parent.remove(node) + # Clean parent nodes if needed + if parent is not None: + clean_node(parent_map, parent, ignored) + + +def del_text(parent_map, node): + """ + Function to use as ``del`` value in the :py:func:`change_xml` mapping items to remove the text. + This is the default function. + Calls :py:func:`clean_node` before returning. + """ + parent = parent_map[node] + parent.remove(node) + clean_node(parent, node) + + +def del_attribute(attribute, ignored=None): + """ + Helper returning a function to use as ``del`` value in the :py:func:`change_xml` mapping items to + remove an attribute. + + The generated function calls :py:func:`clean_node` before returning. + + :param attribute: the name of the attribute to remove + :param ignored: the list of attributes to ignore during the cleanup + + :return: the function called by :py:func:`change_xml`. + """ + + def _do_delete(parent_map, node): + if attribute not in node.keys(): + return + node.attrib.pop(attribute) + clean_node(parent_map, node, ignored) + + return _do_delete + + +def change_xml(doc, data, mapping): + """ + Change an XML ElementTree document according. + + :param doc: the ElementTree parsed XML document to modify + :param data: the dictionary of values used to modify the XML. + :param mapping: a list of items describing how to modify the XML document. + Each item is a dictionary containing the following keys: + + .. glossary:: + path + the path to the value to set or remove in the ``data`` parameter. + See :py:func:`salt.utils.data.get_value ` for the format + of the value. + + xpath + Simplified XPath expression used to locate the change in the XML tree. + See :py:func:`get_xml_node` documentation for details on the supported XPath syntax + + get + function gettin the value from the XML. + Takes a single parameter for the XML node found by the XPath expression. + Default returns the node text value. + This may be used to return an attribute or to perform value transformation. + + set + function setting the value in the XML. + Takes two parameters for the XML node and the value to set. + Default is to set the text value. + + del + function deleting the value in the XML. + Takes two parameters for the parent node and the node matched by the XPath. + Default is to remove the text value. + More cleanup may be performed, see the :py:func:`clean_node` function for details. + + convert + function modifying the user-provided value right before comparing it with the one from the XML. + Takes the value as single parameter. + Default is to apply no conversion. + + :return: ``True`` if the XML has been modified, ``False`` otherwise. + """ + need_update = False + for param in mapping: + # Get the value from the function parameter using the path-like description + # Using an empty list as a default value will cause values not provided by the user + # to be left untouched, as opposed to explicit None unsetting the value + values = salt.utils.data.get_value(data, param["path"], []) + xpath = param["xpath"] + # Prepend the xpath with ./ to handle the root more easily + if not xpath.startswith("./"): + xpath = "./{}".format(xpath) + + placeholders = [ + s[1:-1] + for s in param["path"].split(":") + if s.startswith("{") and s.endswith("}") + ] + + ctx = {placeholder: "$$$" for placeholder in placeholders} + all_nodes_xpath = string.Template(xpath).substitute(ctx) + all_nodes_xpath = re.sub( + r"""(?:=['"]\$\$\$["'])|(?:\[\$\$\$\])""", "", all_nodes_xpath + ) + + # Store the nodes that are not removed for later cleanup + kept_nodes = set() + + for value_item in values: + new_value = value_item["value"] + + # Only handle simple type values. Use multiple entries or a custom get for dict or lists + if isinstance(new_value, list) or isinstance(new_value, dict): + continue + + if new_value is not None: + ctx = { + placeholder: value_item.get(placeholder, "") + for placeholder in placeholders + } + node_xpath = string.Template(xpath).substitute(ctx) + node = get_xml_node(doc, node_xpath) + + kept_nodes.add(node) + + get_fn = param.get("get", lambda n: n.text) + set_fn = param.get("set", set_node_text) + current_value = get_fn(node) + + # Do we need to apply some conversion to the user-provided value? + convert_fn = param.get("convert") + if convert_fn: + new_value = convert_fn(new_value) + + if current_value != new_value: + set_fn(node, new_value) + need_update = True + else: + nodes = doc.findall(all_nodes_xpath) + del_fn = param.get("del", del_text) + parent_map = {c: p for p in doc.iter() for c in p} + for node in nodes: + del_fn(parent_map, node) + need_update = True + + # Clean the left over XML elements if there were placeholders + if placeholders and values[0].get("value") != []: + all_nodes = set(doc.findall(all_nodes_xpath)) + to_remove = all_nodes - kept_nodes + del_fn = param.get("del", del_text) + parent_map = {c: p for p in doc.iter() for c in p} + for node in to_remove: + del_fn(parent_map, node) + need_update = True + + return need_update diff --git a/tests/pytests/unit/utils/test_data.py b/tests/pytests/unit/utils/test_data.py new file mode 100644 index 0000000000..b3f0ba04ae --- /dev/null +++ b/tests/pytests/unit/utils/test_data.py @@ -0,0 +1,57 @@ +import salt.utils.data + + +def test_get_value_simple_path(): + data = {"a": {"b": {"c": "foo"}}} + assert [{"value": "foo"}] == salt.utils.data.get_value(data, "a:b:c") + + +def test_get_value_placeholder_dict(): + data = {"a": {"b": {"name": "foo"}, "c": {"name": "bar"}}} + assert [ + {"value": "foo", "id": "b"}, + {"value": "bar", "id": "c"}, + ] == salt.utils.data.get_value(data, "a:{id}:name") + + +def test_get_value_placeholder_list(): + data = {"a": [{"name": "foo"}, {"name": "bar"}]} + assert [ + {"value": "foo", "id": 0}, + {"value": "bar", "id": 1}, + ] == salt.utils.data.get_value(data, "a:{id}:name") + + +def test_get_value_nested_placeholder(): + data = { + "a": { + "b": {"b1": {"name": "foo1"}, "b2": {"name": "foo2"}}, + "c": {"c1": {"name": "bar"}}, + } + } + assert [ + {"value": "foo1", "id": "b", "sub": "b1"}, + {"value": "foo2", "id": "b", "sub": "b2"}, + {"value": "bar", "id": "c", "sub": "c1"}, + ] == salt.utils.data.get_value(data, "a:{id}:{sub}:name") + + +def test_get_value_nested_notfound(): + data = {"a": {"b": {"c": "foo"}}} + assert [{"value": []}] == salt.utils.data.get_value(data, "a:b:d", []) + + +def test_get_value_not_found(): + assert [{"value": []}] == salt.utils.data.get_value({}, "a", []) + + +def test_get_value_none(): + assert [{"value": None}] == salt.utils.data.get_value({"a": None}, "a") + + +def test_get_value_simple_type_path(): + assert [{"value": []}] == salt.utils.data.get_value({"a": 1024}, "a:b", []) + + +def test_get_value_None_path(): + assert [{"value": None}] == salt.utils.data.get_value({"a": None}, "a:b", []) diff --git a/tests/pytests/unit/utils/test_xmlutil.py b/tests/pytests/unit/utils/test_xmlutil.py new file mode 100644 index 0000000000..081cc64193 --- /dev/null +++ b/tests/pytests/unit/utils/test_xmlutil.py @@ -0,0 +1,169 @@ +import pytest +import salt.utils.xmlutil as xml +from salt._compat import ElementTree as ET + + +@pytest.fixture +def xml_doc(): + return ET.fromstring( + """ + + test01 + 1024 + + + + + + + + """ + ) + + +def test_change_xml_text(xml_doc): + ret = xml.change_xml( + xml_doc, {"name": "test02"}, [{"path": "name", "xpath": "name"}] + ) + assert ret + assert "test02" == xml_doc.find("name").text + + +def test_change_xml_text_nochange(xml_doc): + ret = xml.change_xml( + xml_doc, {"name": "test01"}, [{"path": "name", "xpath": "name"}] + ) + assert not ret + + +def test_change_xml_text_notdefined(xml_doc): + ret = xml.change_xml(xml_doc, {}, [{"path": "name", "xpath": "name"}]) + assert not ret + + +def test_change_xml_text_removed(xml_doc): + ret = xml.change_xml(xml_doc, {"name": None}, [{"path": "name", "xpath": "name"}]) + assert ret + assert xml_doc.find("name") is None + + +def test_change_xml_text_add(xml_doc): + ret = xml.change_xml( + xml_doc, + {"cpu": {"vendor": "ACME"}}, + [{"path": "cpu:vendor", "xpath": "cpu/vendor"}], + ) + assert ret + assert "ACME" == xml_doc.find("cpu/vendor").text + + +def test_change_xml_convert(xml_doc): + ret = xml.change_xml( + xml_doc, + {"mem": 2}, + [{"path": "mem", "xpath": "memory", "convert": lambda v: v * 1024}], + ) + assert ret + assert "2048" == xml_doc.find("memory").text + + +def test_change_xml_attr(xml_doc): + ret = xml.change_xml( + xml_doc, + {"cpu": {"topology": {"cores": 4}}}, + [ + { + "path": "cpu:topology:cores", + "xpath": "cpu/topology", + "get": lambda n: int(n.get("cores")) if n.get("cores") else None, + "set": lambda n, v: n.set("cores", str(v)), + "del": xml.del_attribute("cores"), + } + ], + ) + assert ret + assert "4" == xml_doc.find("cpu/topology").get("cores") + + +def test_change_xml_attr_unchanged(xml_doc): + ret = xml.change_xml( + xml_doc, + {"cpu": {"topology": {"sockets": 1}}}, + [ + { + "path": "cpu:topology:sockets", + "xpath": "cpu/topology", + "get": lambda n: int(n.get("sockets")) if n.get("sockets") else None, + "set": lambda n, v: n.set("sockets", str(v)), + "del": xml.del_attribute("sockets"), + } + ], + ) + assert not ret + + +def test_change_xml_attr_remove(xml_doc): + ret = xml.change_xml( + xml_doc, + {"cpu": {"topology": {"sockets": None}}}, + [ + { + "path": "cpu:topology:sockets", + "xpath": "./cpu/topology", + "get": lambda n: int(n.get("sockets")) if n.get("sockets") else None, + "set": lambda n, v: n.set("sockets", str(v)), + "del": xml.del_attribute("sockets"), + } + ], + ) + assert ret + assert xml_doc.find("cpu") is None + + +def test_change_xml_not_simple_value(xml_doc): + ret = xml.change_xml( + xml_doc, + {"cpu": {"topology": {"sockets": None}}}, + [{"path": "cpu", "xpath": "vcpu", "get": lambda n: int(n.text)}], + ) + assert not ret + + +def test_change_xml_template(xml_doc): + ret = xml.change_xml( + xml_doc, + {"cpu": {"vcpus": {2: {"enabled": True}, 4: {"enabled": False}}}}, + [ + { + "path": "cpu:vcpus:{id}:enabled", + "xpath": "vcpus/vcpu[@id='$id']", + "convert": lambda v: "yes" if v else "no", + "get": lambda n: n.get("enabled"), + "set": lambda n, v: n.set("enabled", v), + "del": xml.del_attribute("enabled", ["id"]), + }, + ], + ) + assert ret + assert xml_doc.find("vcpus/vcpu[@id='1']") is None + assert "yes" == xml_doc.find("vcpus/vcpu[@id='2']").get("enabled") + assert "no" == xml_doc.find("vcpus/vcpu[@id='4']").get("enabled") + + +def test_change_xml_template_remove(xml_doc): + ret = xml.change_xml( + xml_doc, + {"cpu": {"vcpus": None}}, + [ + { + "path": "cpu:vcpus:{id}:enabled", + "xpath": "vcpus/vcpu[@id='$id']", + "convert": lambda v: "yes" if v else "no", + "get": lambda n: n.get("enabled"), + "set": lambda n, v: n.set("enabled", v), + "del": xml.del_attribute("enabled", ["id"]), + }, + ], + ) + assert ret + assert xml_doc.find("vcpus") is None diff --git a/tests/unit/modules/test_virt.py b/tests/unit/modules/test_virt.py index d3988464f6..5ec8de77e7 100644 --- a/tests/unit/modules/test_virt.py +++ b/tests/unit/modules/test_virt.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ virt execution module unit tests """ @@ -6,7 +5,6 @@ virt execution module unit tests # pylint: disable=3rd-party-module-not-gated # Import python libs -from __future__ import absolute_import, print_function, unicode_literals import datetime import os @@ -23,9 +21,6 @@ import salt.utils.yaml from salt._compat import ElementTree as ET from salt.exceptions import CommandExecutionError, SaltInvocationError -# Import third party libs -from salt.ext import six - # pylint: disable=import-error from salt.ext.six.moves import range # pylint: disable=redefined-builtin @@ -136,7 +131,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): "model": "virtio", "filename": "myvm_system.qcow2", "image": "/path/to/image", - "source_file": "{0}{1}myvm_system.qcow2".format(root_dir, os.sep), + "source_file": "{}{}myvm_system.qcow2".format(root_dir, os.sep), }, { "name": "data", @@ -145,7 +140,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): "format": "raw", "model": "virtio", "filename": "myvm_data.raw", - "source_file": "{0}{1}myvm_data.raw".format(root_dir, os.sep), + "source_file": "{}{}myvm_data.raw".format(root_dir, os.sep), }, ], disks, @@ -582,8 +577,8 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): self.assertIsNone(root.get("type")) self.assertEqual(root.find("name").text, "vmname/system.vmdk") self.assertEqual(root.find("capacity").attrib["unit"], "KiB") - self.assertEqual(root.find("capacity").text, six.text_type(8192 * 1024)) - self.assertEqual(root.find("allocation").text, six.text_type(0)) + self.assertEqual(root.find("capacity").text, str(8192 * 1024)) + self.assertEqual(root.find("allocation").text, str(0)) self.assertEqual(root.find("target/format").get("type"), "vmdk") self.assertIsNone(root.find("target/permissions")) self.assertIsNone(root.find("target/nocow")) @@ -615,9 +610,9 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): self.assertIsNone(root.find("target/path")) self.assertEqual(root.find("target/format").get("type"), "qcow2") self.assertEqual(root.find("capacity").attrib["unit"], "KiB") - self.assertEqual(root.find("capacity").text, six.text_type(8192 * 1024)) + self.assertEqual(root.find("capacity").text, str(8192 * 1024)) self.assertEqual(root.find("capacity").attrib["unit"], "KiB") - self.assertEqual(root.find("allocation").text, six.text_type(4096 * 1024)) + self.assertEqual(root.find("allocation").text, str(4096 * 1024)) self.assertEqual(root.find("target/permissions/mode").text, "0775") self.assertEqual(root.find("target/permissions/owner").text, "123") self.assertEqual(root.find("target/permissions/group").text, "456") @@ -638,7 +633,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): root = ET.fromstring(xml_data) self.assertEqual(root.attrib["type"], "kvm") self.assertEqual(root.find("vcpu").text, "1") - self.assertEqual(root.find("memory").text, six.text_type(512 * 1024)) + self.assertEqual(root.find("memory").text, str(512 * 1024)) self.assertEqual(root.find("memory").attrib["unit"], "KiB") disks = root.findall(".//disk") @@ -671,7 +666,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): root = ET.fromstring(xml_data) self.assertEqual(root.attrib["type"], "vmware") self.assertEqual(root.find("vcpu").text, "1") - self.assertEqual(root.find("memory").text, six.text_type(512 * 1024)) + self.assertEqual(root.find("memory").text, str(512 * 1024)) self.assertEqual(root.find("memory").attrib["unit"], "KiB") disks = root.findall(".//disk") @@ -714,7 +709,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): root = ET.fromstring(xml_data) self.assertEqual(root.attrib["type"], "xen") self.assertEqual(root.find("vcpu").text, "1") - self.assertEqual(root.find("memory").text, six.text_type(512 * 1024)) + self.assertEqual(root.find("memory").text, str(512 * 1024)) self.assertEqual(root.find("memory").attrib["unit"], "KiB") self.assertEqual( root.find(".//kernel").text, "/usr/lib/grub2/x86_64-xen/grub.xen" @@ -768,7 +763,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): root = ET.fromstring(xml_data) self.assertEqual(root.attrib["type"], "vmware") self.assertEqual(root.find("vcpu").text, "1") - self.assertEqual(root.find("memory").text, six.text_type(512 * 1024)) + self.assertEqual(root.find("memory").text, str(512 * 1024)) self.assertEqual(root.find("memory").attrib["unit"], "KiB") self.assertTrue(len(root.findall(".//disk")) == 2) self.assertTrue(len(root.findall(".//interface")) == 2) @@ -801,7 +796,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): root = ET.fromstring(xml_data) self.assertEqual(root.attrib["type"], "kvm") self.assertEqual(root.find("vcpu").text, "1") - self.assertEqual(root.find("memory").text, six.text_type(512 * 1024)) + self.assertEqual(root.find("memory").text, str(512 * 1024)) self.assertEqual(root.find("memory").attrib["unit"], "KiB") disks = root.findall(".//disk") self.assertTrue(len(disks) == 2) @@ -1635,7 +1630,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): self.assertIsNone(definition.find("./devices/disk[2]/source")) self.assertEqual( mock_run.call_args[0][0], - 'qemu-img create -f qcow2 "{0}" 10240M'.format(expected_disk_path), + 'qemu-img create -f qcow2 "{}" 10240M'.format(expected_disk_path), ) self.assertEqual(mock_chmod.call_args[0][0], expected_disk_path) @@ -1729,11 +1724,12 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): 1 hvm + - + @@ -1850,17 +1846,36 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): "cmdline": "console=ttyS0 ks=http://example.com/f8-i386/os/", } - boot_uefi = { - "loader": "/usr/share/OVMF/OVMF_CODE.fd", - "nvram": "/usr/share/OVMF/OVMF_VARS.ms.fd", - } + # Update boot devices case + define_mock.reset_mock() + self.assertEqual( + { + "definition": True, + "disk": {"attached": [], "detached": [], "updated": []}, + "interface": {"attached": [], "detached": []}, + }, + virt.update("my_vm", boot_dev="cdrom network hd"), + ) + setxml = ET.fromstring(define_mock.call_args[0][0]) + self.assertEqual( + ["cdrom", "network", "hd"], + [node.get("dev") for node in setxml.findall("os/boot")], + ) - invalid_boot = { - "loader": "/usr/share/OVMF/OVMF_CODE.fd", - "initrd": "/root/f8-i386-initrd", - } + # Update unchanged boot devices case + define_mock.reset_mock() + self.assertEqual( + { + "definition": False, + "disk": {"attached": [], "detached": [], "updated": []}, + "interface": {"attached": [], "detached": []}, + }, + virt.update("my_vm", boot_dev="hd"), + ) + define_mock.assert_not_called() # Update with boot parameter case + define_mock.reset_mock() self.assertEqual( { "definition": True, @@ -1884,6 +1899,11 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): "console=ttyS0 ks=http://example.com/f8-i386/os/", ) + boot_uefi = { + "loader": "/usr/share/OVMF/OVMF_CODE.fd", + "nvram": "/usr/share/OVMF/OVMF_VARS.ms.fd", + } + self.assertEqual( { "definition": True, @@ -1903,9 +1923,28 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): "/usr/share/OVMF/OVMF_VARS.ms.fd", ) + self.assertEqual( + { + "definition": True, + "disk": {"attached": [], "detached": [], "updated": []}, + "interface": {"attached": [], "detached": []}, + }, + virt.update("my_vm", boot={"efi": True}), + ) + setxml = ET.fromstring(define_mock.call_args[0][0]) + self.assertEqual(setxml.find("os").attrib.get("firmware"), "efi") + + invalid_boot = { + "loader": "/usr/share/OVMF/OVMF_CODE.fd", + "initrd": "/root/f8-i386-initrd", + } + with self.assertRaises(SaltInvocationError): virt.update("my_vm", boot=invalid_boot) + with self.assertRaises(SaltInvocationError): + virt.update("my_vm", boot={"efi": "Not a boolean value"}) + # Update memory case setmem_mock = MagicMock(return_value=0) domain_mock.setMemoryFlags = setmem_mock @@ -1955,7 +1994,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): ) # pylint: disable=no-member self.assertEqual( mock_run.call_args[0][0], - 'qemu-img create -f qcow2 "{0}" 2048M'.format(added_disk_path), + 'qemu-img create -f qcow2 "{}" 2048M'.format(added_disk_path), ) self.assertEqual(mock_chmod.call_args[0][0], added_disk_path) self.assertListEqual( @@ -2397,6 +2436,43 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): ], ) + def test_update_xen_boot_params(self): + """ + Test virt.update() a Xen definition no boot parameter. + """ + root_dir = os.path.join(salt.syspaths.ROOT_DIR, "srv", "salt-images") + xml_boot = """ + + vm + 1048576 + 1048576 + 1 + + hvm + /usr/lib/xen/boot/hvmloader + + + """ + domain_mock_boot = self.set_mock_vm("vm", xml_boot) + domain_mock_boot.OSType = MagicMock(return_value="hvm") + define_mock_boot = MagicMock(return_value=True) + define_mock_boot.setVcpusFlags = MagicMock(return_value=0) + self.mock_conn.defineXML = define_mock_boot + self.assertEqual( + { + "cpu": False, + "definition": True, + "disk": {"attached": [], "detached": [], "updated": []}, + "interface": {"attached": [], "detached": []}, + }, + virt.update("vm", cpu=2), + ) + setxml = ET.fromstring(define_mock_boot.call_args[0][0]) + self.assertEqual(setxml.find("os").find("loader").attrib.get("type"), "rom") + self.assertEqual( + setxml.find("os").find("loader").text, "/usr/lib/xen/boot/hvmloader" + ) + def test_update_existing_boot_params(self): """ Test virt.update() with existing boot parameters. @@ -2537,6 +2613,18 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): self.assertEqual(setxml.find("os").find("initrd"), None) self.assertEqual(setxml.find("os").find("cmdline"), None) + self.assertEqual( + { + "definition": True, + "disk": {"attached": [], "detached": [], "updated": []}, + "interface": {"attached": [], "detached": []}, + }, + virt.update("vm_with_boot_param", boot={"efi": False}), + ) + setxml = ET.fromstring(define_mock_boot.call_args[0][0]) + self.assertEqual(setxml.find("os").find("nvram"), None) + self.assertEqual(setxml.find("os").find("loader"), None) + self.assertEqual( { "definition": True, @@ -2582,7 +2670,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): salt.modules.config.__opts__, mock_config # pylint: disable=no-member ): - for name in six.iterkeys(mock_config["virt"]["nic"]): + for name in mock_config["virt"]["nic"].keys(): profile = salt.modules.virt._nic_profile(name, "kvm") self.assertEqual(len(profile), 2) @@ -3592,8 +3680,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): "44454c4c-3400-105a-8033-b3c04f4b344a", caps["host"]["host"]["uuid"] ) self.assertEqual( - set(["qemu", "kvm"]), - set([domainCaps["domain"] for domainCaps in caps["domains"]]), + {"qemu", "kvm"}, {domainCaps["domain"] for domainCaps in caps["domains"]}, ) def test_network_tag(self): @@ -3694,9 +3781,9 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): for i in range(2): net_mock = MagicMock() - net_mock.name.return_value = "net{0}".format(i) + net_mock.name.return_value = "net{}".format(i) net_mock.UUIDString.return_value = "some-uuid" - net_mock.bridgeName.return_value = "br{0}".format(i) + net_mock.bridgeName.return_value = "br{}".format(i) net_mock.autostart.return_value = True net_mock.isActive.return_value = False net_mock.isPersistent.return_value = True @@ -4156,8 +4243,8 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): pool_mocks = [] for i in range(2): pool_mock = MagicMock() - pool_mock.name.return_value = "pool{0}".format(i) - pool_mock.UUIDString.return_value = "some-uuid-{0}".format(i) + pool_mock.name.return_value = "pool{}".format(i) + pool_mock.UUIDString.return_value = "some-uuid-{}".format(i) pool_mock.info.return_value = [0, 1234, 5678, 123] pool_mock.autostart.return_value = True pool_mock.isPersistent.return_value = True @@ -4257,7 +4344,6 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): """ mock_pool = MagicMock() mock_pool.delete = MagicMock(return_value=0) - mock_pool.XMLDesc.return_value = "" self.mock_conn.storagePoolLookupByName = MagicMock(return_value=mock_pool) res = virt.pool_delete("test-pool") @@ -4271,12 +4357,12 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): self.mock_libvirt.VIR_STORAGE_POOL_DELETE_NORMAL ) - def test_pool_delete_secret(self): + def test_pool_undefine_secret(self): """ - Test virt.pool_delete function where the pool has a secret + Test virt.pool_undefine function where the pool has a secret """ mock_pool = MagicMock() - mock_pool.delete = MagicMock(return_value=0) + mock_pool.undefine = MagicMock(return_value=0) mock_pool.XMLDesc.return_value = """ test-ses @@ -4293,16 +4379,11 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): mock_undefine = MagicMock(return_value=0) self.mock_conn.secretLookupByUsage.return_value.undefine = mock_undefine - res = virt.pool_delete("test-ses") + res = virt.pool_undefine("test-ses") self.assertTrue(res) self.mock_conn.storagePoolLookupByName.assert_called_once_with("test-ses") - - # Shouldn't be called with another parameter so far since those are not implemented - # and thus throwing exceptions. - mock_pool.delete.assert_called_once_with( - self.mock_libvirt.VIR_STORAGE_POOL_DELETE_NORMAL - ) + mock_pool.undefine.assert_called_once_with() self.mock_conn.secretLookupByUsage.assert_called_once_with( self.mock_libvirt.VIR_SECRET_USAGE_TYPE_CEPH, "pool_test-ses" @@ -4571,24 +4652,6 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): """ - expected_xml = ( - '' - "default" - "20fbe05c-ab40-418a-9afa-136d512f0ede" - '1999421108224' - '713207042048' - '1286214066176' - "" - '' - '' - '' - '' - "" - "iscsi-images" - "" - "" - ) - mock_secret = MagicMock() self.mock_conn.secretLookupByUUIDString = MagicMock(return_value=mock_secret) @@ -4609,6 +4672,23 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): self.mock_conn.storagePoolDefineXML.assert_not_called() mock_secret.setValue.assert_called_once_with(b"secret") + # Case where the secret can't be found + self.mock_conn.secretLookupByUUIDString = MagicMock( + side_effect=self.mock_libvirt.libvirtError("secret not found") + ) + self.assertFalse( + virt.pool_update( + "default", + "rbd", + source_name="iscsi-images", + source_hosts=["ses4.tf.local", "ses5.tf.local"], + source_auth={"username": "libvirt", "password": "c2VjcmV0"}, + ) + ) + self.mock_conn.storagePoolDefineXML.assert_not_called() + self.mock_conn.secretDefineXML.assert_called_once() + mock_secret.setValue.assert_called_once_with(b"secret") + def test_pool_update_password_create(self): """ Test the pool_update function, where the password only is changed @@ -4695,11 +4775,11 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): for idx, disk in enumerate(vms_disks): vm = MagicMock() # pylint: disable=no-member - vm.name.return_value = "vm{0}".format(idx) + vm.name.return_value = "vm{}".format(idx) vm.XMLDesc.return_value = """ - vm{0} - {1} + vm{} + {} """.format( idx, disk @@ -4760,7 +4840,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): # pylint: disable=no-member mock_volume.name.return_value = vol_data["name"] mock_volume.key.return_value = vol_data["key"] - mock_volume.path.return_value = "/path/to/{0}.qcow2".format( + mock_volume.path.return_value = "/path/to/{}.qcow2".format( vol_data["name"] ) if vol_data["info"]: @@ -4769,7 +4849,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): """ - {0} + {} """.format( vol_data["backingStore"] @@ -5234,7 +5314,7 @@ class VirtTestCase(TestCase, LoaderModuleMockMixin): def create_mock_vm(idx): mock_vm = MagicMock() - mock_vm.name.return_value = "vm{0}".format(idx) + mock_vm.name.return_value = "vm{}".format(idx) return mock_vm mock_vms = [create_mock_vm(idx) for idx in range(3)] diff --git a/tests/unit/states/test_virt.py b/tests/unit/states/test_virt.py index c76f8a5fc0..f03159334b 100644 --- a/tests/unit/states/test_virt.py +++ b/tests/unit/states/test_virt.py @@ -1,9 +1,7 @@ -# -*- coding: utf-8 -*- """ :codeauthor: Jayesh Kariya """ # Import Python libs -from __future__ import absolute_import, print_function, unicode_literals import shutil import tempfile @@ -14,7 +12,6 @@ import salt.utils.files from salt.exceptions import CommandExecutionError, SaltInvocationError # Import 3rd-party libs -from salt.ext import six from tests.support.mixins import LoaderModuleMockMixin from tests.support.mock import MagicMock, mock_open, patch @@ -37,7 +34,7 @@ class LibvirtMock(MagicMock): # pylint: disable=too-many-ancestors """ Fake function return error message """ - return six.text_type(self) + return str(self) class LibvirtTestCase(TestCase, LoaderModuleMockMixin): @@ -341,6 +338,7 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): "myvm", cpu=2, mem=2048, + boot_dev="cdrom hd", os_type="linux", arch="i686", vm_type="qemu", @@ -363,6 +361,7 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): "myvm", cpu=2, mem=2048, + boot_dev="cdrom hd", os_type="linux", arch="i686", disk="prod", @@ -471,10 +470,13 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): "comment": "Domain myvm updated with live update(s) failures", } ) - self.assertDictEqual(virt.defined("myvm", cpu=2), ret) + self.assertDictEqual( + virt.defined("myvm", cpu=2, boot_dev="cdrom hd"), ret + ) update_mock.assert_called_with( "myvm", cpu=2, + boot_dev="cdrom hd", mem=None, disk_profile=None, disks=None, @@ -598,6 +600,7 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): password=None, boot=None, test=True, + boot_dev=None, ) # No changes case @@ -632,6 +635,7 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): password=None, boot=None, test=True, + boot_dev=None, ) def test_running(self): @@ -708,6 +712,7 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): install=True, pub_key=None, priv_key=None, + boot_dev=None, connection=None, username=None, password=None, @@ -769,6 +774,7 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): install=False, pub_key="/path/to/key.pub", priv_key="/path/to/key", + boot_dev="network hd", connection="someconnection", username="libvirtuser", password="supersecret", @@ -793,6 +799,7 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): start=False, pub_key="/path/to/key.pub", priv_key="/path/to/key", + boot_dev="network hd", connection="someconnection", username="libvirtuser", password="supersecret", @@ -937,6 +944,7 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): password=None, boot=None, test=False, + boot_dev=None, ) # Failed definition update case @@ -1055,6 +1063,7 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): password=None, boot=None, test=True, + boot_dev=None, ) start_mock.assert_not_called() @@ -1091,6 +1100,7 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): password=None, boot=None, test=True, + boot_dev=None, ) def test_stopped(self): @@ -1978,6 +1988,72 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): password="secret", ) + # Define a pool that doesn't handle build + for mock in mocks: + mocks[mock].reset_mock() + with patch.dict( + virt.__salt__, + { # pylint: disable=no-member + "virt.pool_info": MagicMock( + side_effect=[ + {}, + {"mypool": {"state": "stopped", "autostart": True}}, + ] + ), + "virt.pool_define": mocks["define"], + "virt.pool_build": mocks["build"], + "virt.pool_set_autostart": mocks["autostart"], + }, + ): + ret.update( + { + "changes": {"mypool": "Pool defined, marked for autostart"}, + "comment": "Pool mypool defined, marked for autostart", + } + ) + self.assertDictEqual( + virt.pool_defined( + "mypool", + ptype="rbd", + source={ + "name": "libvirt-pool", + "hosts": ["ses2.tf.local", "ses3.tf.local"], + "auth": { + "username": "libvirt", + "password": "AQAz+PRdtquBBRAASMv7nlMZYfxIyLw3St65Xw==", + }, + }, + autostart=True, + ), + ret, + ) + mocks["define"].assert_called_with( + "mypool", + ptype="rbd", + target=None, + permissions=None, + source_devices=None, + source_dir=None, + source_adapter=None, + source_hosts=["ses2.tf.local", "ses3.tf.local"], + source_auth={ + "username": "libvirt", + "password": "AQAz+PRdtquBBRAASMv7nlMZYfxIyLw3St65Xw==", + }, + source_name="libvirt-pool", + source_format=None, + source_initiator=None, + start=False, + transient=False, + connection=None, + username=None, + password=None, + ) + mocks["autostart"].assert_called_with( + "mypool", state="on", connection=None, username=None, password=None, + ) + mocks["build"].assert_not_called() + mocks["update"] = MagicMock(return_value=False) for mock in mocks: mocks[mock].reset_mock() @@ -2027,6 +2103,9 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): for mock in mocks: mocks[mock].reset_mock() mocks["update"] = MagicMock(return_value=True) + mocks["build"] = MagicMock( + side_effect=self.mock_libvirt.libvirtError("Existing VG") + ) with patch.dict( virt.__salt__, { # pylint: disable=no-member @@ -2130,6 +2209,7 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): ), ret, ) + mocks["build"].assert_not_called() mocks["update"].assert_called_with( "mypool", ptype="logical", @@ -2477,8 +2557,8 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): ): ret.update( { - "changes": {"mypool": "Pool updated, built, restarted"}, - "comment": "Pool mypool updated, built, restarted", + "changes": {"mypool": "Pool updated, restarted"}, + "comment": "Pool mypool updated, restarted", "result": True, } ) @@ -2504,9 +2584,7 @@ class LibvirtTestCase(TestCase, LoaderModuleMockMixin): mocks["start"].assert_called_with( "mypool", connection=None, username=None, password=None ) - mocks["build"].assert_called_with( - "mypool", connection=None, username=None, password=None - ) + mocks["build"].assert_not_called() mocks["update"].assert_called_with( "mypool", ptype="logical", diff --git a/tests/unit/utils/test_data.py b/tests/unit/utils/test_data.py index 8fa352321c..8a6956d442 100644 --- a/tests/unit/utils/test_data.py +++ b/tests/unit/utils/test_data.py @@ -1,38 +1,38 @@ -# -*- coding: utf-8 -*- -''' +""" Tests for salt.utils.data -''' +""" # Import Python libs -from __future__ import absolute_import, print_function, unicode_literals + import logging # Import Salt libs import salt.utils.data import salt.utils.stringutils -from salt.utils.odict import OrderedDict -from tests.support.unit import TestCase, LOREM_IPSUM -from tests.support.mock import patch # Import 3rd party libs -from salt.ext.six.moves import builtins # pylint: disable=import-error,redefined-builtin -from salt.ext import six +from salt.ext.six.moves import ( # pylint: disable=import-error,redefined-builtin + builtins, +) +from salt.utils.odict import OrderedDict +from tests.support.mock import patch +from tests.support.unit import LOREM_IPSUM, TestCase log = logging.getLogger(__name__) -_b = lambda x: x.encode('utf-8') +_b = lambda x: x.encode("utf-8") _s = lambda x: salt.utils.stringutils.to_str(x, normalize=True) # Some randomized data that will not decode -BYTES = b'1\x814\x10' +BYTES = b"1\x814\x10" # This is an example of a unicode string with й constructed using two separate # code points. Do not modify it. -EGGS = '\u044f\u0438\u0306\u0446\u0430' +EGGS = "\u044f\u0438\u0306\u0446\u0430" class DataTestCase(TestCase): test_data = [ - 'unicode_str', - _b('питон'), + "unicode_str", + _b("питон"), 123, 456.789, True, @@ -40,71 +40,79 @@ class DataTestCase(TestCase): None, EGGS, BYTES, - [123, 456.789, _b('спам'), True, False, None, EGGS, BYTES], - (987, 654.321, _b('яйца'), EGGS, None, (True, EGGS, BYTES)), - {_b('str_key'): _b('str_val'), - None: True, - 123: 456.789, - EGGS: BYTES, - _b('subdict'): {'unicode_key': EGGS, - _b('tuple'): (123, 'hello', _b('world'), True, EGGS, BYTES), - _b('list'): [456, _b('спам'), False, EGGS, BYTES]}}, - OrderedDict([(_b('foo'), 'bar'), (123, 456), (EGGS, BYTES)]) + [123, 456.789, _b("спам"), True, False, None, EGGS, BYTES], + (987, 654.321, _b("яйца"), EGGS, None, (True, EGGS, BYTES)), + { + _b("str_key"): _b("str_val"), + None: True, + 123: 456.789, + EGGS: BYTES, + _b("subdict"): { + "unicode_key": EGGS, + _b("tuple"): (123, "hello", _b("world"), True, EGGS, BYTES), + _b("list"): [456, _b("спам"), False, EGGS, BYTES], + }, + }, + OrderedDict([(_b("foo"), "bar"), (123, 456), (EGGS, BYTES)]), ] def test_sorted_ignorecase(self): - test_list = ['foo', 'Foo', 'bar', 'Bar'] - expected_list = ['bar', 'Bar', 'foo', 'Foo'] - self.assertEqual( - salt.utils.data.sorted_ignorecase(test_list), expected_list) + test_list = ["foo", "Foo", "bar", "Bar"] + expected_list = ["bar", "Bar", "foo", "Foo"] + self.assertEqual(salt.utils.data.sorted_ignorecase(test_list), expected_list) def test_mysql_to_dict(self): - test_mysql_output = ['+----+------+-----------+------+---------+------+-------+------------------+', - '| Id | User | Host | db | Command | Time | State | Info |', - '+----+------+-----------+------+---------+------+-------+------------------+', - '| 7 | root | localhost | NULL | Query | 0 | init | show processlist |', - '+----+------+-----------+------+---------+------+-------+------------------+'] + test_mysql_output = [ + "+----+------+-----------+------+---------+------+-------+------------------+", + "| Id | User | Host | db | Command | Time | State | Info |", + "+----+------+-----------+------+---------+------+-------+------------------+", + "| 7 | root | localhost | NULL | Query | 0 | init | show processlist |", + "+----+------+-----------+------+---------+------+-------+------------------+", + ] - ret = salt.utils.data.mysql_to_dict(test_mysql_output, 'Info') + ret = salt.utils.data.mysql_to_dict(test_mysql_output, "Info") expected_dict = { - 'show processlist': {'Info': 'show processlist', 'db': 'NULL', 'State': 'init', 'Host': 'localhost', - 'Command': 'Query', 'User': 'root', 'Time': 0, 'Id': 7}} + "show processlist": { + "Info": "show processlist", + "db": "NULL", + "State": "init", + "Host": "localhost", + "Command": "Query", + "User": "root", + "Time": 0, + "Id": 7, + } + } self.assertDictEqual(ret, expected_dict) def test_subdict_match(self): - test_two_level_dict = {'foo': {'bar': 'baz'}} - test_two_level_comb_dict = {'foo': {'bar': 'baz:woz'}} + test_two_level_dict = {"foo": {"bar": "baz"}} + test_two_level_comb_dict = {"foo": {"bar": "baz:woz"}} test_two_level_dict_and_list = { - 'abc': ['def', 'ghi', {'lorem': {'ipsum': [{'dolor': 'sit'}]}}], + "abc": ["def", "ghi", {"lorem": {"ipsum": [{"dolor": "sit"}]}}], } - test_three_level_dict = {'a': {'b': {'c': 'v'}}} + test_three_level_dict = {"a": {"b": {"c": "v"}}} self.assertTrue( - salt.utils.data.subdict_match( - test_two_level_dict, 'foo:bar:baz' - ) + salt.utils.data.subdict_match(test_two_level_dict, "foo:bar:baz") ) # In test_two_level_comb_dict, 'foo:bar' corresponds to 'baz:woz', not # 'baz'. This match should return False. self.assertFalse( - salt.utils.data.subdict_match( - test_two_level_comb_dict, 'foo:bar:baz' - ) + salt.utils.data.subdict_match(test_two_level_comb_dict, "foo:bar:baz") ) # This tests matching with the delimiter in the value part (in other # words, that the path 'foo:bar' corresponds to the string 'baz:woz'). self.assertTrue( - salt.utils.data.subdict_match( - test_two_level_comb_dict, 'foo:bar:baz:woz' - ) + salt.utils.data.subdict_match(test_two_level_comb_dict, "foo:bar:baz:woz") ) # This would match if test_two_level_comb_dict['foo']['bar'] was equal # to 'baz:woz:wiz', or if there was more deep nesting. But it does not, # so this should return False. self.assertFalse( salt.utils.data.subdict_match( - test_two_level_comb_dict, 'foo:bar:baz:woz:wiz' + test_two_level_comb_dict, "foo:bar:baz:woz:wiz" ) ) # This tests for cases when a key path corresponds to a list. The @@ -115,189 +123,171 @@ class DataTestCase(TestCase): # salt.utils.traverse_list_and_dict() so this particular assertion is a # sanity check. self.assertTrue( - salt.utils.data.subdict_match( - test_two_level_dict_and_list, 'abc:ghi' - ) + salt.utils.data.subdict_match(test_two_level_dict_and_list, "abc:ghi") ) # This tests the use case of a dict embedded in a list, embedded in a # list, embedded in a dict. This is a rather absurd case, but it # confirms that match recursion works properly. self.assertTrue( salt.utils.data.subdict_match( - test_two_level_dict_and_list, 'abc:lorem:ipsum:dolor:sit' + test_two_level_dict_and_list, "abc:lorem:ipsum:dolor:sit" ) ) # Test four level dict match for reference - self.assertTrue( - salt.utils.data.subdict_match( - test_three_level_dict, 'a:b:c:v' - ) - ) + self.assertTrue(salt.utils.data.subdict_match(test_three_level_dict, "a:b:c:v")) # Test regression in 2015.8 where 'a:c:v' would match 'a:b:c:v' - self.assertFalse( - salt.utils.data.subdict_match( - test_three_level_dict, 'a:c:v' - ) - ) + self.assertFalse(salt.utils.data.subdict_match(test_three_level_dict, "a:c:v")) # Test wildcard match - self.assertTrue( - salt.utils.data.subdict_match( - test_three_level_dict, 'a:*:c:v' - ) - ) + self.assertTrue(salt.utils.data.subdict_match(test_three_level_dict, "a:*:c:v")) def test_subdict_match_with_wildcards(self): - ''' + """ Tests subdict matching when wildcards are used in the expression - ''' - data = { - 'a': { - 'b': { - 'ç': 'd', - 'é': ['eff', 'gee', '8ch'], - 'ĩ': {'j': 'k'} - } - } - } - assert salt.utils.data.subdict_match(data, '*:*:*:*') - assert salt.utils.data.subdict_match(data, 'a:*:*:*') - assert salt.utils.data.subdict_match(data, 'a:b:*:*') - assert salt.utils.data.subdict_match(data, 'a:b:ç:*') - assert salt.utils.data.subdict_match(data, 'a:b:*:d') - assert salt.utils.data.subdict_match(data, 'a:*:ç:d') - assert salt.utils.data.subdict_match(data, '*:b:ç:d') - assert salt.utils.data.subdict_match(data, '*:*:ç:d') - assert salt.utils.data.subdict_match(data, '*:*:*:d') - assert salt.utils.data.subdict_match(data, 'a:*:*:d') - assert salt.utils.data.subdict_match(data, 'a:b:*:ef*') - assert salt.utils.data.subdict_match(data, 'a:b:*:g*') - assert salt.utils.data.subdict_match(data, 'a:b:*:j:*') - assert salt.utils.data.subdict_match(data, 'a:b:*:j:k') - assert salt.utils.data.subdict_match(data, 'a:b:*:*:k') - assert salt.utils.data.subdict_match(data, 'a:b:*:*:*') + """ + data = {"a": {"b": {"ç": "d", "é": ["eff", "gee", "8ch"], "ĩ": {"j": "k"}}}} + assert salt.utils.data.subdict_match(data, "*:*:*:*") + assert salt.utils.data.subdict_match(data, "a:*:*:*") + assert salt.utils.data.subdict_match(data, "a:b:*:*") + assert salt.utils.data.subdict_match(data, "a:b:ç:*") + assert salt.utils.data.subdict_match(data, "a:b:*:d") + assert salt.utils.data.subdict_match(data, "a:*:ç:d") + assert salt.utils.data.subdict_match(data, "*:b:ç:d") + assert salt.utils.data.subdict_match(data, "*:*:ç:d") + assert salt.utils.data.subdict_match(data, "*:*:*:d") + assert salt.utils.data.subdict_match(data, "a:*:*:d") + assert salt.utils.data.subdict_match(data, "a:b:*:ef*") + assert salt.utils.data.subdict_match(data, "a:b:*:g*") + assert salt.utils.data.subdict_match(data, "a:b:*:j:*") + assert salt.utils.data.subdict_match(data, "a:b:*:j:k") + assert salt.utils.data.subdict_match(data, "a:b:*:*:k") + assert salt.utils.data.subdict_match(data, "a:b:*:*:*") def test_traverse_dict(self): - test_two_level_dict = {'foo': {'bar': 'baz'}} + test_two_level_dict = {"foo": {"bar": "baz"}} self.assertDictEqual( - {'not_found': 'nope'}, + {"not_found": "nope"}, salt.utils.data.traverse_dict( - test_two_level_dict, 'foo:bar:baz', {'not_found': 'nope'} - ) + test_two_level_dict, "foo:bar:baz", {"not_found": "nope"} + ), ) self.assertEqual( - 'baz', + "baz", salt.utils.data.traverse_dict( - test_two_level_dict, 'foo:bar', {'not_found': 'not_found'} - ) + test_two_level_dict, "foo:bar", {"not_found": "not_found"} + ), ) def test_traverse_dict_and_list(self): - test_two_level_dict = {'foo': {'bar': 'baz'}} + test_two_level_dict = {"foo": {"bar": "baz"}} test_two_level_dict_and_list = { - 'foo': ['bar', 'baz', {'lorem': {'ipsum': [{'dolor': 'sit'}]}}] + "foo": ["bar", "baz", {"lorem": {"ipsum": [{"dolor": "sit"}]}}] } # Check traversing too far: salt.utils.data.traverse_dict_and_list() returns # the value corresponding to a given key path, and baz is a value # corresponding to the key path foo:bar. self.assertDictEqual( - {'not_found': 'nope'}, + {"not_found": "nope"}, salt.utils.data.traverse_dict_and_list( - test_two_level_dict, 'foo:bar:baz', {'not_found': 'nope'} - ) + test_two_level_dict, "foo:bar:baz", {"not_found": "nope"} + ), ) # Now check to ensure that foo:bar corresponds to baz self.assertEqual( - 'baz', + "baz", salt.utils.data.traverse_dict_and_list( - test_two_level_dict, 'foo:bar', {'not_found': 'not_found'} - ) + test_two_level_dict, "foo:bar", {"not_found": "not_found"} + ), ) # Check traversing too far self.assertDictEqual( - {'not_found': 'nope'}, + {"not_found": "nope"}, salt.utils.data.traverse_dict_and_list( - test_two_level_dict_and_list, 'foo:bar', {'not_found': 'nope'} - ) + test_two_level_dict_and_list, "foo:bar", {"not_found": "nope"} + ), ) # Check index 1 (2nd element) of list corresponding to path 'foo' self.assertEqual( - 'baz', + "baz", salt.utils.data.traverse_dict_and_list( - test_two_level_dict_and_list, 'foo:1', {'not_found': 'not_found'} - ) + test_two_level_dict_and_list, "foo:1", {"not_found": "not_found"} + ), ) # Traverse a couple times into dicts embedded in lists self.assertEqual( - 'sit', + "sit", salt.utils.data.traverse_dict_and_list( test_two_level_dict_and_list, - 'foo:lorem:ipsum:dolor', - {'not_found': 'not_found'} - ) + "foo:lorem:ipsum:dolor", + {"not_found": "not_found"}, + ), ) def test_compare_dicts(self): - ret = salt.utils.data.compare_dicts(old={'foo': 'bar'}, new={'foo': 'bar'}) + ret = salt.utils.data.compare_dicts(old={"foo": "bar"}, new={"foo": "bar"}) self.assertEqual(ret, {}) - ret = salt.utils.data.compare_dicts(old={'foo': 'bar'}, new={'foo': 'woz'}) - expected_ret = {'foo': {'new': 'woz', 'old': 'bar'}} + ret = salt.utils.data.compare_dicts(old={"foo": "bar"}, new={"foo": "woz"}) + expected_ret = {"foo": {"new": "woz", "old": "bar"}} self.assertDictEqual(ret, expected_ret) def test_compare_lists_no_change(self): - ret = salt.utils.data.compare_lists(old=[1, 2, 3, 'a', 'b', 'c'], - new=[1, 2, 3, 'a', 'b', 'c']) + ret = salt.utils.data.compare_lists( + old=[1, 2, 3, "a", "b", "c"], new=[1, 2, 3, "a", "b", "c"] + ) expected = {} self.assertDictEqual(ret, expected) def test_compare_lists_changes(self): - ret = salt.utils.data.compare_lists(old=[1, 2, 3, 'a', 'b', 'c'], - new=[1, 2, 4, 'x', 'y', 'z']) - expected = {'new': [4, 'x', 'y', 'z'], 'old': [3, 'a', 'b', 'c']} + ret = salt.utils.data.compare_lists( + old=[1, 2, 3, "a", "b", "c"], new=[1, 2, 4, "x", "y", "z"] + ) + expected = {"new": [4, "x", "y", "z"], "old": [3, "a", "b", "c"]} self.assertDictEqual(ret, expected) def test_compare_lists_changes_new(self): - ret = salt.utils.data.compare_lists(old=[1, 2, 3], - new=[1, 2, 3, 'x', 'y', 'z']) - expected = {'new': ['x', 'y', 'z']} + ret = salt.utils.data.compare_lists(old=[1, 2, 3], new=[1, 2, 3, "x", "y", "z"]) + expected = {"new": ["x", "y", "z"]} self.assertDictEqual(ret, expected) def test_compare_lists_changes_old(self): - ret = salt.utils.data.compare_lists(old=[1, 2, 3, 'a', 'b', 'c'], - new=[1, 2, 3]) - expected = {'old': ['a', 'b', 'c']} + ret = salt.utils.data.compare_lists(old=[1, 2, 3, "a", "b", "c"], new=[1, 2, 3]) + expected = {"old": ["a", "b", "c"]} self.assertDictEqual(ret, expected) def test_decode(self): - ''' + """ Companion to test_decode_to_str, they should both be kept up-to-date with one another. NOTE: This uses the lambda "_b" defined above in the global scope, which encodes a string to a bytestring, assuming utf-8. - ''' + """ expected = [ - 'unicode_str', - 'питон', + "unicode_str", + "питон", 123, 456.789, True, False, None, - 'яйца', + "яйца", BYTES, - [123, 456.789, 'спам', True, False, None, 'яйца', BYTES], - (987, 654.321, 'яйца', 'яйца', None, (True, 'яйца', BYTES)), - {'str_key': 'str_val', - None: True, - 123: 456.789, - 'яйца': BYTES, - 'subdict': {'unicode_key': 'яйца', - 'tuple': (123, 'hello', 'world', True, 'яйца', BYTES), - 'list': [456, 'спам', False, 'яйца', BYTES]}}, - OrderedDict([('foo', 'bar'), (123, 456), ('яйца', BYTES)]) + [123, 456.789, "спам", True, False, None, "яйца", BYTES], + (987, 654.321, "яйца", "яйца", None, (True, "яйца", BYTES)), + { + "str_key": "str_val", + None: True, + 123: 456.789, + "яйца": BYTES, + "subdict": { + "unicode_key": "яйца", + "tuple": (123, "hello", "world", True, "яйца", BYTES), + "list": [456, "спам", False, "яйца", BYTES], + }, + }, + OrderedDict([("foo", "bar"), (123, 456), ("яйца", BYTES)]), ] ret = salt.utils.data.decode( @@ -305,7 +295,8 @@ class DataTestCase(TestCase): keep=True, normalize=True, preserve_dict_class=True, - preserve_tuples=True) + preserve_tuples=True, + ) self.assertEqual(ret, expected) # The binary data in the data structure should fail to decode, even @@ -317,74 +308,100 @@ class DataTestCase(TestCase): keep=False, normalize=True, preserve_dict_class=True, - preserve_tuples=True) + preserve_tuples=True, + ) # Now munge the expected data so that we get what we would expect if we # disable preservation of dict class and tuples - expected[10] = [987, 654.321, 'яйца', 'яйца', None, [True, 'яйца', BYTES]] - expected[11]['subdict']['tuple'] = [123, 'hello', 'world', True, 'яйца', BYTES] - expected[12] = {'foo': 'bar', 123: 456, 'яйца': BYTES} + expected[10] = [987, 654.321, "яйца", "яйца", None, [True, "яйца", BYTES]] + expected[11]["subdict"]["tuple"] = [123, "hello", "world", True, "яйца", BYTES] + expected[12] = {"foo": "bar", 123: 456, "яйца": BYTES} ret = salt.utils.data.decode( self.test_data, keep=True, normalize=True, preserve_dict_class=False, - preserve_tuples=False) + preserve_tuples=False, + ) self.assertEqual(ret, expected) # Now test single non-string, non-data-structure items, these should # return the same value when passed to this function for item in (123, 4.56, True, False, None): - log.debug('Testing decode of %s', item) + log.debug("Testing decode of %s", item) self.assertEqual(salt.utils.data.decode(item), item) # Test single strings (not in a data structure) - self.assertEqual(salt.utils.data.decode('foo'), 'foo') - self.assertEqual(salt.utils.data.decode(_b('bar')), 'bar') - self.assertEqual(salt.utils.data.decode(EGGS, normalize=True), 'яйца') + self.assertEqual(salt.utils.data.decode("foo"), "foo") + self.assertEqual(salt.utils.data.decode(_b("bar")), "bar") + self.assertEqual(salt.utils.data.decode(EGGS, normalize=True), "яйца") self.assertEqual(salt.utils.data.decode(EGGS, normalize=False), EGGS) # Test binary blob self.assertEqual(salt.utils.data.decode(BYTES, keep=True), BYTES) - self.assertRaises( - UnicodeDecodeError, - salt.utils.data.decode, - BYTES, - keep=False) + self.assertRaises(UnicodeDecodeError, salt.utils.data.decode, BYTES, keep=False) + + def test_circular_refs_dicts(self): + test_dict = {"key": "value", "type": "test1"} + test_dict["self"] = test_dict + ret = salt.utils.data._remove_circular_refs(ob=test_dict) + self.assertDictEqual(ret, {"key": "value", "type": "test1", "self": None}) + + def test_circular_refs_lists(self): + test_list = { + "foo": [], + } + test_list["foo"].append((test_list,)) + ret = salt.utils.data._remove_circular_refs(ob=test_list) + self.assertDictEqual(ret, {"foo": [(None,)]}) + + def test_circular_refs_tuple(self): + test_dup = {"foo": "string 1", "bar": "string 1", "ham": 1, "spam": 1} + ret = salt.utils.data._remove_circular_refs(ob=test_dup) + self.assertDictEqual( + ret, {"foo": "string 1", "bar": "string 1", "ham": 1, "spam": 1} + ) def test_decode_to_str(self): - ''' + """ Companion to test_decode, they should both be kept up-to-date with one another. NOTE: This uses the lambda "_s" defined above in the global scope, which converts the string/bytestring to a str type. - ''' + """ expected = [ - _s('unicode_str'), - _s('питон'), + _s("unicode_str"), + _s("питон"), 123, 456.789, True, False, None, - _s('яйца'), + _s("яйца"), BYTES, - [123, 456.789, _s('спам'), True, False, None, _s('яйца'), BYTES], - (987, 654.321, _s('яйца'), _s('яйца'), None, (True, _s('яйца'), BYTES)), + [123, 456.789, _s("спам"), True, False, None, _s("яйца"), BYTES], + (987, 654.321, _s("яйца"), _s("яйца"), None, (True, _s("яйца"), BYTES)), { - _s('str_key'): _s('str_val'), + _s("str_key"): _s("str_val"), None: True, 123: 456.789, - _s('яйца'): BYTES, - _s('subdict'): { - _s('unicode_key'): _s('яйца'), - _s('tuple'): (123, _s('hello'), _s('world'), True, _s('яйца'), BYTES), - _s('list'): [456, _s('спам'), False, _s('яйца'), BYTES] - } + _s("яйца"): BYTES, + _s("subdict"): { + _s("unicode_key"): _s("яйца"), + _s("tuple"): ( + 123, + _s("hello"), + _s("world"), + True, + _s("яйца"), + BYTES, + ), + _s("list"): [456, _s("спам"), False, _s("яйца"), BYTES], + }, }, - OrderedDict([(_s('foo'), _s('bar')), (123, 456), (_s('яйца'), BYTES)]) + OrderedDict([(_s("foo"), _s("bar")), (123, 456), (_s("яйца"), BYTES)]), ] ret = salt.utils.data.decode( @@ -393,27 +410,42 @@ class DataTestCase(TestCase): normalize=True, preserve_dict_class=True, preserve_tuples=True, - to_str=True) + to_str=True, + ) self.assertEqual(ret, expected) - if six.PY3: - # The binary data in the data structure should fail to decode, even - # using the fallback, and raise an exception. - self.assertRaises( - UnicodeDecodeError, - salt.utils.data.decode, - self.test_data, - keep=False, - normalize=True, - preserve_dict_class=True, - preserve_tuples=True, - to_str=True) + # The binary data in the data structure should fail to decode, even + # using the fallback, and raise an exception. + self.assertRaises( + UnicodeDecodeError, + salt.utils.data.decode, + self.test_data, + keep=False, + normalize=True, + preserve_dict_class=True, + preserve_tuples=True, + to_str=True, + ) # Now munge the expected data so that we get what we would expect if we # disable preservation of dict class and tuples - expected[10] = [987, 654.321, _s('яйца'), _s('яйца'), None, [True, _s('яйца'), BYTES]] - expected[11][_s('subdict')][_s('tuple')] = [123, _s('hello'), _s('world'), True, _s('яйца'), BYTES] - expected[12] = {_s('foo'): _s('bar'), 123: 456, _s('яйца'): BYTES} + expected[10] = [ + 987, + 654.321, + _s("яйца"), + _s("яйца"), + None, + [True, _s("яйца"), BYTES], + ] + expected[11][_s("subdict")][_s("tuple")] = [ + 123, + _s("hello"), + _s("world"), + True, + _s("яйца"), + BYTES, + ] + expected[12] = {_s("foo"): _s("bar"), 123: 456, _s("яйца"): BYTES} ret = salt.utils.data.decode( self.test_data, @@ -421,47 +453,41 @@ class DataTestCase(TestCase): normalize=True, preserve_dict_class=False, preserve_tuples=False, - to_str=True) + to_str=True, + ) self.assertEqual(ret, expected) # Now test single non-string, non-data-structure items, these should # return the same value when passed to this function for item in (123, 4.56, True, False, None): - log.debug('Testing decode of %s', item) + log.debug("Testing decode of %s", item) self.assertEqual(salt.utils.data.decode(item, to_str=True), item) # Test single strings (not in a data structure) - self.assertEqual(salt.utils.data.decode('foo', to_str=True), _s('foo')) - self.assertEqual(salt.utils.data.decode(_b('bar'), to_str=True), _s('bar')) + self.assertEqual(salt.utils.data.decode("foo", to_str=True), _s("foo")) + self.assertEqual(salt.utils.data.decode(_b("bar"), to_str=True), _s("bar")) # Test binary blob - self.assertEqual( - salt.utils.data.decode(BYTES, keep=True, to_str=True), - BYTES + self.assertEqual(salt.utils.data.decode(BYTES, keep=True, to_str=True), BYTES) + self.assertRaises( + UnicodeDecodeError, salt.utils.data.decode, BYTES, keep=False, to_str=True, ) - if six.PY3: - self.assertRaises( - UnicodeDecodeError, - salt.utils.data.decode, - BYTES, - keep=False, - to_str=True) def test_decode_fallback(self): - ''' + """ Test fallback to utf-8 - ''' - with patch.object(builtins, '__salt_system_encoding__', 'ascii'): - self.assertEqual(salt.utils.data.decode(_b('яйца')), 'яйца') + """ + with patch.object(builtins, "__salt_system_encoding__", "ascii"): + self.assertEqual(salt.utils.data.decode(_b("яйца")), "яйца") def test_encode(self): - ''' + """ NOTE: This uses the lambda "_b" defined above in the global scope, which encodes a string to a bytestring, assuming utf-8. - ''' + """ expected = [ - _b('unicode_str'), - _b('питон'), + _b("unicode_str"), + _b("питон"), 123, 456.789, True, @@ -469,67 +495,71 @@ class DataTestCase(TestCase): None, _b(EGGS), BYTES, - [123, 456.789, _b('спам'), True, False, None, _b(EGGS), BYTES], - (987, 654.321, _b('яйца'), _b(EGGS), None, (True, _b(EGGS), BYTES)), + [123, 456.789, _b("спам"), True, False, None, _b(EGGS), BYTES], + (987, 654.321, _b("яйца"), _b(EGGS), None, (True, _b(EGGS), BYTES)), { - _b('str_key'): _b('str_val'), + _b("str_key"): _b("str_val"), None: True, 123: 456.789, _b(EGGS): BYTES, - _b('subdict'): { - _b('unicode_key'): _b(EGGS), - _b('tuple'): (123, _b('hello'), _b('world'), True, _b(EGGS), BYTES), - _b('list'): [456, _b('спам'), False, _b(EGGS), BYTES] - } + _b("subdict"): { + _b("unicode_key"): _b(EGGS), + _b("tuple"): (123, _b("hello"), _b("world"), True, _b(EGGS), BYTES), + _b("list"): [456, _b("спам"), False, _b(EGGS), BYTES], + }, }, - OrderedDict([(_b('foo'), _b('bar')), (123, 456), (_b(EGGS), BYTES)]) + OrderedDict([(_b("foo"), _b("bar")), (123, 456), (_b(EGGS), BYTES)]), ] # Both keep=True and keep=False should work because the BYTES data is # already bytes. ret = salt.utils.data.encode( - self.test_data, - keep=True, - preserve_dict_class=True, - preserve_tuples=True) + self.test_data, keep=True, preserve_dict_class=True, preserve_tuples=True + ) self.assertEqual(ret, expected) ret = salt.utils.data.encode( - self.test_data, - keep=False, - preserve_dict_class=True, - preserve_tuples=True) + self.test_data, keep=False, preserve_dict_class=True, preserve_tuples=True + ) self.assertEqual(ret, expected) # Now munge the expected data so that we get what we would expect if we # disable preservation of dict class and tuples - expected[10] = [987, 654.321, _b('яйца'), _b(EGGS), None, [True, _b(EGGS), BYTES]] - expected[11][_b('subdict')][_b('tuple')] = [ - 123, _b('hello'), _b('world'), True, _b(EGGS), BYTES + expected[10] = [ + 987, + 654.321, + _b("яйца"), + _b(EGGS), + None, + [True, _b(EGGS), BYTES], ] - expected[12] = {_b('foo'): _b('bar'), 123: 456, _b(EGGS): BYTES} + expected[11][_b("subdict")][_b("tuple")] = [ + 123, + _b("hello"), + _b("world"), + True, + _b(EGGS), + BYTES, + ] + expected[12] = {_b("foo"): _b("bar"), 123: 456, _b(EGGS): BYTES} ret = salt.utils.data.encode( - self.test_data, - keep=True, - preserve_dict_class=False, - preserve_tuples=False) + self.test_data, keep=True, preserve_dict_class=False, preserve_tuples=False + ) self.assertEqual(ret, expected) ret = salt.utils.data.encode( - self.test_data, - keep=False, - preserve_dict_class=False, - preserve_tuples=False) + self.test_data, keep=False, preserve_dict_class=False, preserve_tuples=False + ) self.assertEqual(ret, expected) # Now test single non-string, non-data-structure items, these should # return the same value when passed to this function for item in (123, 4.56, True, False, None): - log.debug('Testing encode of %s', item) + log.debug("Testing encode of %s", item) self.assertEqual(salt.utils.data.encode(item), item) # Test single strings (not in a data structure) - self.assertEqual(salt.utils.data.encode('foo'), _b('foo')) - self.assertEqual(salt.utils.data.encode(_b('bar')), _b('bar')) + self.assertEqual(salt.utils.data.encode("foo"), _b("foo")) + self.assertEqual(salt.utils.data.encode(_b("bar")), _b("bar")) # Test binary blob, nothing should happen even when keep=False since # the data is already bytes @@ -537,41 +567,43 @@ class DataTestCase(TestCase): self.assertEqual(salt.utils.data.encode(BYTES, keep=False), BYTES) def test_encode_keep(self): - ''' + """ Whereas we tested the keep argument in test_decode, it is much easier to do a more comprehensive test of keep in its own function where we can force the encoding. - ''' - unicode_str = 'питон' - encoding = 'ascii' + """ + unicode_str = "питон" + encoding = "ascii" # Test single string self.assertEqual( - salt.utils.data.encode(unicode_str, encoding, keep=True), - unicode_str) + salt.utils.data.encode(unicode_str, encoding, keep=True), unicode_str + ) self.assertRaises( UnicodeEncodeError, salt.utils.data.encode, unicode_str, encoding, - keep=False) + keep=False, + ) data = [ unicode_str, - [b'foo', [unicode_str], {b'key': unicode_str}, (unicode_str,)], - {b'list': [b'foo', unicode_str], - b'dict': {b'key': unicode_str}, - b'tuple': (b'foo', unicode_str)}, - ([b'foo', unicode_str], {b'key': unicode_str}, (unicode_str,)) + [b"foo", [unicode_str], {b"key": unicode_str}, (unicode_str,)], + { + b"list": [b"foo", unicode_str], + b"dict": {b"key": unicode_str}, + b"tuple": (b"foo", unicode_str), + }, + ([b"foo", unicode_str], {b"key": unicode_str}, (unicode_str,)), ] # Since everything was a bytestring aside from the bogus data, the # return data should be identical. We don't need to test recursive # decoding, that has already been tested in test_encode. self.assertEqual( - salt.utils.data.encode(data, encoding, - keep=True, preserve_tuples=True), - data + salt.utils.data.encode(data, encoding, keep=True, preserve_tuples=True), + data, ) self.assertRaises( UnicodeEncodeError, @@ -579,13 +611,15 @@ class DataTestCase(TestCase): data, encoding, keep=False, - preserve_tuples=True) + preserve_tuples=True, + ) for index, _ in enumerate(data): self.assertEqual( - salt.utils.data.encode(data[index], encoding, - keep=True, preserve_tuples=True), - data[index] + salt.utils.data.encode( + data[index], encoding, keep=True, preserve_tuples=True + ), + data[index], ) self.assertRaises( UnicodeEncodeError, @@ -593,31 +627,36 @@ class DataTestCase(TestCase): data[index], encoding, keep=False, - preserve_tuples=True) + preserve_tuples=True, + ) def test_encode_fallback(self): - ''' + """ Test fallback to utf-8 - ''' - with patch.object(builtins, '__salt_system_encoding__', 'ascii'): - self.assertEqual(salt.utils.data.encode('яйца'), _b('яйца')) - with patch.object(builtins, '__salt_system_encoding__', 'CP1252'): - self.assertEqual(salt.utils.data.encode('Ψ'), _b('Ψ')) + """ + with patch.object(builtins, "__salt_system_encoding__", "ascii"): + self.assertEqual(salt.utils.data.encode("яйца"), _b("яйца")) + with patch.object(builtins, "__salt_system_encoding__", "CP1252"): + self.assertEqual(salt.utils.data.encode("Ψ"), _b("Ψ")) def test_repack_dict(self): - list_of_one_element_dicts = [{'dict_key_1': 'dict_val_1'}, - {'dict_key_2': 'dict_val_2'}, - {'dict_key_3': 'dict_val_3'}] - expected_ret = {'dict_key_1': 'dict_val_1', - 'dict_key_2': 'dict_val_2', - 'dict_key_3': 'dict_val_3'} + list_of_one_element_dicts = [ + {"dict_key_1": "dict_val_1"}, + {"dict_key_2": "dict_val_2"}, + {"dict_key_3": "dict_val_3"}, + ] + expected_ret = { + "dict_key_1": "dict_val_1", + "dict_key_2": "dict_val_2", + "dict_key_3": "dict_val_3", + } ret = salt.utils.data.repack_dictlist(list_of_one_element_dicts) self.assertDictEqual(ret, expected_ret) # Try with yaml - yaml_key_val_pair = '- key1: val1' + yaml_key_val_pair = "- key1: val1" ret = salt.utils.data.repack_dictlist(yaml_key_val_pair) - self.assertDictEqual(ret, {'key1': 'val1'}) + self.assertDictEqual(ret, {"key1": "val1"}) # Make sure we handle non-yaml junk data ret = salt.utils.data.repack_dictlist(LOREM_IPSUM) @@ -626,43 +665,47 @@ class DataTestCase(TestCase): def test_stringify(self): self.assertRaises(TypeError, salt.utils.data.stringify, 9) self.assertEqual( - salt.utils.data.stringify(['one', 'two', str('three'), 4, 5]), # future lint: disable=blacklisted-function - ['one', 'two', 'three', '4', '5'] + salt.utils.data.stringify( + ["one", "two", "three", 4, 5] + ), # future lint: disable=blacklisted-function + ["one", "two", "three", "4", "5"], ) def test_json_query(self): # Raises exception if jmespath module is not found - with patch('salt.utils.data.jmespath', None): + with patch("salt.utils.data.jmespath", None): self.assertRaisesRegex( - RuntimeError, 'requires jmespath', - salt.utils.data.json_query, {}, '@' + RuntimeError, "requires jmespath", salt.utils.data.json_query, {}, "@" ) # Test search user_groups = { - 'user1': {'groups': ['group1', 'group2', 'group3']}, - 'user2': {'groups': ['group1', 'group2']}, - 'user3': {'groups': ['group3']}, + "user1": {"groups": ["group1", "group2", "group3"]}, + "user2": {"groups": ["group1", "group2"]}, + "user3": {"groups": ["group3"]}, } - expression = '*.groups[0]' - primary_groups = ['group1', 'group1', 'group3'] + expression = "*.groups[0]" + primary_groups = ["group1", "group1", "group3"] self.assertEqual( - sorted(salt.utils.data.json_query(user_groups, expression)), - primary_groups + sorted(salt.utils.data.json_query(user_groups, expression)), primary_groups ) class FilterFalseyTestCase(TestCase): - ''' + """ Test suite for salt.utils.data.filter_falsey - ''' + """ def test_nop(self): - ''' + """ Test cases where nothing will be done. - ''' + """ # Test with dictionary without recursion - old_dict = {'foo': 'bar', 'bar': {'baz': {'qux': 'quux'}}, 'baz': ['qux', {'foo': 'bar'}]} + old_dict = { + "foo": "bar", + "bar": {"baz": {"qux": "quux"}}, + "baz": ["qux", {"foo": "bar"}], + } new_dict = salt.utils.data.filter_falsey(old_dict) self.assertEqual(old_dict, new_dict) # Check returned type equality @@ -671,23 +714,25 @@ class FilterFalseyTestCase(TestCase): new_dict = salt.utils.data.filter_falsey(old_dict, recurse_depth=3) self.assertEqual(old_dict, new_dict) # Test with list - old_list = ['foo', 'bar'] + old_list = ["foo", "bar"] new_list = salt.utils.data.filter_falsey(old_list) self.assertEqual(old_list, new_list) # Check returned type equality self.assertIs(type(old_list), type(new_list)) # Test with set - old_set = set(['foo', 'bar']) + old_set = {"foo", "bar"} new_set = salt.utils.data.filter_falsey(old_set) self.assertEqual(old_set, new_set) # Check returned type equality self.assertIs(type(old_set), type(new_set)) # Test with OrderedDict - old_dict = OrderedDict([ - ('foo', 'bar'), - ('bar', OrderedDict([('qux', 'quux')])), - ('baz', ['qux', OrderedDict([('foo', 'bar')])]) - ]) + old_dict = OrderedDict( + [ + ("foo", "bar"), + ("bar", OrderedDict([("qux", "quux")])), + ("baz", ["qux", OrderedDict([("foo", "bar")])]), + ] + ) new_dict = salt.utils.data.filter_falsey(old_dict) self.assertEqual(old_dict, new_dict) self.assertIs(type(old_dict), type(new_dict)) @@ -696,8 +741,8 @@ class FilterFalseyTestCase(TestCase): new_list = salt.utils.data.filter_falsey(old_list, ignore_types=[type(0)]) self.assertEqual(old_list, new_list) # Test excluding str (or unicode) (or both) - old_list = [''] - new_list = salt.utils.data.filter_falsey(old_list, ignore_types=[type('')]) + old_list = [""] + new_list = salt.utils.data.filter_falsey(old_list, ignore_types=[type("")]) self.assertEqual(old_list, new_list) # Test excluding list old_list = [[]] @@ -709,185 +754,264 @@ class FilterFalseyTestCase(TestCase): self.assertEqual(old_list, new_list) def test_filter_dict_no_recurse(self): - ''' + """ Test filtering a dictionary without recursing. This will only filter out key-values where the values are falsey. - ''' - old_dict = {'foo': None, - 'bar': {'baz': {'qux': None, 'quux': '', 'foo': []}}, - 'baz': ['qux'], - 'qux': {}, - 'quux': []} + """ + old_dict = { + "foo": None, + "bar": {"baz": {"qux": None, "quux": "", "foo": []}}, + "baz": ["qux"], + "qux": {}, + "quux": [], + } new_dict = salt.utils.data.filter_falsey(old_dict) - expect_dict = {'bar': {'baz': {'qux': None, 'quux': '', 'foo': []}}, 'baz': ['qux']} + expect_dict = { + "bar": {"baz": {"qux": None, "quux": "", "foo": []}}, + "baz": ["qux"], + } self.assertEqual(expect_dict, new_dict) self.assertIs(type(expect_dict), type(new_dict)) def test_filter_dict_recurse(self): - ''' + """ Test filtering a dictionary with recursing. This will filter out any key-values where the values are falsey or when the values *become* falsey after filtering their contents (in case they are lists or dicts). - ''' - old_dict = {'foo': None, - 'bar': {'baz': {'qux': None, 'quux': '', 'foo': []}}, - 'baz': ['qux'], - 'qux': {}, - 'quux': []} + """ + old_dict = { + "foo": None, + "bar": {"baz": {"qux": None, "quux": "", "foo": []}}, + "baz": ["qux"], + "qux": {}, + "quux": [], + } new_dict = salt.utils.data.filter_falsey(old_dict, recurse_depth=3) - expect_dict = {'baz': ['qux']} + expect_dict = {"baz": ["qux"]} self.assertEqual(expect_dict, new_dict) self.assertIs(type(expect_dict), type(new_dict)) def test_filter_list_no_recurse(self): - ''' + """ Test filtering a list without recursing. This will only filter out items which are falsey. - ''' - old_list = ['foo', None, [], {}, 0, ''] + """ + old_list = ["foo", None, [], {}, 0, ""] new_list = salt.utils.data.filter_falsey(old_list) - expect_list = ['foo'] + expect_list = ["foo"] self.assertEqual(expect_list, new_list) self.assertIs(type(expect_list), type(new_list)) # Ensure nested values are *not* filtered out. old_list = [ - 'foo', - ['foo'], - ['foo', None], - {'foo': 0}, - {'foo': 'bar', 'baz': []}, - [{'foo': ''}], + "foo", + ["foo"], + ["foo", None], + {"foo": 0}, + {"foo": "bar", "baz": []}, + [{"foo": ""}], ] new_list = salt.utils.data.filter_falsey(old_list) self.assertEqual(old_list, new_list) self.assertIs(type(old_list), type(new_list)) def test_filter_list_recurse(self): - ''' + """ Test filtering a list with recursing. This will filter out any items which are falsey, or which become falsey after filtering their contents (in case they are lists or dicts). - ''' + """ old_list = [ - 'foo', - ['foo'], - ['foo', None], - {'foo': 0}, - {'foo': 'bar', 'baz': []}, - [{'foo': ''}] + "foo", + ["foo"], + ["foo", None], + {"foo": 0}, + {"foo": "bar", "baz": []}, + [{"foo": ""}], ] new_list = salt.utils.data.filter_falsey(old_list, recurse_depth=3) - expect_list = ['foo', ['foo'], ['foo'], {'foo': 'bar'}] + expect_list = ["foo", ["foo"], ["foo"], {"foo": "bar"}] self.assertEqual(expect_list, new_list) self.assertIs(type(expect_list), type(new_list)) def test_filter_set_no_recurse(self): - ''' + """ Test filtering a set without recursing. Note that a set cannot contain unhashable types, so recursion is not possible. - ''' - old_set = set([ - 'foo', - None, - 0, - '', - ]) + """ + old_set = {"foo", None, 0, ""} new_set = salt.utils.data.filter_falsey(old_set) - expect_set = set(['foo']) + expect_set = {"foo"} self.assertEqual(expect_set, new_set) self.assertIs(type(expect_set), type(new_set)) def test_filter_ordereddict_no_recurse(self): - ''' + """ Test filtering an OrderedDict without recursing. - ''' - old_dict = OrderedDict([ - ('foo', None), - ('bar', OrderedDict([('baz', OrderedDict([('qux', None), ('quux', ''), ('foo', [])]))])), - ('baz', ['qux']), - ('qux', {}), - ('quux', []) - ]) + """ + old_dict = OrderedDict( + [ + ("foo", None), + ( + "bar", + OrderedDict( + [ + ( + "baz", + OrderedDict([("qux", None), ("quux", ""), ("foo", [])]), + ) + ] + ), + ), + ("baz", ["qux"]), + ("qux", {}), + ("quux", []), + ] + ) new_dict = salt.utils.data.filter_falsey(old_dict) - expect_dict = OrderedDict([ - ('bar', OrderedDict([('baz', OrderedDict([('qux', None), ('quux', ''), ('foo', [])]))])), - ('baz', ['qux']), - ]) + expect_dict = OrderedDict( + [ + ( + "bar", + OrderedDict( + [ + ( + "baz", + OrderedDict([("qux", None), ("quux", ""), ("foo", [])]), + ) + ] + ), + ), + ("baz", ["qux"]), + ] + ) self.assertEqual(expect_dict, new_dict) self.assertIs(type(expect_dict), type(new_dict)) def test_filter_ordereddict_recurse(self): - ''' + """ Test filtering an OrderedDict with recursing. - ''' - old_dict = OrderedDict([ - ('foo', None), - ('bar', OrderedDict([('baz', OrderedDict([('qux', None), ('quux', ''), ('foo', [])]))])), - ('baz', ['qux']), - ('qux', {}), - ('quux', []) - ]) + """ + old_dict = OrderedDict( + [ + ("foo", None), + ( + "bar", + OrderedDict( + [ + ( + "baz", + OrderedDict([("qux", None), ("quux", ""), ("foo", [])]), + ) + ] + ), + ), + ("baz", ["qux"]), + ("qux", {}), + ("quux", []), + ] + ) new_dict = salt.utils.data.filter_falsey(old_dict, recurse_depth=3) - expect_dict = OrderedDict([ - ('baz', ['qux']), - ]) + expect_dict = OrderedDict([("baz", ["qux"])]) self.assertEqual(expect_dict, new_dict) self.assertIs(type(expect_dict), type(new_dict)) def test_filter_list_recurse_limit(self): - ''' + """ Test filtering a list with recursing, but with a limited depth. Note that the top-level is always processed, so a recursion depth of 2 means that two *additional* levels are processed. - ''' + """ old_list = [None, [None, [None, [None]]]] new_list = salt.utils.data.filter_falsey(old_list, recurse_depth=2) self.assertEqual([[[[None]]]], new_list) def test_filter_dict_recurse_limit(self): - ''' + """ Test filtering a dict with recursing, but with a limited depth. Note that the top-level is always processed, so a recursion depth of 2 means that two *additional* levels are processed. - ''' - old_dict = {'one': None, - 'foo': {'two': None, 'bar': {'three': None, 'baz': {'four': None}}}} + """ + old_dict = { + "one": None, + "foo": {"two": None, "bar": {"three": None, "baz": {"four": None}}}, + } new_dict = salt.utils.data.filter_falsey(old_dict, recurse_depth=2) - self.assertEqual({'foo': {'bar': {'baz': {'four': None}}}}, new_dict) + self.assertEqual({"foo": {"bar": {"baz": {"four": None}}}}, new_dict) def test_filter_exclude_types(self): - ''' + """ Test filtering a list recursively, but also ignoring (i.e. not filtering) out certain types that can be falsey. - ''' + """ # Ignore int, unicode - old_list = ['foo', ['foo'], ['foo', None], {'foo': 0}, {'foo': 'bar', 'baz': []}, [{'foo': ''}]] - new_list = salt.utils.data.filter_falsey(old_list, recurse_depth=3, ignore_types=[type(0), type('')]) - self.assertEqual(['foo', ['foo'], ['foo'], {'foo': 0}, {'foo': 'bar'}, [{'foo': ''}]], new_list) + old_list = [ + "foo", + ["foo"], + ["foo", None], + {"foo": 0}, + {"foo": "bar", "baz": []}, + [{"foo": ""}], + ] + new_list = salt.utils.data.filter_falsey( + old_list, recurse_depth=3, ignore_types=[type(0), type("")] + ) + self.assertEqual( + ["foo", ["foo"], ["foo"], {"foo": 0}, {"foo": "bar"}, [{"foo": ""}]], + new_list, + ) # Ignore list - old_list = ['foo', ['foo'], ['foo', None], {'foo': 0}, {'foo': 'bar', 'baz': []}, [{'foo': ''}]] - new_list = salt.utils.data.filter_falsey(old_list, recurse_depth=3, ignore_types=[type([])]) - self.assertEqual(['foo', ['foo'], ['foo'], {'foo': 'bar', 'baz': []}, []], new_list) + old_list = [ + "foo", + ["foo"], + ["foo", None], + {"foo": 0}, + {"foo": "bar", "baz": []}, + [{"foo": ""}], + ] + new_list = salt.utils.data.filter_falsey( + old_list, recurse_depth=3, ignore_types=[type([])] + ) + self.assertEqual( + ["foo", ["foo"], ["foo"], {"foo": "bar", "baz": []}, []], new_list + ) # Ignore dict - old_list = ['foo', ['foo'], ['foo', None], {'foo': 0}, {'foo': 'bar', 'baz': []}, [{'foo': ''}]] - new_list = salt.utils.data.filter_falsey(old_list, recurse_depth=3, ignore_types=[type({})]) - self.assertEqual(['foo', ['foo'], ['foo'], {}, {'foo': 'bar'}, [{}]], new_list) + old_list = [ + "foo", + ["foo"], + ["foo", None], + {"foo": 0}, + {"foo": "bar", "baz": []}, + [{"foo": ""}], + ] + new_list = salt.utils.data.filter_falsey( + old_list, recurse_depth=3, ignore_types=[type({})] + ) + self.assertEqual(["foo", ["foo"], ["foo"], {}, {"foo": "bar"}, [{}]], new_list) # Ignore NoneType - old_list = ['foo', ['foo'], ['foo', None], {'foo': 0}, {'foo': 'bar', 'baz': []}, [{'foo': ''}]] - new_list = salt.utils.data.filter_falsey(old_list, recurse_depth=3, ignore_types=[type(None)]) - self.assertEqual(['foo', ['foo'], ['foo', None], {'foo': 'bar'}], new_list) + old_list = [ + "foo", + ["foo"], + ["foo", None], + {"foo": 0}, + {"foo": "bar", "baz": []}, + [{"foo": ""}], + ] + new_list = salt.utils.data.filter_falsey( + old_list, recurse_depth=3, ignore_types=[type(None)] + ) + self.assertEqual(["foo", ["foo"], ["foo", None], {"foo": "bar"}], new_list) class FilterRecursiveDiff(TestCase): - ''' + """ Test suite for salt.utils.data.recursive_diff - ''' + """ def test_list_equality(self): - ''' + """ Test cases where equal lists are compared. - ''' + """ test_list = [0, 1, 2] self.assertEqual({}, salt.utils.data.recursive_diff(test_list, test_list)) @@ -895,392 +1019,455 @@ class FilterRecursiveDiff(TestCase): self.assertEqual({}, salt.utils.data.recursive_diff(test_list, test_list)) def test_dict_equality(self): - ''' + """ Test cases where equal dicts are compared. - ''' - test_dict = {'foo': 'bar', 'bar': {'baz': {'qux': 'quux'}}, 'frop': 0} + """ + test_dict = {"foo": "bar", "bar": {"baz": {"qux": "quux"}}, "frop": 0} self.assertEqual({}, salt.utils.data.recursive_diff(test_dict, test_dict)) def test_ordereddict_equality(self): - ''' + """ Test cases where equal OrderedDicts are compared. - ''' - test_dict = OrderedDict([ - ('foo', 'bar'), - ('bar', OrderedDict([('baz', OrderedDict([('qux', 'quux')]))])), - ('frop', 0)]) + """ + test_dict = OrderedDict( + [ + ("foo", "bar"), + ("bar", OrderedDict([("baz", OrderedDict([("qux", "quux")]))])), + ("frop", 0), + ] + ) self.assertEqual({}, salt.utils.data.recursive_diff(test_dict, test_dict)) def test_mixed_equality(self): - ''' + """ Test cases where mixed nested lists and dicts are compared. - ''' + """ test_data = { - 'foo': 'bar', - 'baz': [0, 1, 2], - 'bar': {'baz': [{'qux': 'quux'}, {'froop', 0}]} + "foo": "bar", + "baz": [0, 1, 2], + "bar": {"baz": [{"qux": "quux"}, {"froop", 0}]}, } self.assertEqual({}, salt.utils.data.recursive_diff(test_data, test_data)) def test_set_equality(self): - ''' + """ Test cases where equal sets are compared. - ''' - test_set = set([0, 1, 2, 3, 'foo']) + """ + test_set = {0, 1, 2, 3, "foo"} self.assertEqual({}, salt.utils.data.recursive_diff(test_set, test_set)) # This is a bit of an oddity, as python seems to sort the sets in memory # so both sets end up with the same ordering (0..3). - set_one = set([0, 1, 2, 3]) - set_two = set([3, 2, 1, 0]) + set_one = {0, 1, 2, 3} + set_two = {3, 2, 1, 0} self.assertEqual({}, salt.utils.data.recursive_diff(set_one, set_two)) def test_tuple_equality(self): - ''' + """ Test cases where equal tuples are compared. - ''' - test_tuple = (0, 1, 2, 3, 'foo') + """ + test_tuple = (0, 1, 2, 3, "foo") self.assertEqual({}, salt.utils.data.recursive_diff(test_tuple, test_tuple)) def test_list_inequality(self): - ''' + """ Test cases where two inequal lists are compared. - ''' + """ list_one = [0, 1, 2] - list_two = ['foo', 'bar', 'baz'] - expected_result = {'old': list_one, 'new': list_two} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(list_one, list_two)) - expected_result = {'new': list_one, 'old': list_two} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(list_two, list_one)) - - list_one = [0, 'foo', 1, 'bar'] - list_two = [1, 'foo', 1, 'qux'] - expected_result = {'old': [0, 'bar'], 'new': [1, 'qux']} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(list_one, list_two)) - expected_result = {'new': [0, 'bar'], 'old': [1, 'qux']} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(list_two, list_one)) + list_two = ["foo", "bar", "baz"] + expected_result = {"old": list_one, "new": list_two} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(list_one, list_two) + ) + expected_result = {"new": list_one, "old": list_two} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(list_two, list_one) + ) + + list_one = [0, "foo", 1, "bar"] + list_two = [1, "foo", 1, "qux"] + expected_result = {"old": [0, "bar"], "new": [1, "qux"]} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(list_one, list_two) + ) + expected_result = {"new": [0, "bar"], "old": [1, "qux"]} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(list_two, list_one) + ) list_one = [0, 1, [2, 3]] - list_two = [0, 1, ['foo', 'bar']] - expected_result = {'old': [[2, 3]], 'new': [['foo', 'bar']]} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(list_one, list_two)) - expected_result = {'new': [[2, 3]], 'old': [['foo', 'bar']]} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(list_two, list_one)) + list_two = [0, 1, ["foo", "bar"]] + expected_result = {"old": [[2, 3]], "new": [["foo", "bar"]]} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(list_one, list_two) + ) + expected_result = {"new": [[2, 3]], "old": [["foo", "bar"]]} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(list_two, list_one) + ) def test_dict_inequality(self): - ''' + """ Test cases where two inequal dicts are compared. - ''' - dict_one = {'foo': 1, 'bar': 2, 'baz': 3} - dict_two = {'foo': 2, 1: 'bar', 'baz': 3} - expected_result = {'old': {'foo': 1, 'bar': 2}, 'new': {'foo': 2, 1: 'bar'}} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(dict_one, dict_two)) - expected_result = {'new': {'foo': 1, 'bar': 2}, 'old': {'foo': 2, 1: 'bar'}} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(dict_two, dict_one)) - - dict_one = {'foo': {'bar': {'baz': 1}}} - dict_two = {'foo': {'qux': {'baz': 1}}} - expected_result = {'old': dict_one, 'new': dict_two} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(dict_one, dict_two)) - expected_result = {'new': dict_one, 'old': dict_two} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(dict_two, dict_one)) + """ + dict_one = {"foo": 1, "bar": 2, "baz": 3} + dict_two = {"foo": 2, 1: "bar", "baz": 3} + expected_result = {"old": {"foo": 1, "bar": 2}, "new": {"foo": 2, 1: "bar"}} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(dict_one, dict_two) + ) + expected_result = {"new": {"foo": 1, "bar": 2}, "old": {"foo": 2, 1: "bar"}} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(dict_two, dict_one) + ) + + dict_one = {"foo": {"bar": {"baz": 1}}} + dict_two = {"foo": {"qux": {"baz": 1}}} + expected_result = {"old": dict_one, "new": dict_two} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(dict_one, dict_two) + ) + expected_result = {"new": dict_one, "old": dict_two} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(dict_two, dict_one) + ) def test_ordereddict_inequality(self): - ''' + """ Test cases where two inequal OrderedDicts are compared. - ''' - odict_one = OrderedDict([('foo', 'bar'), ('bar', 'baz')]) - odict_two = OrderedDict([('bar', 'baz'), ('foo', 'bar')]) - expected_result = {'old': odict_one, 'new': odict_two} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(odict_one, odict_two)) + """ + odict_one = OrderedDict([("foo", "bar"), ("bar", "baz")]) + odict_two = OrderedDict([("bar", "baz"), ("foo", "bar")]) + expected_result = {"old": odict_one, "new": odict_two} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(odict_one, odict_two) + ) def test_set_inequality(self): - ''' + """ Test cases where two inequal sets are compared. Tricky as the sets are compared zipped, so shuffled sets of equal values are considered different. - ''' - set_one = set([0, 1, 2, 4]) - set_two = set([0, 1, 3, 4]) - expected_result = {'old': set([2]), 'new': set([3])} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(set_one, set_two)) - expected_result = {'new': set([2]), 'old': set([3])} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(set_two, set_one)) + """ + set_one = {0, 1, 2, 4} + set_two = {0, 1, 3, 4} + expected_result = {"old": {2}, "new": {3}} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(set_one, set_two) + ) + expected_result = {"new": {2}, "old": {3}} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(set_two, set_one) + ) # It is unknown how different python versions will store sets in memory. # Python 2.7 seems to sort it (i.e. set_one below becomes {0, 1, 'foo', 'bar'} # However Python 3.6.8 stores it differently each run. # So just test for "not equal" here. - set_one = set([0, 'foo', 1, 'bar']) - set_two = set(['foo', 1, 'bar', 2]) + set_one = {0, "foo", 1, "bar"} + set_two = {"foo", 1, "bar", 2} expected_result = {} - self.assertNotEqual(expected_result, salt.utils.data.recursive_diff(set_one, set_two)) + self.assertNotEqual( + expected_result, salt.utils.data.recursive_diff(set_one, set_two) + ) def test_mixed_inequality(self): - ''' + """ Test cases where two mixed dicts/iterables that are different are compared. - ''' - dict_one = {'foo': [1, 2, 3]} - dict_two = {'foo': [3, 2, 1]} - expected_result = {'old': {'foo': [1, 3]}, 'new': {'foo': [3, 1]}} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(dict_one, dict_two)) - expected_result = {'new': {'foo': [1, 3]}, 'old': {'foo': [3, 1]}} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(dict_two, dict_one)) - - list_one = [1, 2, {'foo': ['bar', {'foo': 1, 'bar': 2}]}] - list_two = [3, 4, {'foo': ['qux', {'foo': 1, 'bar': 2}]}] - expected_result = {'old': [1, 2, {'foo': ['bar']}], 'new': [3, 4, {'foo': ['qux']}]} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(list_one, list_two)) - expected_result = {'new': [1, 2, {'foo': ['bar']}], 'old': [3, 4, {'foo': ['qux']}]} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(list_two, list_one)) - - mixed_one = {'foo': set([0, 1, 2]), 'bar': [0, 1, 2]} - mixed_two = {'foo': set([1, 2, 3]), 'bar': [1, 2, 3]} + """ + dict_one = {"foo": [1, 2, 3]} + dict_two = {"foo": [3, 2, 1]} + expected_result = {"old": {"foo": [1, 3]}, "new": {"foo": [3, 1]}} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(dict_one, dict_two) + ) + expected_result = {"new": {"foo": [1, 3]}, "old": {"foo": [3, 1]}} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(dict_two, dict_one) + ) + + list_one = [1, 2, {"foo": ["bar", {"foo": 1, "bar": 2}]}] + list_two = [3, 4, {"foo": ["qux", {"foo": 1, "bar": 2}]}] expected_result = { - 'old': {'foo': set([0]), 'bar': [0, 1, 2]}, - 'new': {'foo': set([3]), 'bar': [1, 2, 3]} + "old": [1, 2, {"foo": ["bar"]}], + "new": [3, 4, {"foo": ["qux"]}], } - self.assertEqual(expected_result, salt.utils.data.recursive_diff(mixed_one, mixed_two)) + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(list_one, list_two) + ) + expected_result = { + "new": [1, 2, {"foo": ["bar"]}], + "old": [3, 4, {"foo": ["qux"]}], + } + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(list_two, list_one) + ) + + mixed_one = {"foo": {0, 1, 2}, "bar": [0, 1, 2]} + mixed_two = {"foo": {1, 2, 3}, "bar": [1, 2, 3]} expected_result = { - 'new': {'foo': set([0]), 'bar': [0, 1, 2]}, - 'old': {'foo': set([3]), 'bar': [1, 2, 3]} + "old": {"foo": {0}, "bar": [0, 1, 2]}, + "new": {"foo": {3}, "bar": [1, 2, 3]}, } - self.assertEqual(expected_result, salt.utils.data.recursive_diff(mixed_two, mixed_one)) + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(mixed_one, mixed_two) + ) + expected_result = { + "new": {"foo": {0}, "bar": [0, 1, 2]}, + "old": {"foo": {3}, "bar": [1, 2, 3]}, + } + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(mixed_two, mixed_one) + ) def test_tuple_inequality(self): - ''' + """ Test cases where two tuples that are different are compared. - ''' + """ tuple_one = (1, 2, 3) tuple_two = (3, 2, 1) - expected_result = {'old': (1, 3), 'new': (3, 1)} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(tuple_one, tuple_two)) + expected_result = {"old": (1, 3), "new": (3, 1)} + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(tuple_one, tuple_two) + ) def test_list_vs_set(self): - ''' + """ Test case comparing a list with a set, will be compared unordered. - ''' + """ mixed_one = [1, 2, 3] - mixed_two = set([3, 2, 1]) + mixed_two = {3, 2, 1} expected_result = {} - self.assertEqual(expected_result, salt.utils.data.recursive_diff(mixed_one, mixed_two)) - self.assertEqual(expected_result, salt.utils.data.recursive_diff(mixed_two, mixed_one)) + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(mixed_one, mixed_two) + ) + self.assertEqual( + expected_result, salt.utils.data.recursive_diff(mixed_two, mixed_one) + ) def test_dict_vs_ordereddict(self): - ''' + """ Test case comparing a dict with an ordereddict, will be compared unordered. - ''' - test_dict = {'foo': 'bar', 'bar': 'baz'} - test_odict = OrderedDict([('foo', 'bar'), ('bar', 'baz')]) + """ + test_dict = {"foo": "bar", "bar": "baz"} + test_odict = OrderedDict([("foo", "bar"), ("bar", "baz")]) self.assertEqual({}, salt.utils.data.recursive_diff(test_dict, test_odict)) self.assertEqual({}, salt.utils.data.recursive_diff(test_odict, test_dict)) - test_odict2 = OrderedDict([('bar', 'baz'), ('foo', 'bar')]) + test_odict2 = OrderedDict([("bar", "baz"), ("foo", "bar")]) self.assertEqual({}, salt.utils.data.recursive_diff(test_dict, test_odict2)) self.assertEqual({}, salt.utils.data.recursive_diff(test_odict2, test_dict)) def test_list_ignore_ignored(self): - ''' + """ Test case comparing two lists with ignore-list supplied (which is not used when comparing lists). - ''' + """ list_one = [1, 2, 3] list_two = [3, 2, 1] - expected_result = {'old': [1, 3], 'new': [3, 1]} + expected_result = {"old": [1, 3], "new": [3, 1]} self.assertEqual( expected_result, - salt.utils.data.recursive_diff(list_one, list_two, ignore_keys=[1, 3]) + salt.utils.data.recursive_diff(list_one, list_two, ignore_keys=[1, 3]), ) def test_dict_ignore(self): - ''' + """ Test case comparing two dicts with ignore-list supplied. - ''' - dict_one = {'foo': 1, 'bar': 2, 'baz': 3} - dict_two = {'foo': 3, 'bar': 2, 'baz': 1} - expected_result = {'old': {'baz': 3}, 'new': {'baz': 1}} + """ + dict_one = {"foo": 1, "bar": 2, "baz": 3} + dict_two = {"foo": 3, "bar": 2, "baz": 1} + expected_result = {"old": {"baz": 3}, "new": {"baz": 1}} self.assertEqual( expected_result, - salt.utils.data.recursive_diff(dict_one, dict_two, ignore_keys=['foo']) + salt.utils.data.recursive_diff(dict_one, dict_two, ignore_keys=["foo"]), ) def test_ordereddict_ignore(self): - ''' + """ Test case comparing two OrderedDicts with ignore-list supplied. - ''' - odict_one = OrderedDict([('foo', 1), ('bar', 2), ('baz', 3)]) - odict_two = OrderedDict([('baz', 1), ('bar', 2), ('foo', 3)]) + """ + odict_one = OrderedDict([("foo", 1), ("bar", 2), ("baz", 3)]) + odict_two = OrderedDict([("baz", 1), ("bar", 2), ("foo", 3)]) # The key 'foo' will be ignored, which means the key from the other OrderedDict # will always be considered "different" since OrderedDicts are compared ordered. - expected_result = {'old': OrderedDict([('baz', 3)]), 'new': OrderedDict([('baz', 1)])} + expected_result = { + "old": OrderedDict([("baz", 3)]), + "new": OrderedDict([("baz", 1)]), + } self.assertEqual( expected_result, - salt.utils.data.recursive_diff(odict_one, odict_two, ignore_keys=['foo']) + salt.utils.data.recursive_diff(odict_one, odict_two, ignore_keys=["foo"]), ) def test_dict_vs_ordereddict_ignore(self): - ''' + """ Test case comparing a dict with an OrderedDict with ignore-list supplied. - ''' - dict_one = {'foo': 1, 'bar': 2, 'baz': 3} - odict_two = OrderedDict([('foo', 3), ('bar', 2), ('baz', 1)]) - expected_result = {'old': {'baz': 3}, 'new': OrderedDict([('baz', 1)])} + """ + dict_one = {"foo": 1, "bar": 2, "baz": 3} + odict_two = OrderedDict([("foo", 3), ("bar", 2), ("baz", 1)]) + expected_result = {"old": {"baz": 3}, "new": OrderedDict([("baz", 1)])} self.assertEqual( expected_result, - salt.utils.data.recursive_diff(dict_one, odict_two, ignore_keys=['foo']) + salt.utils.data.recursive_diff(dict_one, odict_two, ignore_keys=["foo"]), ) def test_mixed_nested_ignore(self): - ''' + """ Test case comparing mixed, nested items with ignore-list supplied. - ''' - dict_one = {'foo': [1], 'bar': {'foo': 1, 'bar': 2}, 'baz': 3} - dict_two = {'foo': [2], 'bar': {'foo': 3, 'bar': 2}, 'baz': 1} - expected_result = {'old': {'baz': 3}, 'new': {'baz': 1}} + """ + dict_one = {"foo": [1], "bar": {"foo": 1, "bar": 2}, "baz": 3} + dict_two = {"foo": [2], "bar": {"foo": 3, "bar": 2}, "baz": 1} + expected_result = {"old": {"baz": 3}, "new": {"baz": 1}} self.assertEqual( expected_result, - salt.utils.data.recursive_diff(dict_one, dict_two, ignore_keys=['foo']) + salt.utils.data.recursive_diff(dict_one, dict_two, ignore_keys=["foo"]), ) def test_ordered_dict_unequal_length(self): - ''' + """ Test case comparing two OrderedDicts of unequal length. - ''' - odict_one = OrderedDict([('foo', 1), ('bar', 2), ('baz', 3)]) - odict_two = OrderedDict([('foo', 1), ('bar', 2)]) - expected_result = {'old': OrderedDict([('baz', 3)]), 'new': {}} + """ + odict_one = OrderedDict([("foo", 1), ("bar", 2), ("baz", 3)]) + odict_two = OrderedDict([("foo", 1), ("bar", 2)]) + expected_result = {"old": OrderedDict([("baz", 3)]), "new": {}} self.assertEqual( - expected_result, - salt.utils.data.recursive_diff(odict_one, odict_two) + expected_result, salt.utils.data.recursive_diff(odict_one, odict_two) ) def test_list_unequal_length(self): - ''' + """ Test case comparing two lists of unequal length. - ''' + """ list_one = [1, 2, 3] list_two = [1, 2, 3, 4] - expected_result = {'old': [], 'new': [4]} + expected_result = {"old": [], "new": [4]} self.assertEqual( - expected_result, - salt.utils.data.recursive_diff(list_one, list_two) + expected_result, salt.utils.data.recursive_diff(list_one, list_two) ) def test_set_unequal_length(self): - ''' + """ Test case comparing two sets of unequal length. This does not do anything special, as it is unordered. - ''' - set_one = set([1, 2, 3]) - set_two = set([4, 3, 2, 1]) - expected_result = {'old': set([]), 'new': set([4])} + """ + set_one = {1, 2, 3} + set_two = {4, 3, 2, 1} + expected_result = {"old": set(), "new": {4}} self.assertEqual( - expected_result, - salt.utils.data.recursive_diff(set_one, set_two) + expected_result, salt.utils.data.recursive_diff(set_one, set_two) ) def test_tuple_unequal_length(self): - ''' + """ Test case comparing two tuples of unequal length. This should be the same as comparing two ordered lists. - ''' + """ tuple_one = (1, 2, 3) tuple_two = (1, 2, 3, 4) - expected_result = {'old': (), 'new': (4,)} + expected_result = {"old": (), "new": (4,)} self.assertEqual( - expected_result, - salt.utils.data.recursive_diff(tuple_one, tuple_two) + expected_result, salt.utils.data.recursive_diff(tuple_one, tuple_two) ) def test_list_unordered(self): - ''' + """ Test case comparing two lists unordered. - ''' + """ list_one = [1, 2, 3, 4] list_two = [4, 3, 2] - expected_result = {'old': [1], 'new': []} + expected_result = {"old": [1], "new": []} self.assertEqual( expected_result, - salt.utils.data.recursive_diff(list_one, list_two, ignore_order=True) + salt.utils.data.recursive_diff(list_one, list_two, ignore_order=True), ) def test_mixed_nested_unordered(self): - ''' + """ Test case comparing nested dicts/lists unordered. - ''' - dict_one = {'foo': {'bar': [1, 2, 3]}, 'bar': [{'foo': 4}, 0]} - dict_two = {'foo': {'bar': [3, 2, 1]}, 'bar': [0, {'foo': 4}]} + """ + dict_one = {"foo": {"bar": [1, 2, 3]}, "bar": [{"foo": 4}, 0]} + dict_two = {"foo": {"bar": [3, 2, 1]}, "bar": [0, {"foo": 4}]} expected_result = {} self.assertEqual( expected_result, - salt.utils.data.recursive_diff(dict_one, dict_two, ignore_order=True) + salt.utils.data.recursive_diff(dict_one, dict_two, ignore_order=True), ) expected_result = { - 'old': {'foo': {'bar': [1, 3]}, 'bar': [{'foo': 4}, 0]}, - 'new': {'foo': {'bar': [3, 1]}, 'bar': [0, {'foo': 4}]}, + "old": {"foo": {"bar": [1, 3]}, "bar": [{"foo": 4}, 0]}, + "new": {"foo": {"bar": [3, 1]}, "bar": [0, {"foo": 4}]}, } self.assertEqual( - expected_result, - salt.utils.data.recursive_diff(dict_one, dict_two) + expected_result, salt.utils.data.recursive_diff(dict_one, dict_two) ) def test_ordered_dict_unordered(self): - ''' + """ Test case comparing OrderedDicts unordered. - ''' - odict_one = OrderedDict([('foo', 1), ('bar', 2), ('baz', 3)]) - odict_two = OrderedDict([('baz', 3), ('bar', 2), ('foo', 1)]) + """ + odict_one = OrderedDict([("foo", 1), ("bar", 2), ("baz", 3)]) + odict_two = OrderedDict([("baz", 3), ("bar", 2), ("foo", 1)]) expected_result = {} self.assertEqual( expected_result, - salt.utils.data.recursive_diff(odict_one, odict_two, ignore_order=True) + salt.utils.data.recursive_diff(odict_one, odict_two, ignore_order=True), ) def test_ignore_missing_keys_dict(self): - ''' + """ Test case ignoring missing keys on a comparison of dicts. - ''' - dict_one = {'foo': 1, 'bar': 2, 'baz': 3} - dict_two = {'bar': 3} - expected_result = {'old': {'bar': 2}, 'new': {'bar': 3}} + """ + dict_one = {"foo": 1, "bar": 2, "baz": 3} + dict_two = {"bar": 3} + expected_result = {"old": {"bar": 2}, "new": {"bar": 3}} self.assertEqual( expected_result, - salt.utils.data.recursive_diff(dict_one, dict_two, ignore_missing_keys=True) + salt.utils.data.recursive_diff( + dict_one, dict_two, ignore_missing_keys=True + ), ) def test_ignore_missing_keys_ordered_dict(self): - ''' + """ Test case not ignoring missing keys on a comparison of OrderedDicts. - ''' - odict_one = OrderedDict([('foo', 1), ('bar', 2), ('baz', 3)]) - odict_two = OrderedDict([('bar', 3)]) - expected_result = {'old': odict_one, 'new': odict_two} + """ + odict_one = OrderedDict([("foo", 1), ("bar", 2), ("baz", 3)]) + odict_two = OrderedDict([("bar", 3)]) + expected_result = {"old": odict_one, "new": odict_two} self.assertEqual( expected_result, - salt.utils.data.recursive_diff(odict_one, odict_two, ignore_missing_keys=True) + salt.utils.data.recursive_diff( + odict_one, odict_two, ignore_missing_keys=True + ), ) def test_ignore_missing_keys_recursive(self): - ''' + """ Test case ignoring missing keys on a comparison of nested dicts. - ''' - dict_one = {'foo': {'bar': 2, 'baz': 3}} - dict_two = {'foo': {'baz': 3}} + """ + dict_one = {"foo": {"bar": 2, "baz": 3}} + dict_two = {"foo": {"baz": 3}} expected_result = {} self.assertEqual( expected_result, - salt.utils.data.recursive_diff(dict_one, dict_two, ignore_missing_keys=True) + salt.utils.data.recursive_diff( + dict_one, dict_two, ignore_missing_keys=True + ), ) # Compare from dict-in-dict dict_two = {} self.assertEqual( expected_result, - salt.utils.data.recursive_diff(dict_one, dict_two, ignore_missing_keys=True) + salt.utils.data.recursive_diff( + dict_one, dict_two, ignore_missing_keys=True + ), ) # Compare from dict-in-list - dict_one = {'foo': ['bar', {'baz': 3}]} - dict_two = {'foo': ['bar', {}]} + dict_one = {"foo": ["bar", {"baz": 3}]} + dict_two = {"foo": ["bar", {}]} self.assertEqual( expected_result, - salt.utils.data.recursive_diff(dict_one, dict_two, ignore_missing_keys=True) + salt.utils.data.recursive_diff( + dict_one, dict_two, ignore_missing_keys=True + ), ) diff --git a/tests/unit/utils/test_xmlutil.py b/tests/unit/utils/test_xmlutil.py index c04f39498e..cbf73861e5 100644 --- a/tests/unit/utils/test_xmlutil.py +++ b/tests/unit/utils/test_xmlutil.py @@ -1,148 +1,170 @@ -# -*- coding: utf-8 -*- -''' +""" tests.unit.xmlutil_test ~~~~~~~~~~~~~~~~~~~~ -''' -from __future__ import absolute_import, print_function, unicode_literals -# Import Salt Testing libs -from tests.support.unit import TestCase +""" +import salt.utils.xmlutil as xml # Import Salt libs from salt._compat import ElementTree as ET -import salt.utils.xmlutil as xml + +# Import Salt Testing libs +from tests.support.unit import TestCase class XMLUtilTestCase(TestCase): - ''' + """ Tests that salt.utils.xmlutil properly parses XML data and returns as a properly formatted dictionary. The default method of parsing will ignore attributes and return only the child items. The full method will include parsing attributes. - ''' + """ def setUp(self): # Populate our use cases for specific XML formats. self.cases = { - 'a': { - 'xml': 'data', - 'legacy': {'parent': 'data'}, - 'full': 'data' + "a": { + "xml": "data", + "legacy": {"parent": "data"}, + "full": "data", }, - 'b': { - 'xml': 'data', - 'legacy': {'parent': 'data'}, - 'full': {'parent': 'data', 'value': 'data'} + "b": { + "xml": 'data', + "legacy": {"parent": "data"}, + "full": {"parent": "data", "value": "data"}, }, - 'c': { - 'xml': 'datadata' - '', - 'legacy': {'child': ['data', {'child': 'data'}, {'child': None}, {'child': None}]}, - 'full': {'child': ['data', {'child': 'data', 'value': 'data'}, {'value': 'data'}, None]} + "c": { + "xml": 'datadata' + '', + "legacy": { + "child": [ + "data", + {"child": "data"}, + {"child": None}, + {"child": None}, + ] + }, + "full": { + "child": [ + "data", + {"child": "data", "value": "data"}, + {"value": "data"}, + None, + ] + }, }, - 'd': { - 'xml': 'data', - 'legacy': {'child': 'data'}, - 'full': {'child': 'data', 'another': 'data', 'value': 'data'} + "d": { + "xml": 'data', + "legacy": {"child": "data"}, + "full": {"child": "data", "another": "data", "value": "data"}, }, - 'e': { - 'xml': 'data', - 'legacy': {'child': 'data'}, - 'full': {'child': {'child': 'data', 'value': 'data'}, 'another': 'data', 'value': 'data'} + "e": { + "xml": 'data', + "legacy": {"child": "data"}, + "full": { + "child": {"child": "data", "value": "data"}, + "another": "data", + "value": "data", + }, }, - 'f': { - 'xml': 'data' - 'data', - 'legacy': {'child': [{'sub-child': 'data'}, {'child': 'data'}]}, - 'full': {'child': [{'sub-child': {'value': 'data', 'sub-child': 'data'}}, 'data']} + "f": { + "xml": 'data' + "data", + "legacy": {"child": [{"sub-child": "data"}, {"child": "data"}]}, + "full": { + "child": [ + {"sub-child": {"value": "data", "sub-child": "data"}}, + "data", + ] + }, }, } def test_xml_case_a(self): - xmldata = ET.fromstring(self.cases['a']['xml']) + xmldata = ET.fromstring(self.cases["a"]["xml"]) defaultdict = xml.to_dict(xmldata) - self.assertEqual(defaultdict, self.cases['a']['legacy']) + self.assertEqual(defaultdict, self.cases["a"]["legacy"]) def test_xml_case_a_legacy(self): - xmldata = ET.fromstring(self.cases['a']['xml']) + xmldata = ET.fromstring(self.cases["a"]["xml"]) defaultdict = xml.to_dict(xmldata, False) - self.assertEqual(defaultdict, self.cases['a']['legacy']) + self.assertEqual(defaultdict, self.cases["a"]["legacy"]) def test_xml_case_a_full(self): - xmldata = ET.fromstring(self.cases['a']['xml']) + xmldata = ET.fromstring(self.cases["a"]["xml"]) defaultdict = xml.to_dict(xmldata, True) - self.assertEqual(defaultdict, self.cases['a']['full']) + self.assertEqual(defaultdict, self.cases["a"]["full"]) def test_xml_case_b(self): - xmldata = ET.fromstring(self.cases['b']['xml']) + xmldata = ET.fromstring(self.cases["b"]["xml"]) defaultdict = xml.to_dict(xmldata) - self.assertEqual(defaultdict, self.cases['b']['legacy']) + self.assertEqual(defaultdict, self.cases["b"]["legacy"]) def test_xml_case_b_legacy(self): - xmldata = ET.fromstring(self.cases['b']['xml']) + xmldata = ET.fromstring(self.cases["b"]["xml"]) defaultdict = xml.to_dict(xmldata, False) - self.assertEqual(defaultdict, self.cases['b']['legacy']) + self.assertEqual(defaultdict, self.cases["b"]["legacy"]) def test_xml_case_b_full(self): - xmldata = ET.fromstring(self.cases['b']['xml']) + xmldata = ET.fromstring(self.cases["b"]["xml"]) defaultdict = xml.to_dict(xmldata, True) - self.assertEqual(defaultdict, self.cases['b']['full']) + self.assertEqual(defaultdict, self.cases["b"]["full"]) def test_xml_case_c(self): - xmldata = ET.fromstring(self.cases['c']['xml']) + xmldata = ET.fromstring(self.cases["c"]["xml"]) defaultdict = xml.to_dict(xmldata) - self.assertEqual(defaultdict, self.cases['c']['legacy']) + self.assertEqual(defaultdict, self.cases["c"]["legacy"]) def test_xml_case_c_legacy(self): - xmldata = ET.fromstring(self.cases['c']['xml']) + xmldata = ET.fromstring(self.cases["c"]["xml"]) defaultdict = xml.to_dict(xmldata, False) - self.assertEqual(defaultdict, self.cases['c']['legacy']) + self.assertEqual(defaultdict, self.cases["c"]["legacy"]) def test_xml_case_c_full(self): - xmldata = ET.fromstring(self.cases['c']['xml']) + xmldata = ET.fromstring(self.cases["c"]["xml"]) defaultdict = xml.to_dict(xmldata, True) - self.assertEqual(defaultdict, self.cases['c']['full']) + self.assertEqual(defaultdict, self.cases["c"]["full"]) def test_xml_case_d(self): - xmldata = ET.fromstring(self.cases['d']['xml']) + xmldata = ET.fromstring(self.cases["d"]["xml"]) defaultdict = xml.to_dict(xmldata) - self.assertEqual(defaultdict, self.cases['d']['legacy']) + self.assertEqual(defaultdict, self.cases["d"]["legacy"]) def test_xml_case_d_legacy(self): - xmldata = ET.fromstring(self.cases['d']['xml']) + xmldata = ET.fromstring(self.cases["d"]["xml"]) defaultdict = xml.to_dict(xmldata, False) - self.assertEqual(defaultdict, self.cases['d']['legacy']) + self.assertEqual(defaultdict, self.cases["d"]["legacy"]) def test_xml_case_d_full(self): - xmldata = ET.fromstring(self.cases['d']['xml']) + xmldata = ET.fromstring(self.cases["d"]["xml"]) defaultdict = xml.to_dict(xmldata, True) - self.assertEqual(defaultdict, self.cases['d']['full']) + self.assertEqual(defaultdict, self.cases["d"]["full"]) def test_xml_case_e(self): - xmldata = ET.fromstring(self.cases['e']['xml']) + xmldata = ET.fromstring(self.cases["e"]["xml"]) defaultdict = xml.to_dict(xmldata) - self.assertEqual(defaultdict, self.cases['e']['legacy']) + self.assertEqual(defaultdict, self.cases["e"]["legacy"]) def test_xml_case_e_legacy(self): - xmldata = ET.fromstring(self.cases['e']['xml']) + xmldata = ET.fromstring(self.cases["e"]["xml"]) defaultdict = xml.to_dict(xmldata, False) - self.assertEqual(defaultdict, self.cases['e']['legacy']) + self.assertEqual(defaultdict, self.cases["e"]["legacy"]) def test_xml_case_e_full(self): - xmldata = ET.fromstring(self.cases['e']['xml']) + xmldata = ET.fromstring(self.cases["e"]["xml"]) defaultdict = xml.to_dict(xmldata, True) - self.assertEqual(defaultdict, self.cases['e']['full']) + self.assertEqual(defaultdict, self.cases["e"]["full"]) def test_xml_case_f(self): - xmldata = ET.fromstring(self.cases['f']['xml']) + xmldata = ET.fromstring(self.cases["f"]["xml"]) defaultdict = xml.to_dict(xmldata) - self.assertEqual(defaultdict, self.cases['f']['legacy']) + self.assertEqual(defaultdict, self.cases["f"]["legacy"]) def test_xml_case_f_legacy(self): - xmldata = ET.fromstring(self.cases['f']['xml']) + xmldata = ET.fromstring(self.cases["f"]["xml"]) defaultdict = xml.to_dict(xmldata, False) - self.assertEqual(defaultdict, self.cases['f']['legacy']) + self.assertEqual(defaultdict, self.cases["f"]["legacy"]) def test_xml_case_f_full(self): - xmldata = ET.fromstring(self.cases['f']['xml']) + xmldata = ET.fromstring(self.cases["f"]["xml"]) defaultdict = xml.to_dict(xmldata, True) - self.assertEqual(defaultdict, self.cases['f']['full']) + self.assertEqual(defaultdict, self.cases["f"]["full"]) -- 2.28.0