From 5d7a8a71151bd4e3d9f1c5830f785893b35f0f81e97c4764c91e7c1e4364c270 Mon Sep 17 00:00:00 2001 From: Steve Kowalik Date: Wed, 16 Oct 2019 03:20:52 +0000 Subject: [PATCH] - Add python38-support.patch from upstream that fixes build failures against Python 3.8. OBS-URL: https://build.opensuse.org/package/show/devel:languages:python/python-websockets?expand=0&rev=19 --- python-websockets.changes | 6 + python-websockets.spec | 2 + python38-support.patch | 2201 +++++++++++++++++++++++++++++++++++++ 3 files changed, 2209 insertions(+) create mode 100644 python38-support.patch diff --git a/python-websockets.changes b/python-websockets.changes index 8288ea8..4ead146 100644 --- a/python-websockets.changes +++ b/python-websockets.changes @@ -1,3 +1,9 @@ +------------------------------------------------------------------- +Wed Oct 16 03:19:43 UTC 2019 - Steve Kowalik + +- Add python38-support.patch from upstream that fixes build failures + against Python 3.8. + ------------------------------------------------------------------- Tue Aug 13 16:35:47 UTC 2019 - Tomáš Chvátal diff --git a/python-websockets.spec b/python-websockets.spec index e710f69..adf16d3 100644 --- a/python-websockets.spec +++ b/python-websockets.spec @@ -26,6 +26,7 @@ License: BSD-3-Clause Group: Development/Languages/Python URL: https://github.com/aaugustin/websockets Source: https://github.com/aaugustin/websockets/archive/%{version}.tar.gz +Patch0: python38-support.patch BuildRequires: %{python_module devel} BuildRequires: %{python_module setuptools} BuildRequires: fdupes @@ -44,6 +45,7 @@ concurrent applications. %prep %setup -q -n websockets-%{version} +%autopatch -p1 %build export CFLAGS="%{optflags}" diff --git a/python38-support.patch b/python38-support.patch new file mode 100644 index 0000000..3130c91 --- /dev/null +++ b/python38-support.patch @@ -0,0 +1,2201 @@ +From fbab68fad41b1ddad2bbe7eccd793f5d6c52f761 Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 11:41:21 +0200 +Subject: [PATCH 01/20] Copy FlowControlMixin and StreamReaderProtocol. + +This is the official recommendation of Python core devs. + +The code is taken from the current 3.7 branch. +--- + src/websockets/protocol.py | 127 ++++++++++++++++++++++++++++++++++++- + 1 file changed, 126 insertions(+), 1 deletion(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index 1f0edcce..2f74cd23 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -61,7 +61,132 @@ class State(enum.IntEnum): + # between the check and the assignment. + + +-class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): ++class FlowControlMixin(asyncio.Protocol): ++ """Reusable flow control logic for StreamWriter.drain(). ++ This implements the protocol methods pause_writing(), ++ resume_writing() and connection_lost(). If the subclass overrides ++ these it must call the super methods. ++ StreamWriter.drain() must wait for _drain_helper() coroutine. ++ """ ++ ++ def __init__(self, loop=None): ++ if loop is None: ++ self._loop = asyncio.get_event_loop() ++ else: ++ self._loop = loop ++ self._paused = False ++ self._drain_waiter = None ++ self._connection_lost = False ++ ++ def pause_writing(self): ++ assert not self._paused ++ self._paused = True ++ if self._loop.get_debug(): ++ logger.debug("%r pauses writing", self) ++ ++ def resume_writing(self): ++ assert self._paused ++ self._paused = False ++ if self._loop.get_debug(): ++ logger.debug("%r resumes writing", self) ++ ++ waiter = self._drain_waiter ++ if waiter is not None: ++ self._drain_waiter = None ++ if not waiter.done(): ++ waiter.set_result(None) ++ ++ def connection_lost(self, exc): ++ self._connection_lost = True ++ # Wake up the writer if currently paused. ++ if not self._paused: ++ return ++ waiter = self._drain_waiter ++ if waiter is None: ++ return ++ self._drain_waiter = None ++ if waiter.done(): ++ return ++ if exc is None: ++ waiter.set_result(None) ++ else: ++ waiter.set_exception(exc) ++ ++ async def _drain_helper(self): ++ if self._connection_lost: ++ raise ConnectionResetError("Connection lost") ++ if not self._paused: ++ return ++ waiter = self._drain_waiter ++ assert waiter is None or waiter.cancelled() ++ waiter = self._loop.create_future() ++ self._drain_waiter = waiter ++ await waiter ++ ++ ++class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol): ++ """Helper class to adapt between Protocol and StreamReader. ++ (This is a helper class instead of making StreamReader itself a ++ Protocol subclass, because the StreamReader has other potential ++ uses, and to prevent the user of the StreamReader to accidentally ++ call inappropriate methods of the protocol.) ++ """ ++ ++ def __init__(self, stream_reader, client_connected_cb=None, loop=None): ++ super().__init__(loop=loop) ++ self._stream_reader = stream_reader ++ self._stream_writer = None ++ self._client_connected_cb = client_connected_cb ++ self._over_ssl = False ++ self._closed = self._loop.create_future() ++ ++ def connection_made(self, transport): ++ self._stream_reader.set_transport(transport) ++ self._over_ssl = transport.get_extra_info("sslcontext") is not None ++ if self._client_connected_cb is not None: ++ self._stream_writer = asyncio.StreamWriter( ++ transport, self, self._stream_reader, self._loop ++ ) ++ res = self._client_connected_cb(self._stream_reader, self._stream_writer) ++ if asyncio.iscoroutine(res): ++ self._loop.create_task(res) ++ ++ def connection_lost(self, exc): ++ if self._stream_reader is not None: ++ if exc is None: ++ self._stream_reader.feed_eof() ++ else: ++ self._stream_reader.set_exception(exc) ++ if not self._closed.done(): ++ if exc is None: ++ self._closed.set_result(None) ++ else: ++ self._closed.set_exception(exc) ++ super().connection_lost(exc) ++ self._stream_reader = None ++ self._stream_writer = None ++ ++ def data_received(self, data): ++ self._stream_reader.feed_data(data) ++ ++ def eof_received(self): ++ self._stream_reader.feed_eof() ++ if self._over_ssl: ++ # Prevent a warning in SSLProtocol.eof_received: ++ # "returning true from eof_received() ++ # has no effect when using ssl" ++ return False ++ return True ++ ++ def __del__(self): ++ # Prevent reports about unhandled exceptions. ++ # Better than self._closed._log_traceback = False hack ++ closed = self._closed ++ if closed.done() and not closed.cancelled(): ++ closed.exception() ++ ++ ++class WebSocketCommonProtocol(StreamReaderProtocol): + """ + :class:`~asyncio.Protocol` subclass implementing the data transfer phase. + + +From 57907330fe789f0b5f6b71b2810a0e36fee14844 Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 11:54:04 +0200 +Subject: [PATCH 02/20] Remove docstrings and debug logs. + +--- + src/websockets/protocol.py | 16 ---------------- + 1 file changed, 16 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index 2f74cd23..d74c8157 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -62,12 +62,6 @@ class State(enum.IntEnum): + + + class FlowControlMixin(asyncio.Protocol): +- """Reusable flow control logic for StreamWriter.drain(). +- This implements the protocol methods pause_writing(), +- resume_writing() and connection_lost(). If the subclass overrides +- these it must call the super methods. +- StreamWriter.drain() must wait for _drain_helper() coroutine. +- """ + + def __init__(self, loop=None): + if loop is None: +@@ -81,14 +75,10 @@ def __init__(self, loop=None): + def pause_writing(self): + assert not self._paused + self._paused = True +- if self._loop.get_debug(): +- logger.debug("%r pauses writing", self) + + def resume_writing(self): + assert self._paused + self._paused = False +- if self._loop.get_debug(): +- logger.debug("%r resumes writing", self) + + waiter = self._drain_waiter + if waiter is not None: +@@ -125,12 +115,6 @@ def connection_lost(self, exc): + + + class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol): +- """Helper class to adapt between Protocol and StreamReader. +- (This is a helper class instead of making StreamReader itself a +- Protocol subclass, because the StreamReader has other potential +- uses, and to prevent the user of the StreamReader to accidentally +- call inappropriate methods of the protocol.) +- """ + + def __init__(self, stream_reader, client_connected_cb=None, loop=None): + super().__init__(loop=loop) + +From 53523d06c7752b394c1f1e5566f464db741e2617 Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 11:54:43 +0200 +Subject: [PATCH 03/20] Merge FlowControlMixin in StreamReaderProtocol. + +--- + src/websockets/protocol.py | 54 +++++++++++++++++--------------------- + 1 file changed, 24 insertions(+), 30 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index d74c8157..49d8b4f2 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -61,9 +61,9 @@ class State(enum.IntEnum): + # between the check and the assignment. + + +-class FlowControlMixin(asyncio.Protocol): ++class StreamReaderProtocol(asyncio.Protocol): + +- def __init__(self, loop=None): ++ def __init__(self, stream_reader, client_connected_cb=None, loop=None): + if loop is None: + self._loop = asyncio.get_event_loop() + else: +@@ -72,6 +72,12 @@ def __init__(self, loop=None): + self._drain_waiter = None + self._connection_lost = False + ++ self._stream_reader = stream_reader ++ self._stream_writer = None ++ self._client_connected_cb = client_connected_cb ++ self._over_ssl = False ++ self._closed = self._loop.create_future() ++ + def pause_writing(self): + assert not self._paused + self._paused = True +@@ -86,22 +92,6 @@ def resume_writing(self): + if not waiter.done(): + waiter.set_result(None) + +- def connection_lost(self, exc): +- self._connection_lost = True +- # Wake up the writer if currently paused. +- if not self._paused: +- return +- waiter = self._drain_waiter +- if waiter is None: +- return +- self._drain_waiter = None +- if waiter.done(): +- return +- if exc is None: +- waiter.set_result(None) +- else: +- waiter.set_exception(exc) +- + async def _drain_helper(self): + if self._connection_lost: + raise ConnectionResetError("Connection lost") +@@ -113,17 +103,6 @@ def connection_lost(self, exc): + self._drain_waiter = waiter + await waiter + +- +-class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol): +- +- def __init__(self, stream_reader, client_connected_cb=None, loop=None): +- super().__init__(loop=loop) +- self._stream_reader = stream_reader +- self._stream_writer = None +- self._client_connected_cb = client_connected_cb +- self._over_ssl = False +- self._closed = self._loop.create_future() +- + def connection_made(self, transport): + self._stream_reader.set_transport(transport) + self._over_ssl = transport.get_extra_info("sslcontext") is not None +@@ -146,7 +125,22 @@ def connection_lost(self, exc): + self._closed.set_result(None) + else: + self._closed.set_exception(exc) +- super().connection_lost(exc) ++ ++ self._connection_lost = True ++ # Wake up the writer if currently paused. ++ if not self._paused: ++ return ++ waiter = self._drain_waiter ++ if waiter is None: ++ return ++ self._drain_waiter = None ++ if waiter.done(): ++ return ++ if exc is None: ++ waiter.set_result(None) ++ else: ++ waiter.set_exception(exc) ++ + self._stream_reader = None + self._stream_writer = None + + +From e6d8bcfe2a0f16c50260cafab45f06da6baa1689 Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 11:55:53 +0200 +Subject: [PATCH 04/20] Deduplicate loop and _loop attributes. + +--- + src/websockets/protocol.py | 18 ++++++------------ + 1 file changed, 6 insertions(+), 12 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index 49d8b4f2..98c23ab1 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -63,11 +63,7 @@ class State(enum.IntEnum): + + class StreamReaderProtocol(asyncio.Protocol): + +- def __init__(self, stream_reader, client_connected_cb=None, loop=None): +- if loop is None: +- self._loop = asyncio.get_event_loop() +- else: +- self._loop = loop ++ def __init__(self, stream_reader, client_connected_cb=None): + self._paused = False + self._drain_waiter = None + self._connection_lost = False +@@ -76,7 +72,7 @@ def __init__(self, stream_reader, client_connected_cb=None, loop=None): + self._stream_writer = None + self._client_connected_cb = client_connected_cb + self._over_ssl = False +- self._closed = self._loop.create_future() ++ self._closed = self.loop.create_future() + + def pause_writing(self): + assert not self._paused +@@ -99,7 +95,7 @@ def resume_writing(self): + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() +- waiter = self._loop.create_future() ++ waiter = self.loop.create_future() + self._drain_waiter = waiter + await waiter + +@@ -108,11 +104,11 @@ def connection_made(self, transport): + self._over_ssl = transport.get_extra_info("sslcontext") is not None + if self._client_connected_cb is not None: + self._stream_writer = asyncio.StreamWriter( +- transport, self, self._stream_reader, self._loop ++ transport, self, self._stream_reader, self.loop + ) + res = self._client_connected_cb(self._stream_reader, self._stream_writer) + if asyncio.iscoroutine(res): +- self._loop.create_task(res) ++ self.loop.create_task(res) + + def connection_lost(self, exc): + if self._stream_reader is not None: +@@ -315,8 +311,6 @@ def __init__( + self.read_limit = read_limit + self.write_limit = write_limit + +- # Store a reference to loop to avoid relying on self._loop, a private +- # attribute of StreamReaderProtocol, inherited from FlowControlMixin. + if loop is None: + loop = asyncio.get_event_loop() + self.loop = loop +@@ -331,7 +325,7 @@ def __init__( + # limit and half the buffer limit of :class:`~asyncio.StreamReader`. + # That's why it must be set to half of ``self.read_limit``. + stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) +- super().__init__(stream_reader, self.client_connected, loop) ++ super().__init__(stream_reader, self.client_connected) + + self.reader: asyncio.StreamReader + self.writer: asyncio.StreamWriter + +From 4205cb176e62c938247af59e94214814242360ec Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 11:57:34 +0200 +Subject: [PATCH 05/20] Remove client_connected callback. + +--- + src/websockets/protocol.py | 31 +++++++------------------------ + 1 file changed, 7 insertions(+), 24 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index 98c23ab1..bfc354a8 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -63,14 +63,13 @@ class State(enum.IntEnum): + + class StreamReaderProtocol(asyncio.Protocol): + +- def __init__(self, stream_reader, client_connected_cb=None): ++ def __init__(self, stream_reader): + self._paused = False + self._drain_waiter = None + self._connection_lost = False + + self._stream_reader = stream_reader + self._stream_writer = None +- self._client_connected_cb = client_connected_cb + self._over_ssl = False + self._closed = self.loop.create_future() + +@@ -102,13 +101,11 @@ def resume_writing(self): + def connection_made(self, transport): + self._stream_reader.set_transport(transport) + self._over_ssl = transport.get_extra_info("sslcontext") is not None +- if self._client_connected_cb is not None: +- self._stream_writer = asyncio.StreamWriter( +- transport, self, self._stream_reader, self.loop +- ) +- res = self._client_connected_cb(self._stream_reader, self._stream_writer) +- if asyncio.iscoroutine(res): +- self.loop.create_task(res) ++ self._stream_writer = asyncio.StreamWriter( ++ transport, self, self._stream_reader, self.loop ++ ) ++ self.reader = self._stream_reader ++ self.writer = self._stream_writer + + def connection_lost(self, exc): + if self._stream_reader is not None: +@@ -325,7 +322,7 @@ def __init__( + # limit and half the buffer limit of :class:`~asyncio.StreamReader`. + # That's why it must be set to half of ``self.read_limit``. + stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) +- super().__init__(stream_reader, self.client_connected) ++ super().__init__(stream_reader) + + self.reader: asyncio.StreamReader + self.writer: asyncio.StreamWriter +@@ -381,20 +378,6 @@ def __init__( + # Task closing the TCP connection. + self.close_connection_task: asyncio.Task[None] + +- def client_connected( +- self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter +- ) -> None: +- """ +- Callback when the TCP connection is established. +- +- Record references to the stream reader and the stream writer to avoid +- using private attributes ``_stream_reader`` and ``_stream_writer`` of +- :class:`~asyncio.StreamReaderProtocol`. +- +- """ +- self.reader = reader +- self.writer = writer +- + def connection_open(self) -> None: + """ + Callback when the WebSocket opening handshake completes. + +From 2ebb4b4eff76be388ee37570e891ed8432e19a9f Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 12:00:31 +0200 +Subject: [PATCH 06/20] Deduplicate reader/writer and _stream_reader/writer + attributes. + +--- + src/websockets/protocol.py | 36 +++++++++++++++--------------------- + 1 file changed, 15 insertions(+), 21 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index bfc354a8..9c61a409 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -62,14 +62,11 @@ class State(enum.IntEnum): + + + class StreamReaderProtocol(asyncio.Protocol): +- +- def __init__(self, stream_reader): ++ def __init__(self): + self._paused = False + self._drain_waiter = None + self._connection_lost = False + +- self._stream_reader = stream_reader +- self._stream_writer = None + self._over_ssl = False + self._closed = self.loop.create_future() + +@@ -99,20 +96,16 @@ def resume_writing(self): + await waiter + + def connection_made(self, transport): +- self._stream_reader.set_transport(transport) ++ self.reader.set_transport(transport) + self._over_ssl = transport.get_extra_info("sslcontext") is not None +- self._stream_writer = asyncio.StreamWriter( +- transport, self, self._stream_reader, self.loop +- ) +- self.reader = self._stream_reader +- self.writer = self._stream_writer ++ self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) + + def connection_lost(self, exc): +- if self._stream_reader is not None: ++ if self.reader is not None: + if exc is None: +- self._stream_reader.feed_eof() ++ self.reader.feed_eof() + else: +- self._stream_reader.set_exception(exc) ++ self.reader.set_exception(exc) + if not self._closed.done(): + if exc is None: + self._closed.set_result(None) +@@ -134,14 +127,14 @@ def connection_lost(self, exc): + else: + waiter.set_exception(exc) + +- self._stream_reader = None +- self._stream_writer = None ++ del self.reader ++ del self.writer + + def data_received(self, data): +- self._stream_reader.feed_data(data) ++ self.reader.feed_data(data) + + def eof_received(self): +- self._stream_reader.feed_eof() ++ self.reader.feed_eof() + if self._over_ssl: + # Prevent a warning in SSLProtocol.eof_received: + # "returning true from eof_received() +@@ -321,13 +314,14 @@ def __init__( + # ``self.read_limit``. The ``limit`` argument controls the line length + # limit and half the buffer limit of :class:`~asyncio.StreamReader`. + # That's why it must be set to half of ``self.read_limit``. +- stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) +- super().__init__(stream_reader) +- +- self.reader: asyncio.StreamReader ++ self.reader: asyncio.StreamReader = asyncio.StreamReader( ++ limit=read_limit // 2, loop=loop ++ ) + self.writer: asyncio.StreamWriter + self._drain_lock = asyncio.Lock(loop=loop) + ++ super().__init__() ++ + # This class implements the data transfer and closing handshake, which + # are shared between the client-side and the server-side. + # Subclasses implement the opening handshake and, on success, execute + +From c7ae53f4dfc586f21af1fdbdcfba321bbccd491e Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 12:25:40 +0200 +Subject: [PATCH 07/20] Merge asyncio.Protocol methods. + +--- + src/websockets/protocol.py | 160 ++++++++++++++++--------------------- + 1 file changed, 70 insertions(+), 90 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index 9c61a409..89e3464a 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -70,20 +70,6 @@ def __init__(self): + self._over_ssl = False + self._closed = self.loop.create_future() + +- def pause_writing(self): +- assert not self._paused +- self._paused = True +- +- def resume_writing(self): +- assert self._paused +- self._paused = False +- +- waiter = self._drain_waiter +- if waiter is not None: +- self._drain_waiter = None +- if not waiter.done(): +- waiter.set_result(None) +- + async def _drain_helper(self): + if self._connection_lost: + raise ConnectionResetError("Connection lost") +@@ -95,53 +81,6 @@ def resume_writing(self): + self._drain_waiter = waiter + await waiter + +- def connection_made(self, transport): +- self.reader.set_transport(transport) +- self._over_ssl = transport.get_extra_info("sslcontext") is not None +- self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) +- +- def connection_lost(self, exc): +- if self.reader is not None: +- if exc is None: +- self.reader.feed_eof() +- else: +- self.reader.set_exception(exc) +- if not self._closed.done(): +- if exc is None: +- self._closed.set_result(None) +- else: +- self._closed.set_exception(exc) +- +- self._connection_lost = True +- # Wake up the writer if currently paused. +- if not self._paused: +- return +- waiter = self._drain_waiter +- if waiter is None: +- return +- self._drain_waiter = None +- if waiter.done(): +- return +- if exc is None: +- waiter.set_result(None) +- else: +- waiter.set_exception(exc) +- +- del self.reader +- del self.writer +- +- def data_received(self, data): +- self.reader.feed_data(data) +- +- def eof_received(self): +- self.reader.feed_eof() +- if self._over_ssl: +- # Prevent a warning in SSLProtocol.eof_received: +- # "returning true from eof_received() +- # has no effect when using ssl" +- return False +- return True +- + def __del__(self): + # Prevent reports about unhandled exceptions. + # Better than self._closed._log_traceback = False hack +@@ -1363,7 +1302,7 @@ def abort_pings(self) -> None: + "%s - aborted pending ping%s: %s", self.side, plural, pings_hex + ) + +- # asyncio.StreamReaderProtocol methods ++ # asyncio.Protocol methods + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """ +@@ -1382,34 +1321,11 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: + logger.debug("%s - event = connection_made(%s)", self.side, transport) + # mypy thinks transport is a BaseTransport, not a Transport. + transport.set_write_buffer_limits(self.write_limit) # type: ignore +- super().connection_made(transport) +- +- def eof_received(self) -> bool: +- """ +- Close the transport after receiving EOF. +- +- Since Python 3.5, `:meth:~StreamReaderProtocol.eof_received` returns +- ``True`` on non-TLS connections. +- +- See http://bugs.python.org/issue24539 for more information. +- +- This is inappropriate for ``websockets`` for at least three reasons: +- +- 1. The use case is to read data until EOF with self.reader.read(-1). +- Since WebSocket is a TLV protocol, this never happens. +- +- 2. It doesn't work on TLS connections. A falsy value must be +- returned to have the same behavior on TLS and plain connections. +- +- 3. The WebSocket protocol has its own closing handshake. Endpoints +- close the TCP connection after sending a close frame. +- +- As a consequence we revert to the previous, more useful behavior. + +- """ +- logger.debug("%s - event = eof_received()", self.side) +- super().eof_received() +- return False ++ # Copied from asyncio.StreamReaderProtocol ++ self.reader.set_transport(transport) ++ self._over_ssl = transport.get_extra_info("sslcontext") is not None ++ self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) + + def connection_lost(self, exc: Optional[Exception]) -> None: + """ +@@ -1434,4 +1350,68 @@ def connection_lost(self, exc: Optional[Exception]) -> None: + # - it's set only here in connection_lost() which is called only once; + # - it must never be canceled. + self.connection_lost_waiter.set_result(None) +- super().connection_lost(exc) ++ ++ # Copied from asyncio.StreamReaderProtocol ++ if self.reader is not None: ++ if exc is None: ++ self.reader.feed_eof() ++ else: ++ self.reader.set_exception(exc) ++ if not self._closed.done(): ++ if exc is None: ++ self._closed.set_result(None) ++ else: ++ self._closed.set_exception(exc) ++ ++ # Copied from asyncio.FlowControlMixin ++ self._connection_lost = True ++ # Wake up the writer if currently paused. ++ if not self._paused: ++ return ++ waiter = self._drain_waiter ++ if waiter is None: ++ return ++ self._drain_waiter = None ++ if waiter.done(): ++ return ++ if exc is None: ++ waiter.set_result(None) ++ else: ++ waiter.set_exception(exc) ++ ++ del self.reader ++ del self.writer ++ ++ def pause_writing(self) -> None: ++ assert not self._paused ++ self._paused = True ++ ++ def resume_writing(self) -> None: ++ assert self._paused ++ self._paused = False ++ ++ waiter = self._drain_waiter ++ if waiter is not None: ++ self._drain_waiter = None ++ if not waiter.done(): ++ waiter.set_result(None) ++ ++ def data_received(self, data: bytes) -> None: ++ logger.debug("%s - event = data_received(<%d bytes>)", self.side, len(data)) ++ self.reader.feed_data(data) ++ ++ def eof_received(self) -> None: ++ """ ++ Close the transport after receiving EOF. ++ ++ The WebSocket protocol has its own closing handshake: endpoints close ++ the TCP or TLS connection after sending and receiving a close frame. ++ ++ As a consequence, they never need to write after receiving EOF, so ++ there's no reason to keep the transport open by returning ``True``. ++ ++ Besides, that doesn't work on TLS connections. ++ ++ """ ++ logger.debug("%s - event = eof_received()", self.side) ++ self.reader.feed_eof() + +From ab5dd72aebd86b878ff08d1ebd5cbb3256e2d6e3 Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 13:09:40 +0200 +Subject: [PATCH 08/20] Finish merging StreamReaderProtocol and + FlowControlMixin. + +--- + src/websockets/protocol.py | 59 +++++++++++++++++++------------------- + 1 file changed, 29 insertions(+), 30 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index 89e3464a..2db44e5d 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -61,35 +61,7 @@ class State(enum.IntEnum): + # between the check and the assignment. + + +-class StreamReaderProtocol(asyncio.Protocol): +- def __init__(self): +- self._paused = False +- self._drain_waiter = None +- self._connection_lost = False +- +- self._over_ssl = False +- self._closed = self.loop.create_future() +- +- async def _drain_helper(self): +- if self._connection_lost: +- raise ConnectionResetError("Connection lost") +- if not self._paused: +- return +- waiter = self._drain_waiter +- assert waiter is None or waiter.cancelled() +- waiter = self.loop.create_future() +- self._drain_waiter = waiter +- await waiter +- +- def __del__(self): +- # Prevent reports about unhandled exceptions. +- # Better than self._closed._log_traceback = False hack +- closed = self._closed +- if closed.done() and not closed.cancelled(): +- closed.exception() +- +- +-class WebSocketCommonProtocol(StreamReaderProtocol): ++class WebSocketCommonProtocol(asyncio.Protocol): + """ + :class:`~asyncio.Protocol` subclass implementing the data transfer phase. + +@@ -259,7 +231,14 @@ def __init__( + self.writer: asyncio.StreamWriter + self._drain_lock = asyncio.Lock(loop=loop) + +- super().__init__() ++ # Copied from asyncio.FlowControlMixin ++ self._paused = False ++ self._drain_waiter: Optional[asyncio.Future[None]] = None ++ self._connection_lost = False ++ ++ # Copied from asyncio.StreamReaderProtocol ++ self._over_ssl = False ++ self._closed = self.loop.create_future() + + # This class implements the data transfer and closing handshake, which + # are shared between the client-side and the server-side. +@@ -311,6 +290,26 @@ def __init__( + # Task closing the TCP connection. + self.close_connection_task: asyncio.Task[None] + ++ # Copied from asyncio.StreamReaderProtocol ++ def __del__(self) -> None: ++ # Prevent reports about unhandled exceptions. ++ # Better than self._closed._log_traceback = False hack ++ closed = self._closed ++ if closed.done() and not closed.cancelled(): ++ closed.exception() ++ ++ # Copied from asyncio.FlowControlMixin ++ async def _drain_helper(self) -> None: ++ if self._connection_lost: ++ raise ConnectionResetError("Connection lost") ++ if not self._paused: ++ return ++ waiter = self._drain_waiter ++ assert waiter is None or waiter.cancelled() ++ waiter = self.loop.create_future() ++ self._drain_waiter = waiter ++ await waiter ++ + def connection_open(self) -> None: + """ + Callback when the WebSocket opening handshake completes. + +From 080543c4961037fcd8a207956e48474366e13ad5 Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 13:17:26 +0200 +Subject: [PATCH 09/20] Deduplicate connection termination tracking. + +--- + src/websockets/protocol.py | 21 +++++---------------- + 1 file changed, 5 insertions(+), 16 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index 2db44e5d..1cd5a91c 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -234,11 +234,9 @@ def __init__( + # Copied from asyncio.FlowControlMixin + self._paused = False + self._drain_waiter: Optional[asyncio.Future[None]] = None +- self._connection_lost = False + + # Copied from asyncio.StreamReaderProtocol + self._over_ssl = False +- self._closed = self.loop.create_future() + + # This class implements the data transfer and closing handshake, which + # are shared between the client-side and the server-side. +@@ -290,17 +288,14 @@ def __init__( + # Task closing the TCP connection. + self.close_connection_task: asyncio.Task[None] + +- # Copied from asyncio.StreamReaderProtocol +- def __del__(self) -> None: +- # Prevent reports about unhandled exceptions. +- # Better than self._closed._log_traceback = False hack +- closed = self._closed +- if closed.done() and not closed.cancelled(): +- closed.exception() ++ # asyncio.StreamWriter expects this attribute on the Protocol ++ @property ++ def _closed(self) -> asyncio.Future: ++ return self.connection_lost_waiter + + # Copied from asyncio.FlowControlMixin + async def _drain_helper(self) -> None: +- if self._connection_lost: ++ if self.connection_lost_waiter.done(): + raise ConnectionResetError("Connection lost") + if not self._paused: + return +@@ -1356,14 +1351,8 @@ def connection_lost(self, exc: Optional[Exception]) -> None: + self.reader.feed_eof() + else: + self.reader.set_exception(exc) +- if not self._closed.done(): +- if exc is None: +- self._closed.set_result(None) +- else: +- self._closed.set_exception(exc) + + # Copied from asyncio.FlowControlMixin +- self._connection_lost = True + # Wake up the writer if currently paused. + if not self._paused: + return + +From 284932f967699467908e69226118f9bcd5ba8f9d Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 13:21:25 +0200 +Subject: [PATCH 10/20] Remove unused attribute. + +--- + src/websockets/protocol.py | 4 ---- + 1 file changed, 4 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index 1cd5a91c..a1c90916 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -235,9 +235,6 @@ def __init__( + self._paused = False + self._drain_waiter: Optional[asyncio.Future[None]] = None + +- # Copied from asyncio.StreamReaderProtocol +- self._over_ssl = False +- + # This class implements the data transfer and closing handshake, which + # are shared between the client-side and the server-side. + # Subclasses implement the opening handshake and, on success, execute +@@ -1318,7 +1315,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: + + # Copied from asyncio.StreamReaderProtocol + self.reader.set_transport(transport) +- self._over_ssl = transport.get_extra_info("sslcontext") is not None + self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) + + def connection_lost(self, exc: Optional[Exception]) -> None: + +From 9eeeeb3547eb7abc2253c75e462a50a8f68ba16b Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 13:28:49 +0200 +Subject: [PATCH 11/20] Ignore quality checks for code copied from asyncio. + +--- + src/websockets/protocol.py | 51 +++++++++++++++++++------------------- + 1 file changed, 25 insertions(+), 26 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index a1c90916..0bb12fd5 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -287,11 +287,11 @@ def __init__( + + # asyncio.StreamWriter expects this attribute on the Protocol + @property +- def _closed(self) -> asyncio.Future: ++ def _closed(self) -> Any: # pragma: no cover + return self.connection_lost_waiter + + # Copied from asyncio.FlowControlMixin +- async def _drain_helper(self) -> None: ++ async def _drain_helper(self) -> None: # pragma: no cover + if self.connection_lost_waiter.done(): + raise ConnectionResetError("Connection lost") + if not self._paused: +@@ -1341,36 +1341,35 @@ def connection_lost(self, exc: Optional[Exception]) -> None: + # - it must never be canceled. + self.connection_lost_waiter.set_result(None) + +- # Copied from asyncio.StreamReaderProtocol +- if self.reader is not None: +- if exc is None: +- self.reader.feed_eof() +- else: +- self.reader.set_exception(exc) ++ if True: # pragma: no cover + +- # Copied from asyncio.FlowControlMixin +- # Wake up the writer if currently paused. +- if not self._paused: +- return +- waiter = self._drain_waiter +- if waiter is None: +- return +- self._drain_waiter = None +- if waiter.done(): +- return +- if exc is None: +- waiter.set_result(None) +- else: +- waiter.set_exception(exc) ++ # Copied from asyncio.StreamReaderProtocol ++ if self.reader is not None: ++ if exc is None: ++ self.reader.feed_eof() ++ else: ++ self.reader.set_exception(exc) + +- del self.reader +- del self.writer ++ # Copied from asyncio.FlowControlMixin ++ # Wake up the writer if currently paused. ++ if not self._paused: ++ return ++ waiter = self._drain_waiter ++ if waiter is None: ++ return ++ self._drain_waiter = None ++ if waiter.done(): ++ return ++ if exc is None: ++ waiter.set_result(None) ++ else: ++ waiter.set_exception(exc) + +- def pause_writing(self) -> None: ++ def pause_writing(self) -> None: # pragma: no cover + assert not self._paused + self._paused = True + +- def resume_writing(self) -> None: ++ def resume_writing(self) -> None: # pragma: no cover + assert self._paused + self._paused = False + + +From f79cdd8e50e82433bbb0c003279a6b6f13e6c17e Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 13:54:13 +0200 +Subject: [PATCH 12/20] Remove asyncio.StreamWriter. + +It adds only one method for flow control. Copy it, as we've already +copied the rest of the flow control implementation. +--- + src/websockets/client.py | 2 +- + src/websockets/protocol.py | 64 +++++++++++++++++++++---------------- + src/websockets/server.py | 6 ++-- + tests/test_client_server.py | 2 +- + tests/test_protocol.py | 18 +++++------ + 5 files changed, 51 insertions(+), 41 deletions(-) + +diff --git a/src/websockets/client.py b/src/websockets/client.py +index c1fdf88a..34cd8624 100644 +--- a/src/websockets/client.py ++++ b/src/websockets/client.py +@@ -85,7 +85,7 @@ def write_http_request(self, path: str, headers: Headers) -> None: + request = f"GET {path} HTTP/1.1\r\n" + request += str(headers) + +- self.writer.write(request.encode()) ++ self.transport.write(request.encode()) + + async def read_http_response(self) -> Tuple[int, Headers]: + """ +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index 0bb12fd5..eb3d6bcc 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -228,7 +228,8 @@ def __init__( + self.reader: asyncio.StreamReader = asyncio.StreamReader( + limit=read_limit // 2, loop=loop + ) +- self.writer: asyncio.StreamWriter ++ ++ self.transport: asyncio.Transport + self._drain_lock = asyncio.Lock(loop=loop) + + # Copied from asyncio.FlowControlMixin +@@ -285,11 +286,6 @@ def __init__( + # Task closing the TCP connection. + self.close_connection_task: asyncio.Task[None] + +- # asyncio.StreamWriter expects this attribute on the Protocol +- @property +- def _closed(self) -> Any: # pragma: no cover +- return self.connection_lost_waiter +- + # Copied from asyncio.FlowControlMixin + async def _drain_helper(self) -> None: # pragma: no cover + if self.connection_lost_waiter.done(): +@@ -302,6 +298,23 @@ def _closed(self) -> Any: # pragma: no cover + self._drain_waiter = waiter + await waiter + ++ # Copied from asyncio.StreamWriter ++ async def _drain(self) -> None: # pragma: no cover ++ if self.reader is not None: ++ exc = self.reader.exception() ++ if exc is not None: ++ raise exc ++ if self.transport is not None: ++ if self.transport.is_closing(): ++ # Yield to the event loop so connection_lost() may be ++ # called. Without this, _drain_helper() would return ++ # immediately, and code that calls ++ # write(...); yield from drain() ++ # in a loop would never call connection_lost(), so it ++ # would not see an error when the socket is closed. ++ await asyncio.sleep(0) ++ await self._drain_helper() ++ + def connection_open(self) -> None: + """ + Callback when the WebSocket opening handshake completes. +@@ -348,9 +361,9 @@ def local_address(self) -> Any: + been established yet. + + """ +- if self.writer is None: ++ if self.transport is None: + return None +- return self.writer.get_extra_info("sockname") ++ return self.transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: +@@ -361,9 +374,9 @@ def remote_address(self) -> Any: + been established yet. + + """ +- if self.writer is None: ++ if self.transport is None: + return None +- return self.writer.get_extra_info("peername") ++ return self.transport.get_extra_info("peername") + + @property + def open(self) -> bool: +@@ -1037,7 +1050,9 @@ def append(frame: Frame) -> None: + + frame = Frame(fin, opcode, data) + logger.debug("%s > %r", self.side, frame) +- frame.write(self.writer.write, mask=self.is_client, extensions=self.extensions) ++ frame.write( ++ self.transport.write, mask=self.is_client, extensions=self.extensions ++ ) + + try: + # drain() cannot be called concurrently by multiple coroutines: +@@ -1045,7 +1060,7 @@ def append(frame: Frame) -> None: + # version of Python where this bugs exists is supported anymore. + async with self._drain_lock: + # Handle flow control automatically. +- await self.writer.drain() ++ await self._drain() + except ConnectionError: + # Terminate the connection if the socket died. + self.fail_connection() +@@ -1147,9 +1162,9 @@ def append(frame: Frame) -> None: + logger.debug("%s ! timed out waiting for TCP close", self.side) + + # Half-close the TCP connection if possible (when there's no TLS). +- if self.writer.can_write_eof(): ++ if self.transport.can_write_eof(): + logger.debug("%s x half-closing TCP connection", self.side) +- self.writer.write_eof() ++ self.transport.write_eof() + + if await self.wait_for_connection_lost(): + return +@@ -1162,17 +1177,12 @@ def append(frame: Frame) -> None: + # If connection_lost() was called, the TCP connection is closed. + # However, if TLS is enabled, the transport still needs closing. + # Else asyncio complains: ResourceWarning: unclosed transport. +- try: +- writer_is_closing = self.writer.is_closing # type: ignore +- except AttributeError: # pragma: no cover +- # Python < 3.7 +- writer_is_closing = self.writer.transport.is_closing +- if self.connection_lost_waiter.done() and writer_is_closing(): ++ if self.connection_lost_waiter.done() and self.transport.is_closing(): + return + + # Close the TCP connection. Buffers are flushed asynchronously. + logger.debug("%s x closing TCP connection", self.side) +- self.writer.close() ++ self.transport.close() + + if await self.wait_for_connection_lost(): + return +@@ -1180,8 +1190,7 @@ def append(frame: Frame) -> None: + + # Abort the TCP connection. Buffers are discarded. + logger.debug("%s x aborting TCP connection", self.side) +- # mypy thinks self.writer.transport is a BaseTransport, not a Transport. +- self.writer.transport.abort() # type: ignore ++ self.transport.abort() + + # connection_lost() is called quickly after aborting. + await self.wait_for_connection_lost() +@@ -1261,7 +1270,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: + frame = Frame(True, OP_CLOSE, frame_data) + logger.debug("%s > %r", self.side, frame) + frame.write( +- self.writer.write, mask=self.is_client, extensions=self.extensions ++ self.transport.write, mask=self.is_client, extensions=self.extensions + ) + + # Start close_connection_task if the opening handshake didn't succeed. +@@ -1310,12 +1319,13 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: + + """ + logger.debug("%s - event = connection_made(%s)", self.side, transport) +- # mypy thinks transport is a BaseTransport, not a Transport. +- transport.set_write_buffer_limits(self.write_limit) # type: ignore ++ ++ transport = cast(asyncio.Transport, transport) ++ transport.set_write_buffer_limits(self.write_limit) ++ self.transport = transport + + # Copied from asyncio.StreamReaderProtocol + self.reader.set_transport(transport) +- self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) + + def connection_lost(self, exc: Optional[Exception]) -> None: + """ +diff --git a/src/websockets/server.py b/src/websockets/server.py +index b220a1b8..1e8ae861 100644 +--- a/src/websockets/server.py ++++ b/src/websockets/server.py +@@ -211,7 +211,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: + except Exception: + # Last-ditch attempt to avoid leaking connections on errors. + try: +- self.writer.close() ++ self.transport.close() + except Exception: # pragma: no cover + pass + +@@ -265,11 +265,11 @@ def write_http_response( + response = f"HTTP/1.1 {status.value} {status.phrase}\r\n" + response += str(headers) + +- self.writer.write(response.encode()) ++ self.transport.write(response.encode()) + + if body is not None: + logger.debug("%s > body (%d bytes)", self.side, len(body)) +- self.writer.write(body) ++ self.transport.write(body) + + async def process_request( + self, path: str, request_headers: Headers +diff --git a/tests/test_client_server.py b/tests/test_client_server.py +index e74ec6bf..6171f21b 100644 +--- a/tests/test_client_server.py ++++ b/tests/test_client_server.py +@@ -1166,7 +1166,7 @@ def test_server_close_crashes(self, close): + def test_client_closes_connection_before_handshake(self, handshake): + # We have mocked the handshake() method to prevent the client from + # performing the opening handshake. Force it to close the connection. +- self.client.writer.close() ++ self.client.transport.close() + # The server should stop properly anyway. It used to hang because the + # task handling the connection was waiting for the opening handshake. + +diff --git a/tests/test_protocol.py b/tests/test_protocol.py +index a6c42018..dfc2c6d4 100644 +--- a/tests/test_protocol.py ++++ b/tests/test_protocol.py +@@ -94,16 +94,16 @@ def tearDown(self): + # Utilities for writing tests. + + def make_drain_slow(self, delay=MS): +- # Process connection_made in order to initialize self.protocol.writer. ++ # Process connection_made in order to initialize self.protocol.transport. + self.run_loop_once() + +- original_drain = self.protocol.writer.drain ++ original_drain = self.protocol._drain + + async def delayed_drain(): + await asyncio.sleep(delay, loop=self.loop) + await original_drain() + +- self.protocol.writer.drain = delayed_drain ++ self.protocol._drain = delayed_drain + + close_frame = Frame(True, OP_CLOSE, serialize_close(1000, "close")) + local_close = Frame(True, OP_CLOSE, serialize_close(1000, "local")) +@@ -321,32 +321,32 @@ def test_local_address(self): + self.transport.get_extra_info = get_extra_info + + self.assertEqual(self.protocol.local_address, ("host", 4312)) +- get_extra_info.assert_called_with("sockname", None) ++ get_extra_info.assert_called_with("sockname") + + def test_local_address_before_connection(self): + # Emulate the situation before connection_open() runs. +- self.protocol.writer, _writer = None, self.protocol.writer ++ self.protocol.transport, _transport = None, self.protocol.transport + + try: + self.assertEqual(self.protocol.local_address, None) + finally: +- self.protocol.writer = _writer ++ self.protocol.transport = _transport + + def test_remote_address(self): + get_extra_info = unittest.mock.Mock(return_value=("host", 4312)) + self.transport.get_extra_info = get_extra_info + + self.assertEqual(self.protocol.remote_address, ("host", 4312)) +- get_extra_info.assert_called_with("peername", None) ++ get_extra_info.assert_called_with("peername") + + def test_remote_address_before_connection(self): + # Emulate the situation before connection_open() runs. +- self.protocol.writer, _writer = None, self.protocol.writer ++ self.protocol.transport, _transport = None, self.protocol.transport + + try: + self.assertEqual(self.protocol.remote_address, None) + finally: +- self.protocol.writer = _writer ++ self.protocol.transport = _transport + + def test_open(self): + self.assertTrue(self.protocol.open) + +From c55f374a93f7e9b14bc2bdcc8868f38f8a834a87 Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 13:55:16 +0200 +Subject: [PATCH 13/20] Rename writer to write. + +It's a better name for a function that writes bytes. +--- + src/websockets/framing.py | 8 ++++---- + tests/test_framing.py | 14 +++++++------- + tests/test_protocol.py | 4 ++-- + 3 files changed, 13 insertions(+), 13 deletions(-) + +diff --git a/src/websockets/framing.py b/src/websockets/framing.py +index 81a3185b..c24b8a73 100644 +--- a/src/websockets/framing.py ++++ b/src/websockets/framing.py +@@ -147,7 +147,7 @@ class Frame(NamedTuple): + + def write( + frame, +- writer: Callable[[bytes], Any], ++ write: Callable[[bytes], Any], + *, + mask: bool, + extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, +@@ -156,7 +156,7 @@ def write( + Write a WebSocket frame. + + :param frame: frame to write +- :param writer: function that writes bytes ++ :param write: function that writes bytes + :param mask: whether the frame should be masked i.e. whether the write + happens on the client side + :param extensions: list of classes with an ``encode()`` method that +@@ -210,10 +210,10 @@ def write( + + # Send the frame. + +- # The frame is written in a single call to writer in order to prevent ++ # The frame is written in a single call to write in order to prevent + # TCP fragmentation. See #68 for details. This also makes it safe to + # send frames concurrently from multiple coroutines. +- writer(output.getvalue()) ++ write(output.getvalue()) + + def check(frame) -> None: + """ +diff --git a/tests/test_framing.py b/tests/test_framing.py +index 9e6f1871..5def415d 100644 +--- a/tests/test_framing.py ++++ b/tests/test_framing.py +@@ -27,15 +27,15 @@ def decode(self, message, mask=False, max_size=None, extensions=None): + return frame + + def encode(self, frame, mask=False, extensions=None): +- writer = unittest.mock.Mock() +- frame.write(writer, mask=mask, extensions=extensions) +- # Ensure the entire frame is sent with a single call to writer(). ++ write = unittest.mock.Mock() ++ frame.write(write, mask=mask, extensions=extensions) ++ # Ensure the entire frame is sent with a single call to write(). + # Multiple calls cause TCP fragmentation and degrade performance. +- self.assertEqual(writer.call_count, 1) ++ self.assertEqual(write.call_count, 1) + # The frame data is the single positional argument of that call. +- self.assertEqual(len(writer.call_args[0]), 1) +- self.assertEqual(len(writer.call_args[1]), 0) +- return writer.call_args[0][0] ++ self.assertEqual(len(write.call_args[0]), 1) ++ self.assertEqual(len(write.call_args[1]), 0) ++ return write.call_args[0][0] + + def round_trip(self, message, expected, mask=False, extensions=None): + decoded = self.decode(message, mask, extensions=extensions) +diff --git a/tests/test_protocol.py b/tests/test_protocol.py +index dfc2c6d4..d2793faf 100644 +--- a/tests/test_protocol.py ++++ b/tests/test_protocol.py +@@ -114,9 +114,9 @@ def receive_frame(self, frame): + Make the protocol receive a frame. + + """ +- writer = self.protocol.data_received ++ write = self.protocol.data_received + mask = not self.protocol.is_client +- frame.write(writer, mask=mask) ++ frame.write(write, mask=mask) + + def receive_eof(self): + """ + +From 3e03d1a19dbbf54e1c714c933b90ec283caa7bb0 Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 15:49:08 +0200 +Subject: [PATCH 14/20] Update to the latest version of mypy. + +The bugs that were locking us on an old version are fixed. +--- + src/websockets/__init__.py | 15 ++++++++------- + src/websockets/__main__.py | 8 ++++---- + src/websockets/client.py | 3 +-- + src/websockets/handshake.py | 15 ++++++++++----- + src/websockets/protocol.py | 14 +++++++++----- + src/websockets/server.py | 11 +++-------- + tox.ini | 2 +- + 7 files changed, 36 insertions(+), 32 deletions(-) + +diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py +index e7ba31ce..6bad0f7b 100644 +--- a/src/websockets/__init__.py ++++ b/src/websockets/__init__.py +@@ -1,12 +1,13 @@ + # This relies on each of the submodules having an __all__ variable. + +-from .auth import * +-from .client import * +-from .exceptions import * +-from .protocol import * +-from .server import * +-from .typing import * +-from .uri import * ++from . import auth, client, exceptions, protocol, server, typing, uri ++from .auth import * # noqa ++from .client import * # noqa ++from .exceptions import * # noqa ++from .protocol import * # noqa ++from .server import * # noqa ++from .typing import * # noqa ++from .uri import * # noqa + from .version import version as __version__ # noqa + + +diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py +index bccb8aa5..394f7ac7 100644 +--- a/src/websockets/__main__.py ++++ b/src/websockets/__main__.py +@@ -6,8 +6,8 @@ + import threading + from typing import Any, Set + +-import websockets +-from websockets.exceptions import format_close ++from .client import connect ++from .exceptions import ConnectionClosed, format_close + + + if sys.platform == "win32": +@@ -95,7 +95,7 @@ def print_over_input(string: str) -> None: + stop: "asyncio.Future[None]", + ) -> None: + try: +- websocket = await websockets.connect(uri) ++ websocket = await connect(uri) + except Exception as exc: + print_over_input(f"Failed to connect to {uri}: {exc}.") + exit_from_event_loop_thread(loop, stop) +@@ -122,7 +122,7 @@ def print_over_input(string: str) -> None: + if incoming in done: + try: + message = incoming.result() +- except websockets.ConnectionClosed: ++ except ConnectionClosed: + break + else: + if isinstance(message, str): +diff --git a/src/websockets/client.py b/src/websockets/client.py +index 34cd8624..725ec1e7 100644 +--- a/src/websockets/client.py ++++ b/src/websockets/client.py +@@ -24,7 +24,6 @@ + from .extensions.permessage_deflate import ClientPerMessageDeflateFactory + from .handshake import build_request, check_response + from .headers import ( +- ExtensionHeader, + build_authorization_basic, + build_extension, + build_subprotocol, +@@ -33,7 +32,7 @@ + ) + from .http import USER_AGENT, Headers, HeadersLike, read_response + from .protocol import WebSocketCommonProtocol +-from .typing import Origin, Subprotocol ++from .typing import ExtensionHeader, Origin, Subprotocol + from .uri import WebSocketURI, parse_uri + + +diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py +index 17332d15..9bfe2775 100644 +--- a/src/websockets/handshake.py ++++ b/src/websockets/handshake.py +@@ -29,9 +29,10 @@ + import binascii + import hashlib + import random ++from typing import List + + from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade +-from .headers import parse_connection, parse_upgrade ++from .headers import ConnectionOption, UpgradeProtocol, parse_connection, parse_upgrade + from .http import Headers, MultipleValuesError + + +@@ -74,14 +75,16 @@ def check_request(headers: Headers) -> str: + is invalid; then the server must return 400 Bad Request error + + """ +- connection = sum( ++ connection: List[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade("Connection", ", ".join(connection)) + +- upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], []) ++ upgrade: List[UpgradeProtocol] = sum( ++ [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ++ ) + + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. It's supposed to be 'WebSocket'. +@@ -148,14 +151,16 @@ def check_response(headers: Headers, key: str) -> None: + is invalid + + """ +- connection = sum( ++ connection: List[ConnectionOption] = sum( + [parse_connection(value) for value in headers.get_all("Connection")], [] + ) + + if not any(value.lower() == "upgrade" for value in connection): + raise InvalidUpgrade("Connection", " ".join(connection)) + +- upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], []) ++ upgrade: List[UpgradeProtocol] = sum( ++ [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] ++ ) + + # For compatibility with non-strict implementations, ignore case when + # checking the Upgrade header. It's supposed to be 'WebSocket'. +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index eb3d6bcc..b7c1f19c 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -601,10 +601,14 @@ def closed(self) -> bool: + + elif isinstance(message, AsyncIterable): + # aiter_message = aiter(message) without aiter +- aiter_message = type(message).__aiter__(message) ++ # https://github.com/python/mypy/issues/5738 ++ aiter_message = type(message).__aiter__(message) # type: ignore + try: + # message_chunk = anext(aiter_message) without anext +- message_chunk = await type(aiter_message).__anext__(aiter_message) ++ # https://github.com/python/mypy/issues/5738 ++ message_chunk = await type(aiter_message).__anext__( # type: ignore ++ aiter_message ++ ) + except StopAsyncIteration: + return + opcode, data = prepare_data(message_chunk) +@@ -615,7 +619,8 @@ def closed(self) -> bool: + await self.write_frame(False, opcode, data) + + # Other fragments. +- async for message_chunk in aiter_message: ++ # https://github.com/python/mypy/issues/5738 ++ async for message_chunk in aiter_message: # type: ignore + confirm_opcode, data = prepare_data(message_chunk) + if confirm_opcode != opcode: + raise TypeError("data contains inconsistent types") +@@ -899,8 +904,7 @@ def connection_closed_exc(self) -> ConnectionClosed: + max_size = self.max_size + if text: + decoder_factory = codecs.getincrementaldecoder("utf-8") +- # https://github.com/python/typeshed/pull/2752 +- decoder = decoder_factory(errors="strict") # type: ignore ++ decoder = decoder_factory(errors="strict") + if max_size is None: + + def append(frame: Frame) -> None: +diff --git a/src/websockets/server.py b/src/websockets/server.py +index 1e8ae861..5114646d 100644 +--- a/src/websockets/server.py ++++ b/src/websockets/server.py +@@ -39,15 +39,10 @@ + from .extensions.base import Extension, ServerExtensionFactory + from .extensions.permessage_deflate import ServerPerMessageDeflateFactory + from .handshake import build_response, check_request +-from .headers import ( +- ExtensionHeader, +- build_extension, +- parse_extension, +- parse_subprotocol, +-) ++from .headers import build_extension, parse_extension, parse_subprotocol + from .http import USER_AGENT, Headers, HeadersLike, MultipleValuesError, read_request + from .protocol import WebSocketCommonProtocol +-from .typing import Origin, Subprotocol ++from .typing import ExtensionHeader, Origin, Subprotocol + + + __all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] +@@ -662,7 +657,7 @@ def is_serving(self) -> bool: + """ + try: + # Python ≥ 3.7 +- return self.server.is_serving() # type: ignore ++ return self.server.is_serving() + except AttributeError: # pragma: no cover + # Python < 3.7 + return self.server.sockets is not None +diff --git a/tox.ini b/tox.ini +index 801d4d5d..7397c90a 100644 +--- a/tox.ini ++++ b/tox.ini +@@ -25,4 +25,4 @@ deps = isort + + [testenv:mypy] + commands = mypy --strict src +-deps = mypy==0.670 ++deps = mypy + +From aef1e8781826653720f4db110a4d2af4418fa5b8 Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 15:49:51 +0200 +Subject: [PATCH 15/20] Fix deprecation warnings on Python 3.8. + +* Don't pass the deprecated loop argument. +* Ignore deprecation warnings for @asyncio.coroutine. +--- + src/websockets/protocol.py | 28 ++++++++++++++++++++-------- + src/websockets/server.py | 10 +++++++--- + tests/test_client_server.py | 36 +++++++++++++++++++++--------------- + tests/test_protocol.py | 5 ++++- + 4 files changed, 52 insertions(+), 27 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index b7c1f19c..76d46ad9 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -14,6 +14,7 @@ + import logging + import random + import struct ++import sys + import warnings + from typing import ( + Any, +@@ -230,7 +231,9 @@ def __init__( + ) + + self.transport: asyncio.Transport +- self._drain_lock = asyncio.Lock(loop=loop) ++ self._drain_lock = asyncio.Lock( ++ loop=loop if sys.version_info[:2] < (3, 8) else None ++ ) + + # Copied from asyncio.FlowControlMixin + self._paused = False +@@ -312,7 +315,9 @@ def __init__( + # write(...); yield from drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. +- await asyncio.sleep(0) ++ await asyncio.sleep( ++ 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None ++ ) + await self._drain_helper() + + def connection_open(self) -> None: +@@ -483,7 +488,7 @@ def closed(self) -> bool: + # pop_message_waiter and self.transfer_data_task. + await asyncio.wait( + [pop_message_waiter, self.transfer_data_task], +- loop=self.loop, ++ loop=self.loop if sys.version_info[:2] < (3, 8) else None, + return_when=asyncio.FIRST_COMPLETED, + ) + finally: +@@ -668,7 +673,7 @@ def closed(self) -> bool: + await asyncio.wait_for( + self.write_close_frame(serialize_close(code, reason)), + self.close_timeout, +- loop=self.loop, ++ loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + except asyncio.TimeoutError: + # If the close frame cannot be sent because the send buffers +@@ -687,7 +692,9 @@ def closed(self) -> bool: + # If close() is canceled during the wait, self.transfer_data_task + # is canceled before the timeout elapses. + await asyncio.wait_for( +- self.transfer_data_task, self.close_timeout, loop=self.loop ++ self.transfer_data_task, ++ self.close_timeout, ++ loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + except (asyncio.TimeoutError, asyncio.CancelledError): + pass +@@ -1106,7 +1113,10 @@ def append(frame: Frame) -> None: + + try: + while True: +- await asyncio.sleep(self.ping_interval, loop=self.loop) ++ await asyncio.sleep( ++ self.ping_interval, ++ loop=self.loop if sys.version_info[:2] < (3, 8) else None, ++ ) + + # ping() raises CancelledError if the connection is closed, + # when close_connection() cancels self.keepalive_ping_task. +@@ -1119,7 +1129,9 @@ def append(frame: Frame) -> None: + if self.ping_timeout is not None: + try: + await asyncio.wait_for( +- ping_waiter, self.ping_timeout, loop=self.loop ++ ping_waiter, ++ self.ping_timeout, ++ loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + except asyncio.TimeoutError: + logger.debug("%s ! timed out waiting for pong", self.side) +@@ -1211,7 +1223,7 @@ def append(frame: Frame) -> None: + await asyncio.wait_for( + asyncio.shield(self.connection_lost_waiter), + self.close_timeout, +- loop=self.loop, ++ loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + except asyncio.TimeoutError: + pass +diff --git a/src/websockets/server.py b/src/websockets/server.py +index 5114646d..4f5e9e0e 100644 +--- a/src/websockets/server.py ++++ b/src/websockets/server.py +@@ -10,6 +10,7 @@ + import http + import logging + import socket ++import sys + import warnings + from types import TracebackType + from typing import ( +@@ -698,7 +699,9 @@ def close(self) -> None: + + # Wait until all accepted connections reach connection_made() and call + # register(). See https://bugs.python.org/issue34852 for details. +- await asyncio.sleep(0) ++ await asyncio.sleep( ++ 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None ++ ) + + # Close OPEN connections with status code 1001. Since the server was + # closed, handshake() closes OPENING conections with a HTTP 503 error. +@@ -707,7 +710,8 @@ def close(self) -> None: + # asyncio.wait doesn't accept an empty first argument + if self.websockets: + await asyncio.wait( +- [websocket.close(1001) for websocket in self.websockets], loop=self.loop ++ [websocket.close(1001) for websocket in self.websockets], ++ loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + + # Wait until all connection handlers are complete. +@@ -716,7 +720,7 @@ def close(self) -> None: + if self.websockets: + await asyncio.wait( + [websocket.handler_task for websocket in self.websockets], +- loop=self.loop, ++ loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) + + # Tell wait_closed() to return. +diff --git a/tests/test_client_server.py b/tests/test_client_server.py +index 6171f21b..85828bdb 100644 +--- a/tests/test_client_server.py ++++ b/tests/test_client_server.py +@@ -1381,13 +1381,16 @@ def test_client(self): + start_server = serve(handler, "localhost", 0) + server = self.loop.run_until_complete(start_server) + +- @asyncio.coroutine +- def run_client(): +- # Yield from connect. +- client = yield from connect(get_server_uri(server)) +- self.assertEqual(client.state, State.OPEN) +- yield from client.close() +- self.assertEqual(client.state, State.CLOSED) ++ # @asyncio.coroutine is deprecated on Python ≥ 3.8 ++ with warnings.catch_warnings(record=True): ++ ++ @asyncio.coroutine ++ def run_client(): ++ # Yield from connect. ++ client = yield from connect(get_server_uri(server)) ++ self.assertEqual(client.state, State.OPEN) ++ yield from client.close() ++ self.assertEqual(client.state, State.CLOSED) + + self.loop.run_until_complete(run_client()) + +@@ -1395,14 +1398,17 @@ def run_client(): + self.loop.run_until_complete(server.wait_closed()) + + def test_server(self): +- @asyncio.coroutine +- def run_server(): +- # Yield from serve. +- server = yield from serve(handler, "localhost", 0) +- self.assertTrue(server.sockets) +- server.close() +- yield from server.wait_closed() +- self.assertFalse(server.sockets) ++ # @asyncio.coroutine is deprecated on Python ≥ 3.8 ++ with warnings.catch_warnings(record=True): ++ ++ @asyncio.coroutine ++ def run_server(): ++ # Yield from serve. ++ server = yield from serve(handler, "localhost", 0) ++ self.assertTrue(server.sockets) ++ server.close() ++ yield from server.wait_closed() ++ self.assertFalse(server.sockets) + + self.loop.run_until_complete(run_server()) + +diff --git a/tests/test_protocol.py b/tests/test_protocol.py +index d2793faf..04e2a38f 100644 +--- a/tests/test_protocol.py ++++ b/tests/test_protocol.py +@@ -1,6 +1,7 @@ + import asyncio + import contextlib + import logging ++import sys + import unittest + import unittest.mock + import warnings +@@ -100,7 +101,9 @@ def make_drain_slow(self, delay=MS): + original_drain = self.protocol._drain + + async def delayed_drain(): +- await asyncio.sleep(delay, loop=self.loop) ++ await asyncio.sleep( ++ delay, loop=self.loop if sys.version_info[:2] < (3, 8) else None ++ ) + await original_drain() + + self.protocol._drain = delayed_drain + +From 504c66cf01bc7e2c4f5fbcceb1387cadf056a678 Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 14:02:51 +0200 +Subject: [PATCH 16/20] Document and test support for Python 3.8. + +--- + .circleci/config.yml | 12 ++++++++++++ + docs/changelog.rst | 2 ++ + setup.py | 1 + + tox.ini | 2 +- + 4 files changed, 16 insertions(+), 1 deletion(-) + +diff --git a/.circleci/config.yml b/.circleci/config.yml +index a6c85d23..0877c161 100644 +--- a/.circleci/config.yml ++++ b/.circleci/config.yml +@@ -29,6 +29,15 @@ jobs: + - checkout + - run: sudo pip install tox + - run: tox -e py37 ++ py38: ++ docker: ++ - image: circleci/python:3.8.0rc1 ++ steps: ++ # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. ++ - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc ++ - checkout ++ - run: sudo pip install tox ++ - run: tox -e py38 + + workflows: + version: 2 +@@ -41,3 +50,6 @@ workflows: + - py37: + requires: + - main ++ - py38: ++ requires: ++ - main +diff --git a/docs/changelog.rst b/docs/changelog.rst +index 87b2e438..2a106fbc 100644 +--- a/docs/changelog.rst ++++ b/docs/changelog.rst +@@ -8,6 +8,8 @@ Changelog + + *In development* + ++* Added compatibility with Python 3.8. ++ + 8.0.2 + ..... + +diff --git a/setup.py b/setup.py +index c7643010..f3581924 100644 +--- a/setup.py ++++ b/setup.py +@@ -53,6 +53,7 @@ + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', ++ 'Programming Language :: Python :: 3.8', + ], + package_dir = {'': 'src'}, + package_data = {'websockets': ['py.typed']}, +diff --git a/tox.ini b/tox.ini +index 7397c90a..825e3406 100644 +--- a/tox.ini ++++ b/tox.ini +@@ -1,5 +1,5 @@ + [tox] +-envlist = py36,py37,coverage,black,flake8,isort,mypy ++envlist = py36,py37,py38,coverage,black,flake8,isort,mypy + + [testenv] + commands = python -W default -m unittest {posargs} + +From fc99c16a53e5d0c85448714021a5deb74021920a Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 21:18:30 +0200 +Subject: [PATCH 17/20] Move test logging configuration to a single place. + +--- + tests/__init__.py | 5 +++++ + tests/test_client_server.py | 5 ----- + tests/test_protocol.py | 5 ----- + 3 files changed, 5 insertions(+), 10 deletions(-) + +diff --git a/tests/__init__.py b/tests/__init__.py +index e69de29b..dd78609f 100644 +--- a/tests/__init__.py ++++ b/tests/__init__.py +@@ -0,0 +1,5 @@ ++import logging ++ ++ ++# Avoid displaying stack traces at the ERROR logging level. ++logging.basicConfig(level=logging.CRITICAL) +diff --git a/tests/test_client_server.py b/tests/test_client_server.py +index 85828bdb..ce0f66ce 100644 +--- a/tests/test_client_server.py ++++ b/tests/test_client_server.py +@@ -2,7 +2,6 @@ + import contextlib + import functools + import http +-import logging + import pathlib + import random + import socket +@@ -37,10 +36,6 @@ + from .utils import AsyncioTestCase + + +-# Avoid displaying stack traces at the ERROR logging level. +-logging.basicConfig(level=logging.CRITICAL) +- +- + # Generate TLS certificate with: + # $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ + # -out test_localhost.crt -keyout test_localhost.key +diff --git a/tests/test_protocol.py b/tests/test_protocol.py +index 04e2a38f..d95260a8 100644 +--- a/tests/test_protocol.py ++++ b/tests/test_protocol.py +@@ -1,6 +1,5 @@ + import asyncio + import contextlib +-import logging + import sys + import unittest + import unittest.mock +@@ -13,10 +12,6 @@ + from .utils import MS, AsyncioTestCase + + +-# Avoid displaying stack traces at the ERROR logging level. +-logging.basicConfig(level=logging.CRITICAL) +- +- + async def async_iterable(iterable): + for item in iterable: + yield item + +From 91dea74a9430d8573f7fd58c35fd731eddbb6cdb Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 21:33:27 +0200 +Subject: [PATCH 18/20] Remove test that no longer makes sense. + +Since version 7.0, when the server closes, it terminates connections +with close code 1001 instead of canceling them. +--- + tests/test_client_server.py | 18 ++---------------- + 1 file changed, 2 insertions(+), 16 deletions(-) + +diff --git a/tests/test_client_server.py b/tests/test_client_server.py +index ce0f66ce..35913666 100644 +--- a/tests/test_client_server.py ++++ b/tests/test_client_server.py +@@ -173,7 +173,7 @@ class HealthCheckServerProtocol(WebSocketServerProtocol): + return http.HTTPStatus.OK, [("X-Access", "OK")], b"status = green\n" + + +-class SlowServerProtocol(WebSocketServerProtocol): ++class SlowOpeningHandshakeProtocol(WebSocketServerProtocol): + async def process_request(self, path, request_headers): + await asyncio.sleep(10 * MS) + +@@ -1165,7 +1165,7 @@ def test_client_closes_connection_before_handshake(self, handshake): + # The server should stop properly anyway. It used to hang because the + # task handling the connection was waiting for the opening handshake. + +- @with_server(create_protocol=SlowServerProtocol) ++ @with_server(create_protocol=SlowOpeningHandshakeProtocol) + def test_server_shuts_down_during_opening_handshake(self): + self.loop.call_later(5 * MS, self.server.close) + with self.assertRaises(InvalidStatusCode) as raised: +@@ -1188,20 +1188,6 @@ def test_server_shuts_down_during_connection_handling(self): + self.assertEqual(self.client.close_code, 1001) + self.assertEqual(server_ws.close_code, 1001) + +- @with_server() +- @unittest.mock.patch("websockets.server.WebSocketServerProtocol.close") +- def test_server_shuts_down_during_connection_close(self, _close): +- _close.side_effect = asyncio.CancelledError +- +- self.server.closing = True +- with self.temp_client(): +- self.loop.run_until_complete(self.client.send("Hello!")) +- reply = self.loop.run_until_complete(self.client.recv()) +- self.assertEqual(reply, "Hello!") +- +- # Websocket connection terminates abnormally. +- self.assertEqual(self.client.close_code, 1006) +- + @with_server() + def test_server_shuts_down_waits_until_handlers_terminate(self): + # This handler waits a bit after the connection is closed in order + +From 88073fe7da7e8c07a5f46ae7c77e1a4ef80736ee Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 21:41:13 +0200 +Subject: [PATCH 19/20] Fix refactoring error. + +WebSocketCommonProtocol.transport can be unset, but it cannot be None. +--- + src/websockets/protocol.py | 14 ++++++++++---- + tests/test_protocol.py | 8 ++++---- + 2 files changed, 14 insertions(+), 8 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index 76d46ad9..0623e136 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -366,9 +366,12 @@ def local_address(self) -> Any: + been established yet. + + """ +- if self.transport is None: ++ try: ++ transport = self.transport ++ except AttributeError: + return None +- return self.transport.get_extra_info("sockname") ++ else: ++ return transport.get_extra_info("sockname") + + @property + def remote_address(self) -> Any: +@@ -379,9 +382,12 @@ def remote_address(self) -> Any: + been established yet. + + """ +- if self.transport is None: ++ try: ++ transport = self.transport ++ except AttributeError: + return None +- return self.transport.get_extra_info("peername") ++ else: ++ return transport.get_extra_info("peername") + + @property + def open(self) -> bool: +diff --git a/tests/test_protocol.py b/tests/test_protocol.py +index d95260a8..66a822e7 100644 +--- a/tests/test_protocol.py ++++ b/tests/test_protocol.py +@@ -323,8 +323,8 @@ def test_local_address(self): + + def test_local_address_before_connection(self): + # Emulate the situation before connection_open() runs. +- self.protocol.transport, _transport = None, self.protocol.transport +- ++ _transport = self.protocol.transport ++ del self.protocol.transport + try: + self.assertEqual(self.protocol.local_address, None) + finally: +@@ -339,8 +339,8 @@ def test_remote_address(self): + + def test_remote_address_before_connection(self): + # Emulate the situation before connection_open() runs. +- self.protocol.transport, _transport = None, self.protocol.transport +- ++ _transport = self.protocol.transport ++ del self.protocol.transport + try: + self.assertEqual(self.protocol.remote_address, None) + finally: + +From d75d4a4d34a7485ab8e5f2e65dee06dc37b7cfba Mon Sep 17 00:00:00 2001 +From: Aymeric Augustin +Date: Sat, 5 Oct 2019 21:39:14 +0200 +Subject: [PATCH 20/20] Remove useless type declaration. + +--- + src/websockets/protocol.py | 13 +++++-------- + 1 file changed, 5 insertions(+), 8 deletions(-) + +diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py +index 0623e136..6c29b2a5 100644 +--- a/src/websockets/protocol.py ++++ b/src/websockets/protocol.py +@@ -226,19 +226,16 @@ def __init__( + # ``self.read_limit``. The ``limit`` argument controls the line length + # limit and half the buffer limit of :class:`~asyncio.StreamReader`. + # That's why it must be set to half of ``self.read_limit``. +- self.reader: asyncio.StreamReader = asyncio.StreamReader( +- limit=read_limit // 2, loop=loop +- ) +- +- self.transport: asyncio.Transport +- self._drain_lock = asyncio.Lock( +- loop=loop if sys.version_info[:2] < (3, 8) else None +- ) ++ self.reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) + + # Copied from asyncio.FlowControlMixin + self._paused = False + self._drain_waiter: Optional[asyncio.Future[None]] = None + ++ self._drain_lock = asyncio.Lock( ++ loop=loop if sys.version_info[:2] < (3, 8) else None ++ ) ++ + # This class implements the data transfer and closing handshake, which + # are shared between the client-side and the server-side. + # Subclasses implement the opening handshake and, on success, execute