From 6fd70f39b3a79ea406d994f3b437d8887d2f6e70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sun, 16 Apr 2023 17:47:21 +0300 Subject: [PATCH 1/7] Changed TLSAttribute.shared_ciphers to match SSLSocket.shared_ciphers() --- docs/versionhistory.rst | 2 ++ src/anyio/streams/tls.py | 9 +++-- tests/streams/test_tls.py | 74 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 4 deletions(-) Index: anyio-3.6.2/src/anyio/streams/tls.py =================================================================== --- anyio-3.6.2.orig/src/anyio/streams/tls.py +++ anyio-3.6.2/src/anyio/streams/tls.py @@ -37,8 +37,9 @@ class TLSAttribute(TypedAttributeSet): peer_certificate_binary: Optional[bytes] = typed_attribute() #: ``True`` if this is the server side of the connection server_side: bool = typed_attribute() - #: ciphers shared between both ends of the TLS connection - shared_ciphers: List[Tuple[str, str, int]] = typed_attribute() + #: ciphers shared by the client during the TLS handshake (``None`` if this is the + #: client side) + shared_ciphers: Optional[list[tuple[str, str, int]]] = typed_attribute() #: the :class:`~ssl.SSLObject` used for encryption ssl_object: ssl.SSLObject = typed_attribute() #: ``True`` if this stream does (and expects) a closing TLS handshake when the stream is being @@ -105,7 +106,8 @@ class TLSStream(ByteStream): # Re-enable detection of unexpected EOFs if it was disabled by Python if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): - ssl_context.options ^= ssl.OP_IGNORE_UNEXPECTED_EOF # type: ignore[attr-defined] + ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF # type: ignore[attr-defined] + bio_in = ssl.MemoryBIO() bio_out = ssl.MemoryBIO() @@ -153,7 +155,7 @@ class TLSStream(ByteStream): self._read_bio.write_eof() self._write_bio.write_eof() if ( - isinstance(exc, ssl.SSLEOFError) + isinstance(exc, (ssl.SSLEOFError, ssl.SSLZeroReturnError)) or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror ): if self.standard_compatible: @@ -228,7 +230,9 @@ class TLSStream(ByteStream): True ), TLSAttribute.server_side: lambda: self._ssl_object.server_side, - TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers(), + TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers() + if self._ssl_object.server_side + else None, TLSAttribute.standard_compatible: lambda: self.standard_compatible, TLSAttribute.ssl_object: lambda: self._ssl_object, TLSAttribute.tls_version: self._ssl_object.version, Index: anyio-3.6.2/tests/streams/test_tls.py =================================================================== --- anyio-3.6.2.orig/tests/streams/test_tls.py +++ anyio-3.6.2/tests/streams/test_tls.py @@ -5,6 +5,7 @@ from threading import Thread from typing import ContextManager, NoReturn import pytest +from pytest_mock import MockerFixture from trustme import CA from anyio import ( @@ -12,10 +13,13 @@ from anyio import ( EndOfStream, Event, connect_tcp, + create_memory_object_stream, create_task_group, create_tcp_listener, + to_thread, ) from anyio.abc import AnyByteStream, SocketAttribute, SocketStream +from anyio.streams.stapled import StapledObjectStream from anyio.streams.tls import TLSAttribute, TLSListener, TLSStream pytestmark = pytest.mark.anyio @@ -95,7 +99,7 @@ class TestTLSStream: wrapper.extra(TLSAttribute.peer_certificate_binary), bytes ) assert wrapper.extra(TLSAttribute.server_side) is False - assert isinstance(wrapper.extra(TLSAttribute.shared_ciphers), list) + assert wrapper.extra(TLSAttribute.shared_ciphers) is None assert isinstance(wrapper.extra(TLSAttribute.ssl_object), ssl.SSLObject) assert wrapper.extra(TLSAttribute.standard_compatible) is False assert wrapper.extra(TLSAttribute.tls_version).startswith("TLSv") @@ -368,6 +372,20 @@ class TestTLSStream: server_thread.join() server_sock.close() + @pytest.mark.skipif( + not hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"), + reason="The ssl module does not have the OP_IGNORE_UNEXPECTED_EOF attribute", + ) + async def test_default_context_ignore_unexpected_eof_flag_off( + self, mocker: MockerFixture + ) -> None: + send1, receive1 = create_memory_object_stream() + client_stream = StapledObjectStream(send1, receive1) + mocker.patch.object(TLSStream, "_call_sslobject_method") + tls_stream = await TLSStream.wrap(client_stream) + ssl_context = tls_stream.extra(TLSAttribute.ssl_object).context + assert not ssl_context.options & ssl.OP_IGNORE_UNEXPECTED_EOF + class TestTLSListener: async def test_handshake_fail(self, server_context: ssl.SSLContext) -> None: @@ -399,3 +417,74 @@ class TestTLSListener: tg.cancel_scope.cancel() assert isinstance(exception, BrokenResourceError) + + async def test_extra_attributes( + self, client_context: ssl.SSLContext, server_context: ssl.SSLContext, ca: CA + ) -> None: + def connect_sync(addr: tuple[str, int]) -> None: + with socket.create_connection(addr) as plain_sock: + plain_sock.settimeout(2) + with client_context.wrap_socket( + plain_sock, + server_side=False, + server_hostname="localhost", + suppress_ragged_eofs=False, + ) as conn: + conn.recv(1) + conn.unwrap() + + class CustomTLSListener(TLSListener): + @staticmethod + async def handle_handshake_error( + exc: BaseException, stream: AnyByteStream + ) -> None: + await TLSListener.handle_handshake_error(exc, stream) + pytest.fail("TLS handshake failed") + + async def handler(stream: TLSStream) -> None: + async with stream: + try: + assert stream.extra(TLSAttribute.alpn_protocol) == "h2" + assert isinstance( + stream.extra(TLSAttribute.channel_binding_tls_unique), bytes + ) + assert isinstance(stream.extra(TLSAttribute.cipher), tuple) + assert isinstance(stream.extra(TLSAttribute.peer_certificate), dict) + assert isinstance( + stream.extra(TLSAttribute.peer_certificate_binary), bytes + ) + assert stream.extra(TLSAttribute.server_side) is True + shared_ciphers = stream.extra(TLSAttribute.shared_ciphers) + assert isinstance(shared_ciphers, list) + assert len(shared_ciphers) > 1 + assert isinstance( + stream.extra(TLSAttribute.ssl_object), ssl.SSLObject + ) + assert stream.extra(TLSAttribute.standard_compatible) is True + assert stream.extra(TLSAttribute.tls_version).startswith("TLSv") + finally: + event.set() + await stream.send(b"\x00") + + # Issue a client certificate and make the server trust it + client_cert = ca.issue_cert("dummy-client") + client_cert.configure_cert(client_context) + ca.configure_trust(server_context) + server_context.verify_mode = ssl.CERT_REQUIRED + + event = Event() + server_context.set_alpn_protocols(["h2"]) + client_context.set_alpn_protocols(["h2"]) + listener = await create_tcp_listener(local_host="127.0.0.1") + tls_listener = CustomTLSListener(listener, server_context) + async with tls_listener, create_task_group() as tg: + assert tls_listener.extra(TLSAttribute.standard_compatible) is True + tg.start_soon(tls_listener.serve, handler) + client_thread = Thread( + target=connect_sync, + args=[listener.extra(SocketAttribute.local_address)], + ) + client_thread.start() + await event.wait() + await to_thread.run_sync(client_thread.join) + tg.cancel_scope.cancel() Index: anyio-3.6.2/tests/conftest.py =================================================================== --- anyio-3.6.2.orig/tests/conftest.py +++ anyio-3.6.2/tests/conftest.py @@ -54,7 +54,8 @@ def ca() -> CA: def server_context(ca: CA) -> SSLContext: server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): - server_context.options ^= ssl.OP_IGNORE_UNEXPECTED_EOF # type: ignore[attr-defined] + server_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF # type: ignore[attr-defined] + ca.issue_cert("localhost").configure_cert(server_context) return server_context