From 78faccbd063b8635550935057b8630262958f669 Mon Sep 17 00:00:00 2001 From: Mihai Dinca Date: Fri, 16 Nov 2018 17:05:29 +0100 Subject: [PATCH] Async batch implementation Add find_job checks Check if should close on all events Make batch_delay a request parameter Allow multiple event handlers Use config value for gather_job_timeout when not in payload Add async batch unittests Allow metadata to pass Pass metadata only to batch jobs Add the metadata to the start/done events Pass only metadata not all **kwargs Add separate batch presence_ping timeout --- salt/auth/__init__.py | 2 + salt/cli/batch.py | 115 +++++++--- salt/cli/batch_async.py | 240 +++++++++++++++++++ salt/client/__init__.py | 14 ++ salt/master.py | 26 ++- salt/netapi/__init__.py | 3 +- salt/transport/ipc.py | 43 ++-- salt/utils/event.py | 8 +- tests/unit/cli/test_batch_async.py | 357 +++++++++++++++++++++++++++++ 9 files changed, 741 insertions(+), 67 deletions(-) create mode 100644 salt/cli/batch_async.py create mode 100644 tests/unit/cli/test_batch_async.py diff --git a/salt/auth/__init__.py b/salt/auth/__init__.py index ee1eac7ce4..22c54e8048 100644 --- a/salt/auth/__init__.py +++ b/salt/auth/__init__.py @@ -52,6 +52,8 @@ AUTH_INTERNAL_KEYWORDS = frozenset( "print_event", "raw", "yield_pub_data", + "batch", + "batch_delay", ] ) diff --git a/salt/cli/batch.py b/salt/cli/batch.py index 155dc734b7..527cffdeb7 100644 --- a/salt/cli/batch.py +++ b/salt/cli/batch.py @@ -1,10 +1,7 @@ -# -*- coding: utf-8 -*- """ Execute batch runs """ -# Import python libs -from __future__ import absolute_import, print_function, unicode_literals import copy @@ -17,11 +14,8 @@ from datetime import datetime, timedelta import salt.client import salt.exceptions import salt.output - -# Import salt libs import salt.utils.stringutils -# Import 3rd-party libs # pylint: disable=import-error,no-name-in-module,redefined-builtin from salt.ext import six from salt.ext.six.moves import range @@ -29,7 +23,77 @@ from salt.ext.six.moves import range log = logging.getLogger(__name__) -class Batch(object): +def get_bnum(opts, minions, quiet): + """ + Return the active number of minions to maintain + """ + partition = lambda x: float(x) / 100.0 * len(minions) + try: + if isinstance(opts["batch"], str) and "%" in opts["batch"]: + res = partition(float(opts["batch"].strip("%"))) + if res < 1: + return int(math.ceil(res)) + else: + return int(res) + else: + return int(opts["batch"]) + except ValueError: + if not quiet: + salt.utils.stringutils.print_cli( + "Invalid batch data sent: {}\nData must be in the " + "form of %10, 10% or 3".format(opts["batch"]) + ) + + +def batch_get_opts( + tgt, fun, batch, parent_opts, arg=(), tgt_type="glob", ret="", kwarg=None, **kwargs +): + # We need to re-import salt.utils.args here + # even though it has already been imported. + # when cmd_batch is called via the NetAPI + # the module is unavailable. + import salt.utils.args + + arg = salt.utils.args.condition_input(arg, kwarg) + opts = { + "tgt": tgt, + "fun": fun, + "arg": arg, + "tgt_type": tgt_type, + "ret": ret, + "batch": batch, + "failhard": kwargs.get("failhard", parent_opts.get("failhard", False)), + "raw": kwargs.get("raw", False), + } + + if "timeout" in kwargs: + opts["timeout"] = kwargs["timeout"] + if "gather_job_timeout" in kwargs: + opts["gather_job_timeout"] = kwargs["gather_job_timeout"] + if "batch_wait" in kwargs: + opts["batch_wait"] = int(kwargs["batch_wait"]) + + for key, val in parent_opts.items(): + if key not in opts: + opts[key] = val + + return opts + + +def batch_get_eauth(kwargs): + eauth = {} + if "eauth" in kwargs: + eauth["eauth"] = kwargs.pop("eauth") + if "username" in kwargs: + eauth["username"] = kwargs.pop("username") + if "password" in kwargs: + eauth["password"] = kwargs.pop("password") + if "token" in kwargs: + eauth["token"] = kwargs.pop("token") + return eauth + + +class Batch: """ Manage the execution of batch runs """ @@ -75,7 +139,7 @@ class Batch(object): continue else: try: - m = next(six.iterkeys(ret)) + m = next(iter(ret.keys())) except StopIteration: if not self.quiet: salt.utils.stringutils.print_cli( @@ -87,28 +151,7 @@ class Batch(object): return (list(fret), ping_gen, nret.difference(fret)) def get_bnum(self): - """ - Return the active number of minions to maintain - """ - partition = lambda x: float(x) / 100.0 * len(self.minions) - try: - if ( - isinstance(self.opts["batch"], six.string_types) - and "%" in self.opts["batch"] - ): - res = partition(float(self.opts["batch"].strip("%"))) - if res < 1: - return int(math.ceil(res)) - else: - return int(res) - else: - return int(self.opts["batch"]) - except ValueError: - if not self.quiet: - salt.utils.stringutils.print_cli( - "Invalid batch data sent: {0}\nData must be in the " - "form of %10, 10% or 3".format(self.opts["batch"]) - ) + return get_bnum(self.opts, self.minions, self.quiet) def __update_wait(self, wait): now = datetime.now() @@ -161,7 +204,7 @@ class Batch(object): # the user we won't be attempting to run a job on them for down_minion in self.down_minions: salt.utils.stringutils.print_cli( - "Minion {0} did not respond. No job will be sent.".format( + "Minion {} did not respond. No job will be sent.".format( down_minion ) ) @@ -190,7 +233,7 @@ class Batch(object): if next_: if not self.quiet: salt.utils.stringutils.print_cli( - "\nExecuting run on {0}\n".format(sorted(next_)) + "\nExecuting run on {}\n".format(sorted(next_)) ) # create a new iterator for this batch of minions return_value = self.opts.get("return", self.opts.get("ret", "")) @@ -218,7 +261,7 @@ class Batch(object): for ping_ret in self.ping_gen: if ping_ret is None: break - m = next(six.iterkeys(ping_ret)) + m = next(iter(ping_ret.keys())) if m not in self.minions: self.minions.append(m) to_run.append(m) @@ -243,7 +286,7 @@ class Batch(object): ) else: salt.utils.stringutils.print_cli( - "minion {0} was already deleted from tracker, probably a duplicate key".format( + "minion {} was already deleted from tracker, probably a duplicate key".format( part["id"] ) ) @@ -254,7 +297,7 @@ class Batch(object): minion_tracker[queue]["minions"].remove(id) else: salt.utils.stringutils.print_cli( - "minion {0} was already deleted from tracker, probably a duplicate key".format( + "minion {} was already deleted from tracker, probably a duplicate key".format( id ) ) @@ -274,7 +317,7 @@ class Batch(object): parts[minion] = {} parts[minion]["ret"] = {} - for minion, data in six.iteritems(parts): + for minion, data in parts.items(): if minion in active: active.remove(minion) if bwait: diff --git a/salt/cli/batch_async.py b/salt/cli/batch_async.py new file mode 100644 index 0000000000..1557e5105b --- /dev/null +++ b/salt/cli/batch_async.py @@ -0,0 +1,240 @@ +""" +Execute a job on the targeted minions by using a moving window of fixed size `batch`. +""" + +import fnmatch + +# pylint: enable=import-error,no-name-in-module,redefined-builtin +import logging + +import salt.client +import tornado +from salt.cli.batch import batch_get_eauth, batch_get_opts, get_bnum + +log = logging.getLogger(__name__) + + +class BatchAsync: + """ + Run a job on the targeted minions by using a moving window of fixed size `batch`. + + ``BatchAsync`` is used to execute a job on the targeted minions by keeping + the number of concurrent running minions to the size of `batch` parameter. + + The control parameters are: + - batch: number/percentage of concurrent running minions + - batch_delay: minimum wait time between batches + - batch_presence_ping_timeout: time to wait for presence pings before starting the batch + - gather_job_timeout: `find_job` timeout + - timeout: time to wait before firing a `find_job` + + When the batch stars, a `start` event is fired: + - tag: salt/batch//start + - data: { + "available_minions": self.minions, + "down_minions": self.down_minions + } + + When the batch ends, an `done` event is fired: + - tag: salt/batch//done + - data: { + "available_minions": self.minions, + "down_minions": self.down_minions, + "done_minions": self.done_minions, + "timedout_minions": self.timedout_minions + } + """ + + def __init__(self, parent_opts, jid_gen, clear_load): + ioloop = tornado.ioloop.IOLoop.current() + self.local = salt.client.get_local_client(parent_opts["conf_file"]) + if "gather_job_timeout" in clear_load["kwargs"]: + clear_load["gather_job_timeout"] = clear_load["kwargs"].pop( + "gather_job_timeout" + ) + else: + clear_load["gather_job_timeout"] = self.local.opts["gather_job_timeout"] + self.batch_presence_ping_timeout = clear_load["kwargs"].get( + "batch_presence_ping_timeout", None + ) + self.batch_delay = clear_load["kwargs"].get("batch_delay", 1) + self.opts = batch_get_opts( + clear_load.pop("tgt"), + clear_load.pop("fun"), + clear_load["kwargs"].pop("batch"), + self.local.opts, + **clear_load + ) + self.eauth = batch_get_eauth(clear_load["kwargs"]) + self.metadata = clear_load["kwargs"].get("metadata", {}) + self.minions = set() + self.down_minions = set() + self.timedout_minions = set() + self.done_minions = set() + self.active = set() + self.initialized = False + self.ping_jid = jid_gen() + self.batch_jid = jid_gen() + self.find_job_jid = jid_gen() + self.find_job_returned = set() + self.event = salt.utils.event.get_event( + "master", + self.opts["sock_dir"], + self.opts["transport"], + opts=self.opts, + listen=True, + io_loop=ioloop, + keep_loop=True, + ) + + def __set_event_handler(self): + ping_return_pattern = "salt/job/{}/ret/*".format(self.ping_jid) + batch_return_pattern = "salt/job/{}/ret/*".format(self.batch_jid) + find_job_return_pattern = "salt/job/{}/ret/*".format(self.find_job_jid) + self.event.subscribe(ping_return_pattern, match_type="glob") + self.event.subscribe(batch_return_pattern, match_type="glob") + self.event.subscribe(find_job_return_pattern, match_type="glob") + self.event.patterns = { + (ping_return_pattern, "ping_return"), + (batch_return_pattern, "batch_run"), + (find_job_return_pattern, "find_job_return"), + } + self.event.set_event_handler(self.__event_handler) + + def __event_handler(self, raw): + if not self.event: + return + mtag, data = self.event.unpack(raw, self.event.serial) + for (pattern, op) in self.event.patterns: + if fnmatch.fnmatch(mtag, pattern): + minion = data["id"] + if op == "ping_return": + self.minions.add(minion) + self.down_minions.remove(minion) + if not self.down_minions: + self.event.io_loop.spawn_callback(self.start_batch) + elif op == "find_job_return": + self.find_job_returned.add(minion) + elif op == "batch_run": + if minion in self.active: + self.active.remove(minion) + self.done_minions.add(minion) + # call later so that we maybe gather more returns + self.event.io_loop.call_later( + self.batch_delay, self.schedule_next + ) + + if self.initialized and self.done_minions == self.minions.difference( + self.timedout_minions + ): + self.end_batch() + + def _get_next(self): + to_run = ( + self.minions.difference(self.done_minions) + .difference(self.active) + .difference(self.timedout_minions) + ) + next_batch_size = min( + len(to_run), # partial batch (all left) + self.batch_size - len(self.active), # full batch or available slots + ) + return set(list(to_run)[:next_batch_size]) + + @tornado.gen.coroutine + def check_find_job(self, minions): + did_not_return = minions.difference(self.find_job_returned) + if did_not_return: + for minion in did_not_return: + if minion in self.find_job_returned: + self.find_job_returned.remove(minion) + if minion in self.active: + self.active.remove(minion) + self.timedout_minions.add(minion) + running = ( + minions.difference(did_not_return) + .difference(self.done_minions) + .difference(self.timedout_minions) + ) + if running: + self.event.io_loop.add_callback(self.find_job, running) + + @tornado.gen.coroutine + def find_job(self, minions): + not_done = minions.difference(self.done_minions) + ping_return = yield self.local.run_job_async( + not_done, + "saltutil.find_job", + [self.batch_jid], + "list", + gather_job_timeout=self.opts["gather_job_timeout"], + jid=self.find_job_jid, + **self.eauth + ) + self.event.io_loop.call_later( + self.opts["gather_job_timeout"], self.check_find_job, not_done + ) + + @tornado.gen.coroutine + def start(self): + self.__set_event_handler() + # start batching even if not all minions respond to ping + self.event.io_loop.call_later( + self.batch_presence_ping_timeout or self.opts["gather_job_timeout"], + self.start_batch, + ) + ping_return = yield self.local.run_job_async( + self.opts["tgt"], + "test.ping", + [], + self.opts.get("selected_target_option", self.opts.get("tgt_type", "glob")), + gather_job_timeout=self.opts["gather_job_timeout"], + jid=self.ping_jid, + metadata=self.metadata, + **self.eauth + ) + self.down_minions = set(ping_return["minions"]) + + @tornado.gen.coroutine + def start_batch(self): + if not self.initialized: + self.batch_size = get_bnum(self.opts, self.minions, True) + self.initialized = True + data = { + "available_minions": self.minions, + "down_minions": self.down_minions, + "metadata": self.metadata, + } + self.event.fire_event(data, "salt/batch/{}/start".format(self.batch_jid)) + yield self.schedule_next() + + def end_batch(self): + data = { + "available_minions": self.minions, + "down_minions": self.down_minions, + "done_minions": self.done_minions, + "timedout_minions": self.timedout_minions, + "metadata": self.metadata, + } + self.event.fire_event(data, "salt/batch/{}/done".format(self.batch_jid)) + self.event.remove_event_handler(self.__event_handler) + + @tornado.gen.coroutine + def schedule_next(self): + next_batch = self._get_next() + if next_batch: + yield self.local.run_job_async( + next_batch, + self.opts["fun"], + self.opts["arg"], + "list", + raw=self.opts.get("raw", False), + ret=self.opts.get("return", ""), + gather_job_timeout=self.opts["gather_job_timeout"], + jid=self.batch_jid, + metadata=self.metadata, + ) + self.event.io_loop.call_later( + self.opts["timeout"], self.find_job, set(next_batch) + ) + self.active = self.active.union(next_batch) diff --git a/salt/client/__init__.py b/salt/client/__init__.py index 6fab45fcbf..1e9f11df4c 100644 --- a/salt/client/__init__.py +++ b/salt/client/__init__.py @@ -543,6 +543,20 @@ class LocalClient: # Late import - not used anywhere else in this file import salt.cli.batch + opts = salt.cli.batch.batch_get_opts( + tgt, + fun, + batch, + self.opts, + arg=arg, + tgt_type=tgt_type, + ret=ret, + kwarg=kwarg, + **kwargs + ) + + eauth = salt.cli.batch.batch_get_eauth(kwargs) + arg = salt.utils.args.condition_input(arg, kwarg) opts = { "tgt": tgt, diff --git a/salt/master.py b/salt/master.py index 1c91c28209..b9bc1a7a67 100644 --- a/salt/master.py +++ b/salt/master.py @@ -3,7 +3,6 @@ This module contains all of the routines needed to set up a master server, this involves preparing the three listeners and the workers needed by the master. """ -# Import python libs import collections import copy @@ -21,10 +20,9 @@ import time import salt.acl import salt.auth +import salt.cli.batch_async import salt.client import salt.client.ssh.client - -# Import salt libs import salt.crypt import salt.daemons.masterapi import salt.defaults.exitcodes @@ -89,7 +87,6 @@ except ImportError: # resource is not available on windows HAS_RESOURCE = False -# Import halite libs try: import halite # pylint: disable=import-error @@ -2232,6 +2229,24 @@ class ClearFuncs(TransportMethods): return False return self.loadauth.get_tok(clear_load["token"]) + def publish_batch(self, clear_load, minions, missing): + batch_load = {} + batch_load.update(clear_load) + import salt.cli.batch_async + + batch = salt.cli.batch_async.BatchAsync( + self.local.opts, + functools.partial(self._prep_jid, clear_load, {}), + batch_load, + ) + ioloop = tornado.ioloop.IOLoop.current() + ioloop.add_callback(batch.start) + + return { + "enc": "clear", + "load": {"jid": batch.batch_jid, "minions": minions, "missing": missing}, + } + def publish(self, clear_load): """ This method sends out publications to the minions, it can only be used @@ -2349,6 +2364,9 @@ class ClearFuncs(TransportMethods): ), }, } + if extra.get("batch", None): + return self.publish_batch(clear_load, minions, missing) + jid = self._prep_jid(clear_load, extra) if jid is None: return {"enc": "clear", "load": {"error": "Master failed to assign jid"}} diff --git a/salt/netapi/__init__.py b/salt/netapi/__init__.py index 96f57f6c79..dec19b37ef 100644 --- a/salt/netapi/__init__.py +++ b/salt/netapi/__init__.py @@ -151,7 +151,8 @@ class NetapiClient: :return: job ID """ local = salt.client.get_local_client(mopts=self.opts) - return local.run_job(*args, **kwargs) + ret = local.run_job(*args, **kwargs) + return ret def local(self, *args, **kwargs): """ diff --git a/salt/transport/ipc.py b/salt/transport/ipc.py index 041718d058..f411907da2 100644 --- a/salt/transport/ipc.py +++ b/salt/transport/ipc.py @@ -1,10 +1,7 @@ -# -*- coding: utf-8 -*- """ IPC transport classes """ -# Import Python libs -from __future__ import absolute_import, print_function, unicode_literals import errno import logging @@ -12,15 +9,12 @@ import socket import sys import time -# Import Tornado libs import salt.ext.tornado import salt.ext.tornado.concurrent import salt.ext.tornado.gen import salt.ext.tornado.netutil import salt.transport.client import salt.transport.frame - -# Import Salt libs import salt.utils.msgpack from salt.ext import six from salt.ext.tornado.ioloop import IOLoop @@ -42,7 +36,7 @@ def future_with_timeout_callback(future): class FutureWithTimeout(salt.ext.tornado.concurrent.Future): def __init__(self, io_loop, future, timeout): - super(FutureWithTimeout, self).__init__() + super().__init__() self.io_loop = io_loop self._future = future if timeout is not None: @@ -85,7 +79,7 @@ class FutureWithTimeout(salt.ext.tornado.concurrent.Future): self.set_exception(exc) -class IPCServer(object): +class IPCServer: """ A Tornado IPC server very similar to Tornado's TCPServer class but using either UNIX domain sockets or TCP sockets @@ -181,10 +175,7 @@ class IPCServer(object): # Under Py2 we still want raw to be set to True msgpack_kwargs = {"raw": six.PY2} else: - if six.PY2: - msgpack_kwargs = {"encoding": None} - else: - msgpack_kwargs = {"encoding": "utf-8"} + msgpack_kwargs = {"encoding": "utf-8"} unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs) while not stream.closed(): try: @@ -200,7 +191,7 @@ class IPCServer(object): except StreamClosedError: log.trace("Client disconnected from IPC %s", self.socket_path) break - except socket.error as exc: + except OSError as exc: # On occasion an exception will occur with # an error code of 0, it's a spurious exception. if exc.errno == 0: @@ -247,7 +238,7 @@ class IPCServer(object): # pylint: enable=W1701 -class IPCClient(object): +class IPCClient: """ A Tornado IPC client very similar to Tornado's TCPClient class but using either UNIX domain sockets or TCP sockets @@ -282,10 +273,7 @@ class IPCClient(object): # Under Py2 we still want raw to be set to True msgpack_kwargs = {"raw": six.PY2} else: - if six.PY2: - msgpack_kwargs = {"encoding": None} - else: - msgpack_kwargs = {"encoding": "utf-8"} + msgpack_kwargs = {"encoding": "utf-8"} self.unpacker = salt.utils.msgpack.Unpacker(**msgpack_kwargs) def connected(self): @@ -385,10 +373,10 @@ class IPCClient(object): if self.stream is not None and not self.stream.closed(): try: self.stream.close() - except socket.error as exc: + except OSError as exc: if exc.errno != errno.EBADF: # If its not a bad file descriptor error, raise - six.reraise(*sys.exc_info()) + raise class IPCMessageClient(IPCClient): @@ -483,7 +471,7 @@ class IPCMessageServer(IPCServer): """ -class IPCMessagePublisher(object): +class IPCMessagePublisher: """ A Tornado IPC Publisher similar to Tornado's TCPServer class but using either UNIX domain sockets or TCP sockets @@ -645,10 +633,11 @@ class IPCMessageSubscriber(IPCClient): """ def __init__(self, socket_path, io_loop=None): - super(IPCMessageSubscriber, self).__init__(socket_path, io_loop=io_loop) + super().__init__(socket_path, io_loop=io_loop) self._read_stream_future = None self._saved_data = [] self._read_in_progress = Lock() + self.callbacks = set() @salt.ext.tornado.gen.coroutine def _read(self, timeout, callback=None): @@ -725,8 +714,12 @@ class IPCMessageSubscriber(IPCClient): return self._saved_data.pop(0) return self.io_loop.run_sync(lambda: self._read(timeout)) + def __run_callbacks(self, raw): + for callback in self.callbacks: + self.io_loop.spawn_callback(callback, raw) + @salt.ext.tornado.gen.coroutine - def read_async(self, callback): + def read_async(self): """ Asynchronously read messages and invoke a callback when they are ready. @@ -744,7 +737,7 @@ class IPCMessageSubscriber(IPCClient): except Exception as exc: # pylint: disable=broad-except log.error("Exception occurred while Subscriber connecting: %s", exc) yield salt.ext.tornado.gen.sleep(1) - yield self._read(None, callback) + yield self._read(None, self.__run_callbacks) def close(self): """ @@ -754,7 +747,7 @@ class IPCMessageSubscriber(IPCClient): """ if self._closing: return - super(IPCMessageSubscriber, self).close() + super().close() # This will prevent this message from showing up: # '[ERROR ] Future exception was never retrieved: # StreamClosedError' diff --git a/salt/utils/event.py b/salt/utils/event.py index 6f7edef4e5..ae200f9dfa 100644 --- a/salt/utils/event.py +++ b/salt/utils/event.py @@ -867,6 +867,10 @@ class SaltEvent: # Minion fired a bad retcode, fire an event self._fire_ret_load_specific_fun(load) + def remove_event_handler(self, event_handler): + if event_handler in self.subscriber.callbacks: + self.subscriber.callbacks.remove(event_handler) + def set_event_handler(self, event_handler): """ Invoke the event_handler callback each time an event arrives. @@ -875,8 +879,10 @@ class SaltEvent: if not self.cpub: self.connect_pub() + + self.subscriber.callbacks.add(event_handler) # This will handle reconnects - return self.subscriber.read_async(event_handler) + return self.subscriber.read_async() # pylint: disable=W1701 def __del__(self): diff --git a/tests/unit/cli/test_batch_async.py b/tests/unit/cli/test_batch_async.py new file mode 100644 index 0000000000..3f8626a2dd --- /dev/null +++ b/tests/unit/cli/test_batch_async.py @@ -0,0 +1,357 @@ +import tornado +from salt.cli.batch_async import BatchAsync +from tests.support.mock import NO_MOCK, NO_MOCK_REASON, MagicMock, patch +from tests.support.unit import TestCase, skipIf +from tornado.testing import AsyncTestCase + + +@skipIf(NO_MOCK, NO_MOCK_REASON) +class AsyncBatchTestCase(AsyncTestCase, TestCase): + def setUp(self): + self.io_loop = self.get_new_ioloop() + opts = { + "batch": "1", + "conf_file": {}, + "tgt": "*", + "timeout": 5, + "gather_job_timeout": 5, + "batch_presence_ping_timeout": 1, + "transport": None, + "sock_dir": "", + } + + with patch("salt.client.get_local_client", MagicMock(return_value=MagicMock())): + with patch( + "salt.cli.batch_async.batch_get_opts", MagicMock(return_value=opts) + ): + self.batch = BatchAsync( + opts, + MagicMock(side_effect=["1234", "1235", "1236"]), + { + "tgt": "", + "fun": "", + "kwargs": {"batch": "", "batch_presence_ping_timeout": 1}, + }, + ) + + def test_ping_jid(self): + self.assertEqual(self.batch.ping_jid, "1234") + + def test_batch_jid(self): + self.assertEqual(self.batch.batch_jid, "1235") + + def test_find_job_jid(self): + self.assertEqual(self.batch.find_job_jid, "1236") + + def test_batch_size(self): + """ + Tests passing batch value as a number + """ + self.batch.opts = {"batch": "2", "timeout": 5} + self.batch.minions = {"foo", "bar"} + self.batch.start_batch() + self.assertEqual(self.batch.batch_size, 2) + + @tornado.testing.gen_test + def test_batch_start_on_batch_presence_ping_timeout(self): + self.batch.event = MagicMock() + future = tornado.gen.Future() + future.set_result({"minions": ["foo", "bar"]}) + self.batch.local.run_job_async.return_value = future + ret = self.batch.start() + # assert start_batch is called later with batch_presence_ping_timeout as param + self.assertEqual( + self.batch.event.io_loop.call_later.call_args[0], + (self.batch.batch_presence_ping_timeout, self.batch.start_batch), + ) + # assert test.ping called + self.assertEqual( + self.batch.local.run_job_async.call_args[0], ("*", "test.ping", [], "glob") + ) + # assert down_minions == all minions matched by tgt + self.assertEqual(self.batch.down_minions, {"foo", "bar"}) + + @tornado.testing.gen_test + def test_batch_start_on_gather_job_timeout(self): + self.batch.event = MagicMock() + future = tornado.gen.Future() + future.set_result({"minions": ["foo", "bar"]}) + self.batch.local.run_job_async.return_value = future + self.batch.batch_presence_ping_timeout = None + ret = self.batch.start() + # assert start_batch is called later with gather_job_timeout as param + self.assertEqual( + self.batch.event.io_loop.call_later.call_args[0], + (self.batch.opts["gather_job_timeout"], self.batch.start_batch), + ) + + def test_batch_fire_start_event(self): + self.batch.minions = {"foo", "bar"} + self.batch.opts = {"batch": "2", "timeout": 5} + self.batch.event = MagicMock() + self.batch.metadata = {"mykey": "myvalue"} + self.batch.start_batch() + self.assertEqual( + self.batch.event.fire_event.call_args[0], + ( + { + "available_minions": {"foo", "bar"}, + "down_minions": set(), + "metadata": self.batch.metadata, + }, + "salt/batch/1235/start", + ), + ) + + @tornado.testing.gen_test + def test_start_batch_calls_next(self): + self.batch.schedule_next = MagicMock(return_value=MagicMock()) + self.batch.event = MagicMock() + future = tornado.gen.Future() + future.set_result(None) + self.batch.schedule_next = MagicMock(return_value=future) + self.batch.start_batch() + self.assertEqual(self.batch.initialized, True) + self.assertEqual(len(self.batch.schedule_next.mock_calls), 1) + + def test_batch_fire_done_event(self): + self.batch.minions = {"foo", "bar"} + self.batch.event = MagicMock() + self.batch.metadata = {"mykey": "myvalue"} + self.batch.end_batch() + self.assertEqual( + self.batch.event.fire_event.call_args[0], + ( + { + "available_minions": {"foo", "bar"}, + "done_minions": set(), + "down_minions": set(), + "timedout_minions": set(), + "metadata": self.batch.metadata, + }, + "salt/batch/1235/done", + ), + ) + self.assertEqual(len(self.batch.event.remove_event_handler.mock_calls), 1) + + @tornado.testing.gen_test + def test_batch_next(self): + self.batch.event = MagicMock() + self.batch.opts["fun"] = "my.fun" + self.batch.opts["arg"] = [] + self.batch._get_next = MagicMock(return_value={"foo", "bar"}) + self.batch.batch_size = 2 + future = tornado.gen.Future() + future.set_result({"minions": ["foo", "bar"]}) + self.batch.local.run_job_async.return_value = future + ret = self.batch.schedule_next().result() + self.assertEqual( + self.batch.local.run_job_async.call_args[0], + ({"foo", "bar"}, "my.fun", [], "list"), + ) + self.assertEqual( + self.batch.event.io_loop.call_later.call_args[0], + (self.batch.opts["timeout"], self.batch.find_job, {"foo", "bar"}), + ) + self.assertEqual(self.batch.active, {"bar", "foo"}) + + def test_next_batch(self): + self.batch.minions = {"foo", "bar"} + self.batch.batch_size = 2 + self.assertEqual(self.batch._get_next(), {"foo", "bar"}) + + def test_next_batch_one_done(self): + self.batch.minions = {"foo", "bar"} + self.batch.done_minions = {"bar"} + self.batch.batch_size = 2 + self.assertEqual(self.batch._get_next(), {"foo"}) + + def test_next_batch_one_done_one_active(self): + self.batch.minions = {"foo", "bar", "baz"} + self.batch.done_minions = {"bar"} + self.batch.active = {"baz"} + self.batch.batch_size = 2 + self.assertEqual(self.batch._get_next(), {"foo"}) + + def test_next_batch_one_done_one_active_one_timedout(self): + self.batch.minions = {"foo", "bar", "baz", "faz"} + self.batch.done_minions = {"bar"} + self.batch.active = {"baz"} + self.batch.timedout_minions = {"faz"} + self.batch.batch_size = 2 + self.assertEqual(self.batch._get_next(), {"foo"}) + + def test_next_batch_bigger_size(self): + self.batch.minions = {"foo", "bar"} + self.batch.batch_size = 3 + self.assertEqual(self.batch._get_next(), {"foo", "bar"}) + + def test_next_batch_all_done(self): + self.batch.minions = {"foo", "bar"} + self.batch.done_minions = {"foo", "bar"} + self.batch.batch_size = 2 + self.assertEqual(self.batch._get_next(), set()) + + def test_next_batch_all_active(self): + self.batch.minions = {"foo", "bar"} + self.batch.active = {"foo", "bar"} + self.batch.batch_size = 2 + self.assertEqual(self.batch._get_next(), set()) + + def test_next_batch_all_timedout(self): + self.batch.minions = {"foo", "bar"} + self.batch.timedout_minions = {"foo", "bar"} + self.batch.batch_size = 2 + self.assertEqual(self.batch._get_next(), set()) + + def test_batch__event_handler_ping_return(self): + self.batch.down_minions = {"foo"} + self.batch.event = MagicMock( + unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"})) + ) + self.batch.start() + self.assertEqual(self.batch.minions, set()) + self.batch._BatchAsync__event_handler(MagicMock()) + self.assertEqual(self.batch.minions, {"foo"}) + self.assertEqual(self.batch.done_minions, set()) + + def test_batch__event_handler_call_start_batch_when_all_pings_return(self): + self.batch.down_minions = {"foo"} + self.batch.event = MagicMock( + unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"})) + ) + self.batch.start() + self.batch._BatchAsync__event_handler(MagicMock()) + self.assertEqual( + self.batch.event.io_loop.spawn_callback.call_args[0], + (self.batch.start_batch,), + ) + + def test_batch__event_handler_not_call_start_batch_when_not_all_pings_return(self): + self.batch.down_minions = {"foo", "bar"} + self.batch.event = MagicMock( + unpack=MagicMock(return_value=("salt/job/1234/ret/foo", {"id": "foo"})) + ) + self.batch.start() + self.batch._BatchAsync__event_handler(MagicMock()) + self.assertEqual(len(self.batch.event.io_loop.spawn_callback.mock_calls), 0) + + def test_batch__event_handler_batch_run_return(self): + self.batch.event = MagicMock( + unpack=MagicMock(return_value=("salt/job/1235/ret/foo", {"id": "foo"})) + ) + self.batch.start() + self.batch.active = {"foo"} + self.batch._BatchAsync__event_handler(MagicMock()) + self.assertEqual(self.batch.active, set()) + self.assertEqual(self.batch.done_minions, {"foo"}) + self.assertEqual( + self.batch.event.io_loop.call_later.call_args[0], + (self.batch.batch_delay, self.batch.schedule_next), + ) + + def test_batch__event_handler_find_job_return(self): + self.batch.event = MagicMock( + unpack=MagicMock(return_value=("salt/job/1236/ret/foo", {"id": "foo"})) + ) + self.batch.start() + self.batch._BatchAsync__event_handler(MagicMock()) + self.assertEqual(self.batch.find_job_returned, {"foo"}) + + @tornado.testing.gen_test + def test_batch__event_handler_end_batch(self): + self.batch.event = MagicMock( + unpack=MagicMock( + return_value=("salt/job/not-my-jid/ret/foo", {"id": "foo"}) + ) + ) + future = tornado.gen.Future() + future.set_result({"minions": ["foo", "bar", "baz"]}) + self.batch.local.run_job_async.return_value = future + self.batch.start() + self.batch.initialized = True + self.assertEqual(self.batch.down_minions, {"foo", "bar", "baz"}) + self.batch.end_batch = MagicMock() + self.batch.minions = {"foo", "bar", "baz"} + self.batch.done_minions = {"foo", "bar"} + self.batch.timedout_minions = {"baz"} + self.batch._BatchAsync__event_handler(MagicMock()) + self.assertEqual(len(self.batch.end_batch.mock_calls), 1) + + @tornado.testing.gen_test + def test_batch_find_job(self): + self.batch.event = MagicMock() + future = tornado.gen.Future() + future.set_result({}) + self.batch.local.run_job_async.return_value = future + self.batch.find_job({"foo", "bar"}) + self.assertEqual( + self.batch.event.io_loop.call_later.call_args[0], + ( + self.batch.opts["gather_job_timeout"], + self.batch.check_find_job, + {"foo", "bar"}, + ), + ) + + @tornado.testing.gen_test + def test_batch_find_job_with_done_minions(self): + self.batch.done_minions = {"bar"} + self.batch.event = MagicMock() + future = tornado.gen.Future() + future.set_result({}) + self.batch.local.run_job_async.return_value = future + self.batch.find_job({"foo", "bar"}) + self.assertEqual( + self.batch.event.io_loop.call_later.call_args[0], + (self.batch.opts["gather_job_timeout"], self.batch.check_find_job, {"foo"}), + ) + + def test_batch_check_find_job_did_not_return(self): + self.batch.event = MagicMock() + self.batch.active = {"foo"} + self.batch.find_job_returned = set() + self.batch.check_find_job({"foo"}) + self.assertEqual(self.batch.find_job_returned, set()) + self.assertEqual(self.batch.active, set()) + self.assertEqual(len(self.batch.event.io_loop.add_callback.mock_calls), 0) + + def test_batch_check_find_job_did_return(self): + self.batch.event = MagicMock() + self.batch.find_job_returned = {"foo"} + self.batch.check_find_job({"foo"}) + self.assertEqual( + self.batch.event.io_loop.add_callback.call_args[0], + (self.batch.find_job, {"foo"}), + ) + + def test_batch_check_find_job_multiple_states(self): + self.batch.event = MagicMock() + # currently running minions + self.batch.active = {"foo", "bar"} + + # minion is running and find_job returns + self.batch.find_job_returned = {"foo"} + + # minion started running but find_job did not return + self.batch.timedout_minions = {"faz"} + + # minion finished + self.batch.done_minions = {"baz"} + + # both not yet done but only 'foo' responded to find_job + not_done = {"foo", "bar"} + + self.batch.check_find_job(not_done) + + # assert 'bar' removed from active + self.assertEqual(self.batch.active, {"foo"}) + + # assert 'bar' added to timedout_minions + self.assertEqual(self.batch.timedout_minions, {"bar", "faz"}) + + # assert 'find_job' schedueled again only for 'foo' + self.assertEqual( + self.batch.event.io_loop.add_callback.call_args[0], + (self.batch.find_job, {"foo"}), + ) -- 2.29.2