salt/async-batch-implementation.patch

1150 lines
42 KiB
Diff

From 76e69d9ef729365db1b0f1798f5f8a038d2065fc Mon Sep 17 00:00:00 2001
From: Mihai Dinca <mdinca@suse.de>
Date: Fri, 16 Nov 2018 17:05:29 +0100
Subject: [PATCH] Async batch implementation
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Add find_job checks
Check if should close on all events
Make batch_delay a request parameter
Allow multiple event handlers
Use config value for gather_job_timeout when not in payload
Add async batch unittests
Allow metadata to pass
Pass metadata only to batch jobs
Add the metadata to the start/done events
Pass only metadata not all **kwargs
Add separate batch presence_ping timeout
Fix async batch race conditions
Close batching when there is no next batch
Add 'batch_presence_ping_timeout' and 'batch_presence_ping_gather_job_timeout' parameters for synchronous batching
Fix async-batch multiple done events
Fix memory leak produced by batch async find_jobs mechanism (bsc#1140912)
Multiple fixes:
- use different JIDs per find_job
- fix bug in detection of find_job returns
- fix timeout passed from request payload
- better cleanup at the end of batching
Co-authored-by: Pablo Suárez Hernández <psuarezhernandez@suse.com>
Improve batch_async to release consumed memory (bsc#1140912)
Use current IOLoop for the LocalClient instance of BatchAsync (bsc#1137642)
Fix failing unit tests for batch async
Remove unnecessary yield causing BadYieldError (bsc#1154620)
Fixing StreamClosed issue
Fix batch_async obsolete test
batch_async: avoid using fnmatch to match event (#217)
Batch Async: Catch exceptions and safety unregister and close instances
Fix unit tests for batch async after refactor
Changed imports to vendored Tornado
Async batch implementation fix (#320)
Remove deprecated usage of NO_MOCK and NO_MOCK_REASON
---
salt/auth/__init__.py | 2 +
salt/cli/batch.py | 109 ++++--
salt/cli/batch_async.py | 315 +++++++++++++++++
salt/cli/support/profiles/__init__.py | 5 +-
salt/client/__init__.py | 45 +--
salt/master.py | 20 ++
salt/transport/ipc.py | 9 +-
salt/utils/event.py | 8 +-
tests/pytests/unit/cli/test_batch_async.py | 386 +++++++++++++++++++++
9 files changed, 841 insertions(+), 58 deletions(-)
create mode 100644 salt/cli/batch_async.py
create mode 100644 tests/pytests/unit/cli/test_batch_async.py
diff --git a/salt/auth/__init__.py b/salt/auth/__init__.py
index 331baab211..b0f0c0ac6c 100644
--- a/salt/auth/__init__.py
+++ b/salt/auth/__init__.py
@@ -49,6 +49,8 @@ AUTH_INTERNAL_KEYWORDS = frozenset(
"print_event",
"raw",
"yield_pub_data",
+ "batch",
+ "batch_delay",
]
)
diff --git a/salt/cli/batch.py b/salt/cli/batch.py
index 8e1547c61d..fcd3f571d5 100644
--- a/salt/cli/batch.py
+++ b/salt/cli/batch.py
@@ -13,9 +13,88 @@ import salt.exceptions
import salt.output
import salt.utils.stringutils
+# pylint: disable=import-error,no-name-in-module,redefined-builtin
+
log = logging.getLogger(__name__)
+def get_bnum(opts, minions, quiet):
+ """
+ Return the active number of minions to maintain
+ """
+ partition = lambda x: float(x) / 100.0 * len(minions)
+ try:
+ if isinstance(opts["batch"], str) and "%" in opts["batch"]:
+ res = partition(float(opts["batch"].strip("%")))
+ if res < 1:
+ return int(math.ceil(res))
+ else:
+ return int(res)
+ else:
+ return int(opts["batch"])
+ except ValueError:
+ if not quiet:
+ salt.utils.stringutils.print_cli(
+ "Invalid batch data sent: {}\nData must be in the "
+ "form of %10, 10% or 3".format(opts["batch"])
+ )
+
+
+def batch_get_opts(
+ tgt, fun, batch, parent_opts, arg=(), tgt_type="glob", ret="", kwarg=None, **kwargs
+):
+ # We need to re-import salt.utils.args here
+ # even though it has already been imported.
+ # when cmd_batch is called via the NetAPI
+ # the module is unavailable.
+ import salt.utils.args
+
+ arg = salt.utils.args.condition_input(arg, kwarg)
+ opts = {
+ "tgt": tgt,
+ "fun": fun,
+ "arg": arg,
+ "tgt_type": tgt_type,
+ "ret": ret,
+ "batch": batch,
+ "failhard": kwargs.get("failhard", parent_opts.get("failhard", False)),
+ "raw": kwargs.get("raw", False),
+ }
+
+ if "timeout" in kwargs:
+ opts["timeout"] = kwargs["timeout"]
+ if "gather_job_timeout" in kwargs:
+ opts["gather_job_timeout"] = kwargs["gather_job_timeout"]
+ if "batch_wait" in kwargs:
+ opts["batch_wait"] = int(kwargs["batch_wait"])
+
+ for key, val in parent_opts.items():
+ if key not in opts:
+ opts[key] = val
+
+ opts["batch_presence_ping_timeout"] = kwargs.get(
+ "batch_presence_ping_timeout", opts["timeout"]
+ )
+ opts["batch_presence_ping_gather_job_timeout"] = kwargs.get(
+ "batch_presence_ping_gather_job_timeout", opts["gather_job_timeout"]
+ )
+
+ return opts
+
+
+def batch_get_eauth(kwargs):
+ eauth = {}
+ if "eauth" in kwargs:
+ eauth["eauth"] = kwargs.pop("eauth")
+ if "username" in kwargs:
+ eauth["username"] = kwargs.pop("username")
+ if "password" in kwargs:
+ eauth["password"] = kwargs.pop("password")
+ if "token" in kwargs:
+ eauth["token"] = kwargs.pop("token")
+ return eauth
+
+
class Batch:
"""
Manage the execution of batch runs
@@ -39,6 +118,7 @@ class Batch:
self.pub_kwargs = eauth if eauth else {}
self.quiet = quiet
self.options = _parser
+ self.minions = set()
# Passing listen True to local client will prevent it from purging
# cahced events while iterating over the batches.
self.local = salt.client.get_local_client(opts["conf_file"], listen=True)
@@ -51,7 +131,7 @@ class Batch:
self.opts["tgt"],
"test.ping",
[],
- self.opts["timeout"],
+ self.opts.get("batch_presence_ping_timeout", self.opts["timeout"]),
]
selected_target_option = self.opts.get("selected_target_option", None)
@@ -62,7 +142,12 @@ class Batch:
self.pub_kwargs["yield_pub_data"] = True
ping_gen = self.local.cmd_iter(
- *args, gather_job_timeout=self.opts["gather_job_timeout"], **self.pub_kwargs
+ *args,
+ gather_job_timeout=self.opts.get(
+ "batch_presence_ping_gather_job_timeout",
+ self.opts["gather_job_timeout"],
+ ),
+ **self.pub_kwargs
)
# Broadcast to targets
@@ -87,25 +172,7 @@ class Batch:
return (list(fret), ping_gen, nret.difference(fret))
def get_bnum(self):
- """
- Return the active number of minions to maintain
- """
- partition = lambda x: float(x) / 100.0 * len(self.minions)
- try:
- if isinstance(self.opts["batch"], str) and "%" in self.opts["batch"]:
- res = partition(float(self.opts["batch"].strip("%")))
- if res < 1:
- return int(math.ceil(res))
- else:
- return int(res)
- else:
- return int(self.opts["batch"])
- except ValueError:
- if not self.quiet:
- salt.utils.stringutils.print_cli(
- "Invalid batch data sent: {}\nData must be in the "
- "form of %10, 10% or 3".format(self.opts["batch"])
- )
+ return get_bnum(self.opts, self.minions, self.quiet)
def __update_wait(self, wait):
now = datetime.now()
diff --git a/salt/cli/batch_async.py b/salt/cli/batch_async.py
new file mode 100644
index 0000000000..09aa85258b
--- /dev/null
+++ b/salt/cli/batch_async.py
@@ -0,0 +1,315 @@
+"""
+Execute a job on the targeted minions by using a moving window of fixed size `batch`.
+"""
+
+import gc
+
+# pylint: enable=import-error,no-name-in-module,redefined-builtin
+import logging
+
+import salt.client
+import salt.ext.tornado
+import tornado
+from salt.cli.batch import batch_get_eauth, batch_get_opts, get_bnum
+
+log = logging.getLogger(__name__)
+
+
+class BatchAsync:
+ """
+ Run a job on the targeted minions by using a moving window of fixed size `batch`.
+
+ ``BatchAsync`` is used to execute a job on the targeted minions by keeping
+ the number of concurrent running minions to the size of `batch` parameter.
+
+ The control parameters are:
+ - batch: number/percentage of concurrent running minions
+ - batch_delay: minimum wait time between batches
+ - batch_presence_ping_timeout: time to wait for presence pings before starting the batch
+ - gather_job_timeout: `find_job` timeout
+ - timeout: time to wait before firing a `find_job`
+
+ When the batch stars, a `start` event is fired:
+ - tag: salt/batch/<batch-jid>/start
+ - data: {
+ "available_minions": self.minions,
+ "down_minions": targeted_minions - presence_ping_minions
+ }
+
+ When the batch ends, an `done` event is fired:
+ - tag: salt/batch/<batch-jid>/done
+ - data: {
+ "available_minions": self.minions,
+ "down_minions": targeted_minions - presence_ping_minions
+ "done_minions": self.done_minions,
+ "timedout_minions": self.timedout_minions
+ }
+ """
+
+ def __init__(self, parent_opts, jid_gen, clear_load):
+ ioloop = salt.ext.tornado.ioloop.IOLoop.current()
+ self.local = salt.client.get_local_client(
+ parent_opts["conf_file"], io_loop=ioloop
+ )
+ if "gather_job_timeout" in clear_load["kwargs"]:
+ clear_load["gather_job_timeout"] = clear_load["kwargs"].pop(
+ "gather_job_timeout"
+ )
+ else:
+ clear_load["gather_job_timeout"] = self.local.opts["gather_job_timeout"]
+ self.batch_presence_ping_timeout = clear_load["kwargs"].get(
+ "batch_presence_ping_timeout", None
+ )
+ self.batch_delay = clear_load["kwargs"].get("batch_delay", 1)
+ self.opts = batch_get_opts(
+ clear_load.pop("tgt"),
+ clear_load.pop("fun"),
+ clear_load["kwargs"].pop("batch"),
+ self.local.opts,
+ **clear_load
+ )
+ self.eauth = batch_get_eauth(clear_load["kwargs"])
+ self.metadata = clear_load["kwargs"].get("metadata", {})
+ self.minions = set()
+ self.targeted_minions = set()
+ self.timedout_minions = set()
+ self.done_minions = set()
+ self.active = set()
+ self.initialized = False
+ self.jid_gen = jid_gen
+ self.ping_jid = jid_gen()
+ self.batch_jid = jid_gen()
+ self.find_job_jid = jid_gen()
+ self.find_job_returned = set()
+ self.ended = False
+ self.event = salt.utils.event.get_event(
+ "master",
+ self.opts["sock_dir"],
+ self.opts["transport"],
+ opts=self.opts,
+ listen=True,
+ io_loop=ioloop,
+ keep_loop=True,
+ )
+ self.scheduled = False
+ self.patterns = set()
+
+ def __set_event_handler(self):
+ ping_return_pattern = "salt/job/{}/ret/*".format(self.ping_jid)
+ batch_return_pattern = "salt/job/{}/ret/*".format(self.batch_jid)
+ self.event.subscribe(ping_return_pattern, match_type="glob")
+ self.event.subscribe(batch_return_pattern, match_type="glob")
+ self.patterns = {
+ (ping_return_pattern, "ping_return"),
+ (batch_return_pattern, "batch_run"),
+ }
+ self.event.set_event_handler(self.__event_handler)
+
+ def __event_handler(self, raw):
+ if not self.event:
+ return
+ try:
+ mtag, data = self.event.unpack(raw, self.event.serial)
+ for (pattern, op) in self.patterns:
+ if mtag.startswith(pattern[:-1]):
+ minion = data["id"]
+ if op == "ping_return":
+ self.minions.add(minion)
+ if self.targeted_minions == self.minions:
+ self.event.io_loop.spawn_callback(self.start_batch)
+ elif op == "find_job_return":
+ if data.get("return", None):
+ self.find_job_returned.add(minion)
+ elif op == "batch_run":
+ if minion in self.active:
+ self.active.remove(minion)
+ self.done_minions.add(minion)
+ self.event.io_loop.spawn_callback(self.schedule_next)
+ except Exception as ex:
+ log.error("Exception occured while processing event: {}".format(ex))
+
+ def _get_next(self):
+ to_run = (
+ self.minions.difference(self.done_minions)
+ .difference(self.active)
+ .difference(self.timedout_minions)
+ )
+ next_batch_size = min(
+ len(to_run), # partial batch (all left)
+ self.batch_size - len(self.active), # full batch or available slots
+ )
+ return set(list(to_run)[:next_batch_size])
+
+ def check_find_job(self, batch_minions, jid):
+ if self.event:
+ find_job_return_pattern = "salt/job/{}/ret/*".format(jid)
+ self.event.unsubscribe(find_job_return_pattern, match_type="glob")
+ self.patterns.remove((find_job_return_pattern, "find_job_return"))
+
+ timedout_minions = batch_minions.difference(
+ self.find_job_returned
+ ).difference(self.done_minions)
+ self.timedout_minions = self.timedout_minions.union(timedout_minions)
+ self.active = self.active.difference(self.timedout_minions)
+ running = batch_minions.difference(self.done_minions).difference(
+ self.timedout_minions
+ )
+
+ if timedout_minions:
+ self.schedule_next()
+
+ if self.event and running:
+ self.find_job_returned = self.find_job_returned.difference(running)
+ self.event.io_loop.spawn_callback(self.find_job, running)
+
+ @salt.ext.tornado.gen.coroutine
+ def find_job(self, minions):
+ if self.event:
+ not_done = minions.difference(self.done_minions).difference(
+ self.timedout_minions
+ )
+ try:
+ if not_done:
+ jid = self.jid_gen()
+ find_job_return_pattern = "salt/job/{}/ret/*".format(jid)
+ self.patterns.add((find_job_return_pattern, "find_job_return"))
+ self.event.subscribe(find_job_return_pattern, match_type="glob")
+ ret = yield self.local.run_job_async(
+ not_done,
+ "saltutil.find_job",
+ [self.batch_jid],
+ "list",
+ gather_job_timeout=self.opts["gather_job_timeout"],
+ jid=jid,
+ **self.eauth
+ )
+ yield salt.ext.tornado.gen.sleep(self.opts["gather_job_timeout"])
+ if self.event:
+ self.event.io_loop.spawn_callback(
+ self.check_find_job, not_done, jid
+ )
+ except Exception as ex:
+ log.error(
+ "Exception occured handling batch async: {}. Aborting execution.".format(
+ ex
+ )
+ )
+ self.close_safe()
+
+ @salt.ext.tornado.gen.coroutine
+ def start(self):
+ if self.event:
+ self.__set_event_handler()
+ ping_return = yield self.local.run_job_async(
+ self.opts["tgt"],
+ "test.ping",
+ [],
+ self.opts.get(
+ "selected_target_option", self.opts.get("tgt_type", "glob")
+ ),
+ gather_job_timeout=self.opts["gather_job_timeout"],
+ jid=self.ping_jid,
+ metadata=self.metadata,
+ **self.eauth
+ )
+ self.targeted_minions = set(ping_return["minions"])
+ # start batching even if not all minions respond to ping
+ yield salt.ext.tornado.gen.sleep(
+ self.batch_presence_ping_timeout or self.opts["gather_job_timeout"]
+ )
+ if self.event:
+ self.event.io_loop.spawn_callback(self.start_batch)
+
+ @salt.ext.tornado.gen.coroutine
+ def start_batch(self):
+ if not self.initialized:
+ self.batch_size = get_bnum(self.opts, self.minions, True)
+ self.initialized = True
+ data = {
+ "available_minions": self.minions,
+ "down_minions": self.targeted_minions.difference(self.minions),
+ "metadata": self.metadata,
+ }
+ ret = self.event.fire_event(
+ data, "salt/batch/{}/start".format(self.batch_jid)
+ )
+ if self.event:
+ self.event.io_loop.spawn_callback(self.run_next)
+
+ @salt.ext.tornado.gen.coroutine
+ def end_batch(self):
+ left = self.minions.symmetric_difference(
+ self.done_minions.union(self.timedout_minions)
+ )
+ if not left and not self.ended:
+ self.ended = True
+ data = {
+ "available_minions": self.minions,
+ "down_minions": self.targeted_minions.difference(self.minions),
+ "done_minions": self.done_minions,
+ "timedout_minions": self.timedout_minions,
+ "metadata": self.metadata,
+ }
+ self.event.fire_event(data, "salt/batch/{}/done".format(self.batch_jid))
+
+ # release to the IOLoop to allow the event to be published
+ # before closing batch async execution
+ yield salt.ext.tornado.gen.sleep(1)
+ self.close_safe()
+
+ def close_safe(self):
+ for (pattern, label) in self.patterns:
+ self.event.unsubscribe(pattern, match_type="glob")
+ self.event.remove_event_handler(self.__event_handler)
+ self.event = None
+ self.local = None
+ self.ioloop = None
+ del self
+ gc.collect()
+
+ @salt.ext.tornado.gen.coroutine
+ def schedule_next(self):
+ if not self.scheduled:
+ self.scheduled = True
+ # call later so that we maybe gather more returns
+ yield salt.ext.tornado.gen.sleep(self.batch_delay)
+ if self.event:
+ self.event.io_loop.spawn_callback(self.run_next)
+
+ @salt.ext.tornado.gen.coroutine
+ def run_next(self):
+ self.scheduled = False
+ next_batch = self._get_next()
+ if next_batch:
+ self.active = self.active.union(next_batch)
+ try:
+ ret = yield self.local.run_job_async(
+ next_batch,
+ self.opts["fun"],
+ self.opts["arg"],
+ "list",
+ raw=self.opts.get("raw", False),
+ ret=self.opts.get("return", ""),
+ gather_job_timeout=self.opts["gather_job_timeout"],
+ jid=self.batch_jid,
+ metadata=self.metadata,
+ )
+
+ yield salt.ext.tornado.gen.sleep(self.opts["timeout"])
+
+ # The batch can be done already at this point, which means no self.event
+ if self.event:
+ self.event.io_loop.spawn_callback(self.find_job, set(next_batch))
+ except Exception as ex:
+ log.error("Error in scheduling next batch: %s. Aborting execution", ex)
+ self.active = self.active.difference(next_batch)
+ self.close_safe()
+ else:
+ yield self.end_batch()
+ gc.collect()
+
+ def __del__(self):
+ self.local = None
+ self.event = None
+ self.ioloop = None
+ gc.collect()
diff --git a/salt/cli/support/profiles/__init__.py b/salt/cli/support/profiles/__init__.py
index b86aef30b8..4ae6d07b13 100644
--- a/salt/cli/support/profiles/__init__.py
+++ b/salt/cli/support/profiles/__init__.py
@@ -1,4 +1,3 @@
-# coding=utf-8
-'''
+"""
Profiles for salt-support.
-'''
+"""
diff --git a/salt/client/__init__.py b/salt/client/__init__.py
index 7ce8963b8f..bcda56c9b4 100644
--- a/salt/client/__init__.py
+++ b/salt/client/__init__.py
@@ -594,38 +594,20 @@ class LocalClient:
import salt.cli.batch
import salt.utils.args
- arg = salt.utils.args.condition_input(arg, kwarg)
- opts = {
- "tgt": tgt,
- "fun": fun,
- "arg": arg,
- "tgt_type": tgt_type,
- "ret": ret,
- "batch": batch,
- "failhard": kwargs.get("failhard", self.opts.get("failhard", False)),
- "raw": kwargs.get("raw", False),
- }
+ opts = salt.cli.batch.batch_get_opts(
+ tgt,
+ fun,
+ batch,
+ self.opts,
+ arg=arg,
+ tgt_type=tgt_type,
+ ret=ret,
+ kwarg=kwarg,
+ **kwargs
+ )
+
+ eauth = salt.cli.batch.batch_get_eauth(kwargs)
- if "timeout" in kwargs:
- opts["timeout"] = kwargs["timeout"]
- if "gather_job_timeout" in kwargs:
- opts["gather_job_timeout"] = kwargs["gather_job_timeout"]
- if "batch_wait" in kwargs:
- opts["batch_wait"] = int(kwargs["batch_wait"])
-
- eauth = {}
- if "eauth" in kwargs:
- eauth["eauth"] = kwargs.pop("eauth")
- if "username" in kwargs:
- eauth["username"] = kwargs.pop("username")
- if "password" in kwargs:
- eauth["password"] = kwargs.pop("password")
- if "token" in kwargs:
- eauth["token"] = kwargs.pop("token")
-
- for key, val in self.opts.items():
- if key not in opts:
- opts[key] = val
batch = salt.cli.batch.Batch(opts, eauth=eauth, quiet=True)
for ret, _ in batch.run():
yield ret
@@ -1826,6 +1808,7 @@ class LocalClient:
"key": self.key,
"tgt_type": tgt_type,
"ret": ret,
+ "timeout": timeout,
"jid": jid,
}
diff --git a/salt/master.py b/salt/master.py
index 9d2239bffb..2a526b4f21 100644
--- a/salt/master.py
+++ b/salt/master.py
@@ -19,6 +19,7 @@ import time
import salt.acl
import salt.auth
import salt.channel.server
+import salt.cli.batch_async
import salt.client
import salt.client.ssh.client
import salt.crypt
@@ -2153,6 +2154,22 @@ class ClearFuncs(TransportMethods):
return False
return self.loadauth.get_tok(clear_load["token"])
+ def publish_batch(self, clear_load, minions, missing):
+ batch_load = {}
+ batch_load.update(clear_load)
+ batch = salt.cli.batch_async.BatchAsync(
+ self.local.opts,
+ functools.partial(self._prep_jid, clear_load, {}),
+ batch_load,
+ )
+ ioloop = salt.ext.tornado.ioloop.IOLoop.current()
+ ioloop.add_callback(batch.start)
+
+ return {
+ "enc": "clear",
+ "load": {"jid": batch.batch_jid, "minions": minions, "missing": missing},
+ }
+
def publish(self, clear_load):
"""
This method sends out publications to the minions, it can only be used
@@ -2297,6 +2314,9 @@ class ClearFuncs(TransportMethods):
),
},
}
+ if extra.get("batch", None):
+ return self.publish_batch(clear_load, minions, missing)
+
jid = self._prep_jid(clear_load, extra)
if jid is None:
return {"enc": "clear", "load": {"error": "Master failed to assign jid"}}
diff --git a/salt/transport/ipc.py b/salt/transport/ipc.py
index ca13a498e3..3a3f0c7a5f 100644
--- a/salt/transport/ipc.py
+++ b/salt/transport/ipc.py
@@ -659,6 +659,7 @@ class IPCMessageSubscriber(IPCClient):
self._read_stream_future = None
self._saved_data = []
self._read_in_progress = Lock()
+ self.callbacks = set()
@salt.ext.tornado.gen.coroutine
def _read(self, timeout, callback=None):
@@ -764,8 +765,12 @@ class IPCMessageSubscriber(IPCClient):
return self._saved_data.pop(0)
return self.io_loop.run_sync(lambda: self._read(timeout))
+ def __run_callbacks(self, raw):
+ for callback in self.callbacks:
+ self.io_loop.spawn_callback(callback, raw)
+
@salt.ext.tornado.gen.coroutine
- def read_async(self, callback):
+ def read_async(self):
"""
Asynchronously read messages and invoke a callback when they are ready.
@@ -783,7 +788,7 @@ class IPCMessageSubscriber(IPCClient):
except Exception as exc: # pylint: disable=broad-except
log.error("Exception occurred while Subscriber connecting: %s", exc)
yield salt.ext.tornado.gen.sleep(1)
- yield self._read(None, callback)
+ yield self._read(None, self.__run_callbacks)
def close(self):
"""
diff --git a/salt/utils/event.py b/salt/utils/event.py
index a07ad513b1..869e12a140 100644
--- a/salt/utils/event.py
+++ b/salt/utils/event.py
@@ -946,6 +946,10 @@ class SaltEvent:
# Minion fired a bad retcode, fire an event
self._fire_ret_load_specific_fun(load)
+ def remove_event_handler(self, event_handler):
+ if event_handler in self.subscriber.callbacks:
+ self.subscriber.callbacks.remove(event_handler)
+
def set_event_handler(self, event_handler):
"""
Invoke the event_handler callback each time an event arrives.
@@ -954,8 +958,10 @@ class SaltEvent:
if not self.cpub:
self.connect_pub()
+
+ self.subscriber.callbacks.add(event_handler)
# This will handle reconnects
- return self.subscriber.read_async(event_handler)
+ return self.subscriber.read_async()
# pylint: disable=W1701
def __del__(self):
diff --git a/tests/pytests/unit/cli/test_batch_async.py b/tests/pytests/unit/cli/test_batch_async.py
new file mode 100644
index 0000000000..c0b708de76
--- /dev/null
+++ b/tests/pytests/unit/cli/test_batch_async.py
@@ -0,0 +1,386 @@
+import salt.ext.tornado
+from salt.cli.batch_async import BatchAsync
+from salt.ext.tornado.testing import AsyncTestCase
+from tests.support.mock import MagicMock, patch
+from tests.support.unit import TestCase, skipIf
+
+
+class AsyncBatchTestCase(AsyncTestCase, TestCase):
+ def setUp(self):
+ self.io_loop = self.get_new_ioloop()
+ opts = {
+ "batch": "1",
+ "conf_file": {},
+ "tgt": "*",
+ "timeout": 5,
+ "gather_job_timeout": 5,
+ "batch_presence_ping_timeout": 1,
+ "transport": None,
+ "sock_dir": "",
+ }
+
+ with patch("salt.client.get_local_client", MagicMock(return_value=MagicMock())):
+ with patch(
+ "salt.cli.batch_async.batch_get_opts", MagicMock(return_value=opts)
+ ):
+ self.batch = BatchAsync(
+ opts,
+ MagicMock(side_effect=["1234", "1235", "1236"]),
+ {
+ "tgt": "",
+ "fun": "",
+ "kwargs": {"batch": "", "batch_presence_ping_timeout": 1},
+ },
+ )
+
+ def test_ping_jid(self):
+ self.assertEqual(self.batch.ping_jid, "1234")
+
+ def test_batch_jid(self):
+ self.assertEqual(self.batch.batch_jid, "1235")
+
+ def test_find_job_jid(self):
+ self.assertEqual(self.batch.find_job_jid, "1236")
+
+ def test_batch_size(self):
+ """
+ Tests passing batch value as a number
+ """
+ self.batch.opts = {"batch": "2", "timeout": 5}
+ self.batch.minions = {"foo", "bar"}
+ self.batch.start_batch()
+ self.assertEqual(self.batch.batch_size, 2)
+
+ @salt.ext.tornado.testing.gen_test
+ def test_batch_start_on_batch_presence_ping_timeout(self):
+ self.batch.event = MagicMock()
+ future = salt.ext.tornado.gen.Future()
+ future.set_result({"minions": ["foo", "bar"]})
+ self.batch.local.run_job_async.return_value = future
+ ret = self.batch.start()
+ # assert start_batch is called later with batch_presence_ping_timeout as param
+ self.assertEqual(
+ self.batch.event.io_loop.spawn_callback.call_args[0],
+ (self.batch.start_batch,),
+ )
+ # assert test.ping called
+ self.assertEqual(
+ self.batch.local.run_job_async.call_args[0], ("*", "test.ping", [], "glob")
+ )
+ # assert targeted_minions == all minions matched by tgt
+ self.assertEqual(self.batch.targeted_minions, {"foo", "bar"})
+
+ @salt.ext.tornado.testing.gen_test
+ def test_batch_start_on_gather_job_timeout(self):
+ self.batch.event = MagicMock()
+ future = salt.ext.tornado.gen.Future()
+ future.set_result({"minions": ["foo", "bar"]})
+ self.batch.local.run_job_async.return_value = future
+ self.batch.batch_presence_ping_timeout = None
+ ret = self.batch.start()
+ # assert start_batch is called later with gather_job_timeout as param
+ self.assertEqual(
+ self.batch.event.io_loop.spawn_callback.call_args[0],
+ (self.batch.start_batch,),
+ )
+
+ def test_batch_fire_start_event(self):
+ self.batch.minions = {"foo", "bar"}
+ self.batch.opts = {"batch": "2", "timeout": 5}
+ self.batch.event = MagicMock()
+ self.batch.metadata = {"mykey": "myvalue"}
+ self.batch.start_batch()
+ self.assertEqual(
+ self.batch.event.fire_event.call_args[0],
+ (
+ {
+ "available_minions": {"foo", "bar"},
+ "down_minions": set(),
+ "metadata": self.batch.metadata,
+ },
+ "salt/batch/1235/start",
+ ),
+ )
+
+ @salt.ext.tornado.testing.gen_test
+ def test_start_batch_calls_next(self):
+ self.batch.run_next = MagicMock(return_value=MagicMock())
+ self.batch.event = MagicMock()
+ self.batch.start_batch()
+ self.assertEqual(self.batch.initialized, True)
+ self.assertEqual(
+ self.batch.event.io_loop.spawn_callback.call_args[0], (self.batch.run_next,)
+ )
+
+ def test_batch_fire_done_event(self):
+ self.batch.targeted_minions = {"foo", "baz", "bar"}
+ self.batch.minions = {"foo", "bar"}
+ self.batch.done_minions = {"foo"}
+ self.batch.timedout_minions = {"bar"}
+ self.batch.event = MagicMock()
+ self.batch.metadata = {"mykey": "myvalue"}
+ old_event = self.batch.event
+ self.batch.end_batch()
+ self.assertEqual(
+ old_event.fire_event.call_args[0],
+ (
+ {
+ "available_minions": {"foo", "bar"},
+ "done_minions": self.batch.done_minions,
+ "down_minions": {"baz"},
+ "timedout_minions": self.batch.timedout_minions,
+ "metadata": self.batch.metadata,
+ },
+ "salt/batch/1235/done",
+ ),
+ )
+
+ def test_batch__del__(self):
+ batch = BatchAsync(MagicMock(), MagicMock(), MagicMock())
+ event = MagicMock()
+ batch.event = event
+ batch.__del__()
+ self.assertEqual(batch.local, None)
+ self.assertEqual(batch.event, None)
+ self.assertEqual(batch.ioloop, None)
+
+ def test_batch_close_safe(self):
+ batch = BatchAsync(MagicMock(), MagicMock(), MagicMock())
+ event = MagicMock()
+ batch.event = event
+ batch.patterns = {
+ ("salt/job/1234/ret/*", "find_job_return"),
+ ("salt/job/4321/ret/*", "find_job_return"),
+ }
+ batch.close_safe()
+ self.assertEqual(batch.local, None)
+ self.assertEqual(batch.event, None)
+ self.assertEqual(batch.ioloop, None)
+ self.assertEqual(len(event.unsubscribe.mock_calls), 2)
+ self.assertEqual(len(event.remove_event_handler.mock_calls), 1)
+
+ @salt.ext.tornado.testing.gen_test
+ def test_batch_next(self):
+ self.batch.event = MagicMock()
+ self.batch.opts["fun"] = "my.fun"
+ self.batch.opts["arg"] = []
+ self.batch._get_next = MagicMock(return_value={"foo", "bar"})
+ self.batch.batch_size = 2
+ future = salt.ext.tornado.gen.Future()
+ future.set_result({"minions": ["foo", "bar"]})
+ self.batch.local.run_job_async.return_value = future
+ self.batch.run_next()
+ self.assertEqual(
+ self.batch.local.run_job_async.call_args[0],
+ ({"foo", "bar"}, "my.fun", [], "list"),
+ )
+ self.assertEqual(
+ self.batch.event.io_loop.spawn_callback.call_args[0],
+ (self.batch.find_job, {"foo", "bar"}),
+ )
+ self.assertEqual(self.batch.active, {"bar", "foo"})
+
+ def test_next_batch(self):
+ self.batch.minions = {"foo", "bar"}
+ self.batch.batch_size = 2
+ self.assertEqual(self.batch._get_next(), {"foo", "bar"})
+
+ def test_next_batch_one_done(self):
+ self.batch.minions = {"foo", "bar"}
+ self.batch.done_minions = {"bar"}
+ self.batch.batch_size = 2
+ self.assertEqual(self.batch._get_next(), {"foo"})
+
+ def test_next_batch_one_done_one_active(self):
+ self.batch.minions = {"foo", "bar", "baz"}
+ self.batch.done_minions = {"bar"}
+ self.batch.active = {"baz"}
+ self.batch.batch_size = 2
+ self.assertEqual(self.batch._get_next(), {"foo"})
+
+ def test_next_batch_one_done_one_active_one_timedout(self):
+ self.batch.minions = {"foo", "bar", "baz", "faz"}
+ self.batch.done_minions = {"bar"}
+ self.batch.active = {"baz"}
+ self.batch.timedout_minions = {"faz"}
+ self.batch.batch_size = 2
+ self.assertEqual(self.batch._get_next(), {"foo"})
+
+ def test_next_batch_bigger_size(self):
+ self.batch.minions = {"foo", "bar"}
+ self.batch.batch_size = 3
+ self.assertEqual(self.batch._get_next(), {"foo", "bar"})
+
+ def test_next_batch_all_done(self):
+ self.batch.minions = {"foo", "bar"}
+ self.batch.done_minions = {"foo", "bar"}
+ self.batch.batch_size = 2
+ self.assertEqual(self.batch._get_next(), set())
+
+ def test_next_batch_all_active(self):
+ self.batch.minions = {"foo", "bar"}
+ self.batch.active = {"foo", "bar"}
+ self.batch.batch_size = 2
+ self.assertEqual(self.batch._get_next(), set())
+
+ def test_next_batch_all_timedout(self):
+ self.batch.minions = {"foo", "bar"}
+ self.batch.timedout_minions = {"foo", "bar"}
+ self.batch.batch_size = 2
+ self.assertEqual(self.batch._get_next(), set())
+
+ def test_batch__event_handler_ping_return(self):
+ self.batch.targeted_minions = {"foo"}
+ self.batch.event = MagicMock(
+ unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
+ )
+ self.batch.start()
+ self.assertEqual(self.batch.minions, set())
+ self.batch._BatchAsync__event_handler(MagicMock())
+ self.assertEqual(self.batch.minions, {"foo"})
+ self.assertEqual(self.batch.done_minions, set())
+
+ def test_batch__event_handler_call_start_batch_when_all_pings_return(self):
+ self.batch.targeted_minions = {"foo"}
+ self.batch.event = MagicMock(
+ unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
+ )
+ self.batch.start()
+ self.batch._BatchAsync__event_handler(MagicMock())
+ self.assertEqual(
+ self.batch.event.io_loop.spawn_callback.call_args[0],
+ (self.batch.start_batch,),
+ )
+
+ def test_batch__event_handler_not_call_start_batch_when_not_all_pings_return(self):
+ self.batch.targeted_minions = {"foo", "bar"}
+ self.batch.event = MagicMock(
+ unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"}))
+ )
+ self.batch.start()
+ self.batch._BatchAsync__event_handler(MagicMock())
+ self.assertEqual(len(self.batch.event.io_loop.spawn_callback.mock_calls), 0)
+
+ def test_batch__event_handler_batch_run_return(self):
+ self.batch.event = MagicMock(
+ unpack=MagicMock(return_value=("salt/job/1235/ret/foo", {"id": "foo"}))
+ )
+ self.batch.start()
+ self.batch.active = {"foo"}
+ self.batch._BatchAsync__event_handler(MagicMock())
+ self.assertEqual(self.batch.active, set())
+ self.assertEqual(self.batch.done_minions, {"foo"})
+ self.assertEqual(
+ self.batch.event.io_loop.spawn_callback.call_args[0],
+ (self.batch.schedule_next,),
+ )
+
+ def test_batch__event_handler_find_job_return(self):
+ self.batch.event = MagicMock(
+ unpack=MagicMock(
+ return_value=(
+ "salt/job/1236/ret/foo",
+ {"id": "foo", "return": "deadbeaf"},
+ )
+ )
+ )
+ self.batch.start()
+ self.batch.patterns.add(("salt/job/1236/ret/*", "find_job_return"))
+ self.batch._BatchAsync__event_handler(MagicMock())
+ self.assertEqual(self.batch.find_job_returned, {"foo"})
+
+ @salt.ext.tornado.testing.gen_test
+ def test_batch_run_next_end_batch_when_no_next(self):
+ self.batch.end_batch = MagicMock()
+ self.batch._get_next = MagicMock(return_value={})
+ self.batch.run_next()
+ self.assertEqual(len(self.batch.end_batch.mock_calls), 1)
+
+ @salt.ext.tornado.testing.gen_test
+ def test_batch_find_job(self):
+ self.batch.event = MagicMock()
+ future = salt.ext.tornado.gen.Future()
+ future.set_result({})
+ self.batch.local.run_job_async.return_value = future
+ self.batch.minions = {"foo", "bar"}
+ self.batch.jid_gen = MagicMock(return_value="1234")
+ salt.ext.tornado.gen.sleep = MagicMock(return_value=future)
+ self.batch.find_job({"foo", "bar"})
+ self.assertEqual(
+ self.batch.event.io_loop.spawn_callback.call_args[0],
+ (self.batch.check_find_job, {"foo", "bar"}, "1234"),
+ )
+
+ @salt.ext.tornado.testing.gen_test
+ def test_batch_find_job_with_done_minions(self):
+ self.batch.done_minions = {"bar"}
+ self.batch.event = MagicMock()
+ future = salt.ext.tornado.gen.Future()
+ future.set_result({})
+ self.batch.local.run_job_async.return_value = future
+ self.batch.minions = {"foo", "bar"}
+ self.batch.jid_gen = MagicMock(return_value="1234")
+ salt.ext.tornado.gen.sleep = MagicMock(return_value=future)
+ self.batch.find_job({"foo", "bar"})
+ self.assertEqual(
+ self.batch.event.io_loop.spawn_callback.call_args[0],
+ (self.batch.check_find_job, {"foo"}, "1234"),
+ )
+
+ def test_batch_check_find_job_did_not_return(self):
+ self.batch.event = MagicMock()
+ self.batch.active = {"foo"}
+ self.batch.find_job_returned = set()
+ self.batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
+ self.batch.check_find_job({"foo"}, jid="1234")
+ self.assertEqual(self.batch.find_job_returned, set())
+ self.assertEqual(self.batch.active, set())
+ self.assertEqual(len(self.batch.event.io_loop.add_callback.mock_calls), 0)
+
+ def test_batch_check_find_job_did_return(self):
+ self.batch.event = MagicMock()
+ self.batch.find_job_returned = {"foo"}
+ self.batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
+ self.batch.check_find_job({"foo"}, jid="1234")
+ self.assertEqual(
+ self.batch.event.io_loop.spawn_callback.call_args[0],
+ (self.batch.find_job, {"foo"}),
+ )
+
+ def test_batch_check_find_job_multiple_states(self):
+ self.batch.event = MagicMock()
+ # currently running minions
+ self.batch.active = {"foo", "bar"}
+
+ # minion is running and find_job returns
+ self.batch.find_job_returned = {"foo"}
+
+ # minion started running but find_job did not return
+ self.batch.timedout_minions = {"faz"}
+
+ # minion finished
+ self.batch.done_minions = {"baz"}
+
+ # both not yet done but only 'foo' responded to find_job
+ not_done = {"foo", "bar"}
+
+ self.batch.patterns = {("salt/job/1234/ret/*", "find_job_return")}
+ self.batch.check_find_job(not_done, jid="1234")
+
+ # assert 'bar' removed from active
+ self.assertEqual(self.batch.active, {"foo"})
+
+ # assert 'bar' added to timedout_minions
+ self.assertEqual(self.batch.timedout_minions, {"bar", "faz"})
+
+ # assert 'find_job' schedueled again only for 'foo'
+ self.assertEqual(
+ self.batch.event.io_loop.spawn_callback.call_args[0],
+ (self.batch.find_job, {"foo"}),
+ )
+
+ def test_only_on_run_next_is_scheduled(self):
+ self.batch.event = MagicMock()
+ self.batch.scheduled = True
+ self.batch.schedule_next()
+ self.assertEqual(len(self.batch.event.io_loop.spawn_callback.mock_calls), 0)
--
2.39.2