From 90236c844cfce7da8beb7a570be19a8677c60820 Mon Sep 17 00:00:00 2001 From: Victor Zhestkov Date: Tue, 12 Apr 2022 10:06:43 +0300 Subject: [PATCH] Prevent affection of SSH.opts with LazyLoader (bsc#1197637) * Prevent affection SSH.opts with LazyLoader * Restore parsed targets * Fix test_ssh unit tests Adjust unit tests --- salt/client/ssh/__init__.py | 19 +++++++++------- .../pytests/unit/client/ssh/test_password.py | 4 +++- .../unit/client/ssh/test_return_events.py | 2 +- tests/pytests/unit/client/ssh/test_ssh.py | 22 +++++++++---------- 4 files changed, 26 insertions(+), 21 deletions(-) diff --git a/salt/client/ssh/__init__.py b/salt/client/ssh/__init__.py index a527c03de6..d5a679821e 100644 --- a/salt/client/ssh/__init__.py +++ b/salt/client/ssh/__init__.py @@ -224,15 +224,16 @@ class SSH(MultiprocessingStateMixin): ROSTER_UPDATE_FLAG = "#__needs_update" def __init__(self, opts, context=None): + self.opts = copy.deepcopy(opts) + self.sopts = copy.deepcopy(self.opts) self.__parsed_rosters = {SSH.ROSTER_UPDATE_FLAG: True} - pull_sock = os.path.join(opts["sock_dir"], "master_event_pull.ipc") + pull_sock = os.path.join(self.opts["sock_dir"], "master_event_pull.ipc") if os.path.exists(pull_sock) and zmq: self.event = salt.utils.event.get_event( - "master", opts["sock_dir"], opts=opts, listen=False + "master", self.opts["sock_dir"], opts=self.opts, listen=False ) else: self.event = None - self.opts = opts if self.opts["regen_thin"]: self.opts["ssh_wipe"] = True if not salt.utils.path.which("ssh"): @@ -243,7 +244,7 @@ class SSH(MultiprocessingStateMixin): " to run. Exiting." ), ) - self.opts["_ssh_version"] = ssh_version() + self.sopts["_ssh_version"] = ssh_version() self.tgt_type = ( self.opts["selected_target_option"] if self.opts["selected_target_option"] @@ -339,6 +340,9 @@ class SSH(MultiprocessingStateMixin): self.opts["cachedir"], "salt-ssh.session.lock" ) self.ssh_session_grace_time = int(self.opts.get("ssh_session_grace_time", 1)) + self.sopts["tgt"] = copy.deepcopy(self.opts["tgt"]) + self.sopts["ssh_cli_tgt"] = copy.deepcopy(self.opts["ssh_cli_tgt"]) + self.opts = self.sopts # __setstate__ and __getstate__ are only used on spawning platforms. def __setstate__(self, state): @@ -607,7 +611,6 @@ class SSH(MultiprocessingStateMixin): Spin up the needed threads or processes and execute the subsequent routines """ - opts = copy.deepcopy(self.opts) que = multiprocessing.Queue() running = {} targets_queue = deque(self.targets.keys()) @@ -618,7 +621,7 @@ class SSH(MultiprocessingStateMixin): if not self.targets: log.error("No matching targets found in roster.") break - if len(running) < opts.get("ssh_max_procs", 25) and not init: + if len(running) < self.opts.get("ssh_max_procs", 25) and not init: if targets_queue: host = targets_queue.popleft() else: @@ -682,7 +685,7 @@ class SSH(MultiprocessingStateMixin): continue args = ( que, - opts, + self.opts, host, self.targets[host], mine, @@ -776,7 +779,7 @@ class SSH(MultiprocessingStateMixin): if len(rets) >= len(self.targets): break # Sleep when limit or all threads started - if len(running) >= opts.get("ssh_max_procs", 25) or len( + if len(running) >= self.opts.get("ssh_max_procs", 25) or len( self.targets ) >= len(running): time.sleep(0.1) diff --git a/tests/pytests/unit/client/ssh/test_password.py b/tests/pytests/unit/client/ssh/test_password.py index 8a7794d2f4..0ca28d022e 100644 --- a/tests/pytests/unit/client/ssh/test_password.py +++ b/tests/pytests/unit/client/ssh/test_password.py @@ -27,6 +27,8 @@ def test_password_failure(temp_salt_master, tmp_path): opts["argv"] = ["test.ping"] opts["selected_target_option"] = "glob" opts["tgt"] = "localhost" + opts["ssh_cli_tgt"] = "localhost" + opts["_ssh_version"] = "foobar" opts["arg"] = [] roster = str(tmp_path / "roster") handle_ssh_ret = [ @@ -44,7 +46,7 @@ def test_password_failure(temp_salt_master, tmp_path): "salt.client.ssh.SSH.handle_ssh", MagicMock(return_value=handle_ssh_ret) ), patch("salt.client.ssh.SSH.key_deploy", MagicMock(return_value=expected)), patch( "salt.output.display_output", display_output - ): + ), patch("salt.client.ssh.ssh_version", MagicMock(return_value="foobar")): client = ssh.SSH(opts) ret = next(client.run_iter()) with pytest.raises(SystemExit): diff --git a/tests/pytests/unit/client/ssh/test_return_events.py b/tests/pytests/unit/client/ssh/test_return_events.py index 1f0b0dbf33..18714741b9 100644 --- a/tests/pytests/unit/client/ssh/test_return_events.py +++ b/tests/pytests/unit/client/ssh/test_return_events.py @@ -43,7 +43,7 @@ def test_not_missing_fun_calling_wfuncs(temp_salt_master, tmp_path): assert "localhost" in ret assert "fun" in ret["localhost"] client.run() - display_output.assert_called_once_with(expected, "nested", opts) + display_output.assert_called_once_with(expected, "nested", client.opts) assert ret is handle_ssh_ret[0] assert len(client.event.fire_event.call_args_list) == 2 assert "fun" in client.event.fire_event.call_args_list[0][0][0] diff --git a/tests/pytests/unit/client/ssh/test_ssh.py b/tests/pytests/unit/client/ssh/test_ssh.py index 2be96ab195..377aad9998 100644 --- a/tests/pytests/unit/client/ssh/test_ssh.py +++ b/tests/pytests/unit/client/ssh/test_ssh.py @@ -148,7 +148,7 @@ def test_expand_target_ip_address(opts, roster): MagicMock(return_value=salt.utils.yaml.safe_load(roster)), ): client._expand_target() - assert opts["tgt"] == host + assert client.opts["tgt"] == host def test_expand_target_no_host(opts, tmp_path): @@ -171,7 +171,7 @@ def test_expand_target_no_host(opts, tmp_path): assert opts["tgt"] == user + host with patch("salt.roster.get_roster_file", MagicMock(return_value=roster_file)): client._expand_target() - assert opts["tgt"] == host + assert client.opts["tgt"] == host def test_expand_target_dns(opts, roster): @@ -192,7 +192,7 @@ def test_expand_target_dns(opts, roster): MagicMock(return_value=salt.utils.yaml.safe_load(roster)), ): client._expand_target() - assert opts["tgt"] == host + assert client.opts["tgt"] == host def test_expand_target_no_user(opts, roster): @@ -204,7 +204,7 @@ def test_expand_target_no_user(opts, roster): with patch("salt.utils.network.is_reachable_host", MagicMock(return_value=False)): client = ssh.SSH(opts) - assert opts["tgt"] == host + assert client.opts["tgt"] == host with patch( "salt.roster.get_roster_file", MagicMock(return_value="/etc/salt/roster") @@ -213,7 +213,7 @@ def test_expand_target_no_user(opts, roster): MagicMock(return_value=salt.utils.yaml.safe_load(roster)), ): client._expand_target() - assert opts["tgt"] == host + assert client.opts["tgt"] == host def test_update_targets_ip_address(opts): @@ -228,7 +228,7 @@ def test_update_targets_ip_address(opts): client = ssh.SSH(opts) assert opts["tgt"] == user + host client._update_targets() - assert opts["tgt"] == host + assert client.opts["tgt"] == host assert client.targets[host]["user"] == user.split("@")[0] @@ -244,7 +244,7 @@ def test_update_targets_dns(opts): client = ssh.SSH(opts) assert opts["tgt"] == user + host client._update_targets() - assert opts["tgt"] == host + assert client.opts["tgt"] == host assert client.targets[host]["user"] == user.split("@")[0] @@ -259,7 +259,7 @@ def test_update_targets_no_user(opts): client = ssh.SSH(opts) assert opts["tgt"] == host client._update_targets() - assert opts["tgt"] == host + assert client.opts["tgt"] == host def test_update_expand_target_dns(opts, roster): @@ -281,7 +281,7 @@ def test_update_expand_target_dns(opts, roster): ): client._expand_target() client._update_targets() - assert opts["tgt"] == host + assert client.opts["tgt"] == host assert client.targets[host]["user"] == user.split("@")[0] @@ -299,7 +299,7 @@ def test_parse_tgt(opts): client = ssh.SSH(opts) assert client.parse_tgt["hostname"] == host assert client.parse_tgt["user"] == user.split("@")[0] - assert opts.get("ssh_cli_tgt") == user + host + assert client.opts.get("ssh_cli_tgt") == user + host def test_parse_tgt_no_user(opts): @@ -316,7 +316,7 @@ def test_parse_tgt_no_user(opts): client = ssh.SSH(opts) assert client.parse_tgt["hostname"] == host assert client.parse_tgt["user"] == opts["ssh_user"] - assert opts.get("ssh_cli_tgt") == host + assert client.opts.get("ssh_cli_tgt") == host def test_extra_filerefs(tmp_path, opts): -- 2.39.2