From da544d7ab09899717e57a02321928ceaf3c6465c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Su=C3=A1rez=20Hern=C3=A1ndez?= Date: Tue, 22 Aug 2023 11:43:46 +0100 Subject: [PATCH] Do not fail on bad message pack message (bsc#1213441, CVE-2023-20897) (#595) * Do not fail on bad message pack message Fix unit test after backporting to openSUSE/release/3006.0 * Better error message when inconsistent decoded payload --------- Co-authored-by: Daniel A. Wozniak --- salt/channel/server.py | 10 +++ salt/transport/zeromq.py | 6 +- tests/pytests/unit/transport/test_zeromq.py | 69 +++++++++++++++++++++ 3 files changed, 84 insertions(+), 1 deletion(-) diff --git a/salt/channel/server.py b/salt/channel/server.py index a2117f2934..b6d51fef08 100644 --- a/salt/channel/server.py +++ b/salt/channel/server.py @@ -22,6 +22,7 @@ import salt.utils.minions import salt.utils.platform import salt.utils.stringutils import salt.utils.verify +from salt.exceptions import SaltDeserializationError from salt.utils.cache import CacheCli try: @@ -252,6 +253,15 @@ class ReqServerChannel: return False def _decode_payload(self, payload): + # Sometimes msgpack deserialization of random bytes could be successful, + # so we need to ensure payload in good shape to process this function. + if ( + not isinstance(payload, dict) + or "enc" not in payload + or "load" not in payload + ): + raise SaltDeserializationError("bad load received on socket!") + # we need to decrypt it if payload["enc"] == "aes": try: diff --git a/salt/transport/zeromq.py b/salt/transport/zeromq.py index 3ec7f7726c..7cc6b9987f 100644 --- a/salt/transport/zeromq.py +++ b/salt/transport/zeromq.py @@ -428,7 +428,11 @@ class RequestServer(salt.transport.base.DaemonizedRequestServer): @salt.ext.tornado.gen.coroutine def handle_message(self, stream, payload): - payload = self.decode_payload(payload) + try: + payload = self.decode_payload(payload) + except salt.exceptions.SaltDeserializationError: + self.stream.send(self.encode_payload({"msg": "bad load"})) + return # XXX: Is header really needed? reply = yield self.message_handler(payload) self.stream.send(self.encode_payload(reply)) diff --git a/tests/pytests/unit/transport/test_zeromq.py b/tests/pytests/unit/transport/test_zeromq.py index 10bb4917b8..c7cbc53864 100644 --- a/tests/pytests/unit/transport/test_zeromq.py +++ b/tests/pytests/unit/transport/test_zeromq.py @@ -11,6 +11,7 @@ import threading import time import uuid +import msgpack import pytest import salt.channel.client @@ -1404,3 +1405,71 @@ async def test_req_chan_auth_v2_new_minion_without_master_pub(pki_dir, io_loop): assert "sig" in ret ret = client.auth.handle_signin_response(signin_payload, ret) assert ret == "retry" + + +async def test_req_server_garbage_request(io_loop): + """ + Validate invalid msgpack messages will not raise exceptions in the + RequestServers's message handler. + """ + opts = salt.config.master_config("") + request_server = salt.transport.zeromq.RequestServer(opts) + + def message_handler(payload): + return payload + + request_server.post_fork(message_handler, io_loop) + + byts = msgpack.dumps({"foo": "bar"}) + badbyts = byts[:3] + b"^M" + byts[3:] + + valid_response = msgpack.dumps({"msg": "bad load"}) + + with MagicMock() as stream: + request_server.stream = stream + + try: + await request_server.handle_message(stream, badbyts) + except Exception as exc: # pylint: disable=broad-except + pytest.fail("Exception was raised {}".format(exc)) + + request_server.stream.send.assert_called_once_with(valid_response) + + +async def test_req_chan_bad_payload_to_decode(pki_dir, io_loop): + opts = { + "master_uri": "tcp://127.0.0.1:4506", + "interface": "127.0.0.1", + "ret_port": 4506, + "ipv6": False, + "sock_dir": ".", + "pki_dir": str(pki_dir.joinpath("minion")), + "id": "minion", + "__role": "minion", + "keysize": 4096, + "max_minions": 0, + "auto_accept": False, + "open_mode": False, + "key_pass": None, + "publish_port": 4505, + "auth_mode": 1, + "acceptance_wait_time": 3, + "acceptance_wait_time_max": 3, + } + SMaster.secrets["aes"] = { + "secret": multiprocessing.Array( + ctypes.c_char, + salt.utils.stringutils.to_bytes(salt.crypt.Crypticle.generate_key_string()), + ), + "reload": salt.crypt.Crypticle.generate_key_string, + } + master_opts = dict(opts, pki_dir=str(pki_dir.joinpath("master"))) + master_opts["master_sign_pubkey"] = False + server = salt.channel.server.ReqServerChannel.factory(master_opts) + + with pytest.raises(salt.exceptions.SaltDeserializationError): + server._decode_payload(None) + with pytest.raises(salt.exceptions.SaltDeserializationError): + server._decode_payload({}) + with pytest.raises(salt.exceptions.SaltDeserializationError): + server._decode_payload(12345) -- 2.41.0