From 2c1311f544950fe417fb8609aa3a30da32656637 Mon Sep 17 00:00:00 2001 From: Mihai Dinca Date: Thu, 11 Apr 2019 15:57:59 +0200 Subject: [PATCH] Fix async batch race conditions Close batching when there is no next batch --- salt/cli/batch_async.py | 80 +++++++++++++++++++------------------- tests/unit/cli/test_batch_async.py | 35 +++++++---------- 2 files changed, 54 insertions(+), 61 deletions(-) diff --git a/salt/cli/batch_async.py b/salt/cli/batch_async.py index 3160d46d8b..9c20b2fc6e 100644 --- a/salt/cli/batch_async.py +++ b/salt/cli/batch_async.py @@ -37,14 +37,14 @@ class BatchAsync(object): - tag: salt/batch//start - data: { "available_minions": self.minions, - "down_minions": self.down_minions + "down_minions": targeted_minions - presence_ping_minions } When the batch ends, an `done` event is fired: - tag: salt/batch//done - data: { "available_minions": self.minions, - "down_minions": self.down_minions, + "down_minions": targeted_minions - presence_ping_minions "done_minions": self.done_minions, "timedout_minions": self.timedout_minions } @@ -67,7 +67,7 @@ class BatchAsync(object): self.eauth = batch_get_eauth(clear_load['kwargs']) self.metadata = clear_load['kwargs'].get('metadata', {}) self.minions = set() - self.down_minions = set() + self.targeted_minions = set() self.timedout_minions = set() self.done_minions = set() self.active = set() @@ -108,8 +108,7 @@ class BatchAsync(object): minion = data['id'] if op == 'ping_return': self.minions.add(minion) - self.down_minions.remove(minion) - if not self.down_minions: + if self.targeted_minions == self.minions: self.event.io_loop.spawn_callback(self.start_batch) elif op == 'find_job_return': self.find_job_returned.add(minion) @@ -120,9 +119,6 @@ class BatchAsync(object): # call later so that we maybe gather more returns self.event.io_loop.call_later(self.batch_delay, self.schedule_next) - if self.initialized and self.done_minions == self.minions.difference(self.timedout_minions): - self.end_batch() - def _get_next(self): to_run = self.minions.difference( self.done_minions).difference( @@ -135,16 +131,13 @@ class BatchAsync(object): return set(list(to_run)[:next_batch_size]) @tornado.gen.coroutine - def check_find_job(self, minions): - did_not_return = minions.difference(self.find_job_returned) - if did_not_return: - for minion in did_not_return: - if minion in self.find_job_returned: - self.find_job_returned.remove(minion) - if minion in self.active: - self.active.remove(minion) - self.timedout_minions.add(minion) - running = minions.difference(did_not_return).difference(self.done_minions).difference(self.timedout_minions) + def check_find_job(self, batch_minions): + timedout_minions = batch_minions.difference(self.find_job_returned).difference(self.done_minions) + self.timedout_minions = self.timedout_minions.union(timedout_minions) + self.active = self.active.difference(self.timedout_minions) + running = batch_minions.difference(self.done_minions).difference(self.timedout_minions) + if timedout_minions: + self.event.io_loop.call_later(self.batch_delay, self.schedule_next) if running: self.event.io_loop.add_callback(self.find_job, running) @@ -183,7 +176,7 @@ class BatchAsync(object): jid=self.ping_jid, metadata=self.metadata, **self.eauth) - self.down_minions = set(ping_return['minions']) + self.targeted_minions = set(ping_return['minions']) @tornado.gen.coroutine def start_batch(self): @@ -192,36 +185,43 @@ class BatchAsync(object): self.initialized = True data = { "available_minions": self.minions, - "down_minions": self.down_minions, + "down_minions": self.targeted_minions.difference(self.minions), "metadata": self.metadata } self.event.fire_event(data, "salt/batch/{0}/start".format(self.batch_jid)) yield self.schedule_next() def end_batch(self): - data = { - "available_minions": self.minions, - "down_minions": self.down_minions, - "done_minions": self.done_minions, - "timedout_minions": self.timedout_minions, - "metadata": self.metadata - } - self.event.fire_event(data, "salt/batch/{0}/done".format(self.batch_jid)) - self.event.remove_event_handler(self.__event_handler) + left = self.minions.symmetric_difference(self.done_minions.union(self.timedout_minions)) + if not left: + data = { + "available_minions": self.minions, + "down_minions": self.targeted_minions.difference(self.minions), + "done_minions": self.done_minions, + "timedout_minions": self.timedout_minions, + "metadata": self.metadata + } + self.event.fire_event(data, "salt/batch/{0}/done".format(self.batch_jid)) + self.event.remove_event_handler(self.__event_handler) @tornado.gen.coroutine def schedule_next(self): next_batch = self._get_next() if next_batch: - yield self.local.run_job_async( - next_batch, - self.opts['fun'], - self.opts['arg'], - 'list', - raw=self.opts.get('raw', False), - ret=self.opts.get('return', ''), - gather_job_timeout=self.opts['gather_job_timeout'], - jid=self.batch_jid, - metadata=self.metadata) - self.event.io_loop.call_later(self.opts['timeout'], self.find_job, set(next_batch)) self.active = self.active.union(next_batch) + try: + yield self.local.run_job_async( + next_batch, + self.opts['fun'], + self.opts['arg'], + 'list', + raw=self.opts.get('raw', False), + ret=self.opts.get('return', ''), + gather_job_timeout=self.opts['gather_job_timeout'], + jid=self.batch_jid, + metadata=self.metadata) + self.event.io_loop.call_later(self.opts['timeout'], self.find_job, set(next_batch)) + except Exception as ex: + self.active = self.active.difference(next_batch) + else: + self.end_batch() diff --git a/tests/unit/cli/test_batch_async.py b/tests/unit/cli/test_batch_async.py index f65b6a06c3..d519157d92 100644 --- a/tests/unit/cli/test_batch_async.py +++ b/tests/unit/cli/test_batch_async.py @@ -75,8 +75,8 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase): self.batch.local.run_job_async.call_args[0], ('*', 'test.ping', [], 'glob') ) - # assert down_minions == all minions matched by tgt - self.assertEqual(self.batch.down_minions, set(['foo', 'bar'])) + # assert targeted_minions == all minions matched by tgt + self.assertEqual(self.batch.targeted_minions, set(['foo', 'bar'])) @tornado.testing.gen_test def test_batch_start_on_gather_job_timeout(self): @@ -121,7 +121,10 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase): self.assertEqual(len(self.batch.schedule_next.mock_calls), 1) def test_batch_fire_done_event(self): + self.batch.targeted_minions = {'foo', 'baz', 'bar'} self.batch.minions = set(['foo', 'bar']) + self.batch.done_minions = {'foo'} + self.batch.timedout_minions = {'bar'} self.batch.event = MagicMock() self.batch.metadata = {'mykey': 'myvalue'} self.batch.end_batch() @@ -130,9 +133,9 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase): ( { 'available_minions': set(['foo', 'bar']), - 'done_minions': set(), - 'down_minions': set(), - 'timedout_minions': set(), + 'done_minions': self.batch.done_minions, + 'down_minions': {'baz'}, + 'timedout_minions': self.batch.timedout_minions, 'metadata': self.batch.metadata }, "salt/batch/1235/done" @@ -212,7 +215,7 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase): self.assertEqual(self.batch._get_next(), set()) def test_batch__event_handler_ping_return(self): - self.batch.down_minions = {'foo'} + self.batch.targeted_minions = {'foo'} self.batch.event = MagicMock( unpack=MagicMock(return_value=('salt/job/1234/ret/foo', {'id': 'foo'}))) self.batch.start() @@ -222,7 +225,7 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase): self.assertEqual(self.batch.done_minions, set()) def test_batch__event_handler_call_start_batch_when_all_pings_return(self): - self.batch.down_minions = {'foo'} + self.batch.targeted_minions = {'foo'} self.batch.event = MagicMock( unpack=MagicMock(return_value=('salt/job/1234/ret/foo', {'id': 'foo'}))) self.batch.start() @@ -232,7 +235,7 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase): (self.batch.start_batch,)) def test_batch__event_handler_not_call_start_batch_when_not_all_pings_return(self): - self.batch.down_minions = {'foo', 'bar'} + self.batch.targeted_minions = {'foo', 'bar'} self.batch.event = MagicMock( unpack=MagicMock(return_value=('salt/job/1234/ret/foo', {'id': 'foo'}))) self.batch.start() @@ -260,20 +263,10 @@ class AsyncBatchTestCase(AsyncTestCase, TestCase): self.assertEqual(self.batch.find_job_returned, {'foo'}) @tornado.testing.gen_test - def test_batch__event_handler_end_batch(self): - self.batch.event = MagicMock( - unpack=MagicMock(return_value=('salt/job/not-my-jid/ret/foo', {'id': 'foo'}))) - future = tornado.gen.Future() - future.set_result({'minions': ['foo', 'bar', 'baz']}) - self.batch.local.run_job_async.return_value = future - self.batch.start() - self.batch.initialized = True - self.assertEqual(self.batch.down_minions, {'foo', 'bar', 'baz'}) + def test_batch_schedule_next_end_batch_when_no_next(self): self.batch.end_batch = MagicMock() - self.batch.minions = {'foo', 'bar', 'baz'} - self.batch.done_minions = {'foo', 'bar'} - self.batch.timedout_minions = {'baz'} - self.batch._BatchAsync__event_handler(MagicMock()) + self.batch._get_next = MagicMock(return_value={}) + self.batch.schedule_next() self.assertEqual(len(self.batch.end_batch.mock_calls), 1) @tornado.testing.gen_test -- 2.16.4