From fbab68fad41b1ddad2bbe7eccd793f5d6c52f761 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 11:41:21 +0200 Subject: [PATCH 01/20] Copy FlowControlMixin and StreamReaderProtocol. This is the official recommendation of Python core devs. The code is taken from the current 3.7 branch. --- src/websockets/protocol.py | 127 ++++++++++++++++++++++++++++++++++++- 1 file changed, 126 insertions(+), 1 deletion(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 1f0edcce..2f74cd23 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -61,7 +61,132 @@ class State(enum.IntEnum): # between the check and the assignment. -class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): +class FlowControlMixin(asyncio.Protocol): + """Reusable flow control logic for StreamWriter.drain(). + This implements the protocol methods pause_writing(), + resume_writing() and connection_lost(). If the subclass overrides + these it must call the super methods. + StreamWriter.drain() must wait for _drain_helper() coroutine. + """ + + def __init__(self, loop=None): + if loop is None: + self._loop = asyncio.get_event_loop() + else: + self._loop = loop + self._paused = False + self._drain_waiter = None + self._connection_lost = False + + def pause_writing(self): + assert not self._paused + self._paused = True + if self._loop.get_debug(): + logger.debug("%r pauses writing", self) + + def resume_writing(self): + assert self._paused + self._paused = False + if self._loop.get_debug(): + logger.debug("%r resumes writing", self) + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def connection_lost(self, exc): + self._connection_lost = True + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + async def _drain_helper(self): + if self._connection_lost: + raise ConnectionResetError("Connection lost") + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = self._loop.create_future() + self._drain_waiter = waiter + await waiter + + +class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol): + """Helper class to adapt between Protocol and StreamReader. + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader, client_connected_cb=None, loop=None): + super().__init__(loop=loop) + self._stream_reader = stream_reader + self._stream_writer = None + self._client_connected_cb = client_connected_cb + self._over_ssl = False + self._closed = self._loop.create_future() + + def connection_made(self, transport): + self._stream_reader.set_transport(transport) + self._over_ssl = transport.get_extra_info("sslcontext") is not None + if self._client_connected_cb is not None: + self._stream_writer = asyncio.StreamWriter( + transport, self, self._stream_reader, self._loop + ) + res = self._client_connected_cb(self._stream_reader, self._stream_writer) + if asyncio.iscoroutine(res): + self._loop.create_task(res) + + def connection_lost(self, exc): + if self._stream_reader is not None: + if exc is None: + self._stream_reader.feed_eof() + else: + self._stream_reader.set_exception(exc) + if not self._closed.done(): + if exc is None: + self._closed.set_result(None) + else: + self._closed.set_exception(exc) + super().connection_lost(exc) + self._stream_reader = None + self._stream_writer = None + + def data_received(self, data): + self._stream_reader.feed_data(data) + + def eof_received(self): + self._stream_reader.feed_eof() + if self._over_ssl: + # Prevent a warning in SSLProtocol.eof_received: + # "returning true from eof_received() + # has no effect when using ssl" + return False + return True + + def __del__(self): + # Prevent reports about unhandled exceptions. + # Better than self._closed._log_traceback = False hack + closed = self._closed + if closed.done() and not closed.cancelled(): + closed.exception() + + +class WebSocketCommonProtocol(StreamReaderProtocol): """ :class:`~asyncio.Protocol` subclass implementing the data transfer phase. From 57907330fe789f0b5f6b71b2810a0e36fee14844 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 11:54:04 +0200 Subject: [PATCH 02/20] Remove docstrings and debug logs. --- src/websockets/protocol.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 2f74cd23..d74c8157 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -62,12 +62,6 @@ class State(enum.IntEnum): class FlowControlMixin(asyncio.Protocol): - """Reusable flow control logic for StreamWriter.drain(). - This implements the protocol methods pause_writing(), - resume_writing() and connection_lost(). If the subclass overrides - these it must call the super methods. - StreamWriter.drain() must wait for _drain_helper() coroutine. - """ def __init__(self, loop=None): if loop is None: @@ -81,14 +75,10 @@ def __init__(self, loop=None): def pause_writing(self): assert not self._paused self._paused = True - if self._loop.get_debug(): - logger.debug("%r pauses writing", self) def resume_writing(self): assert self._paused self._paused = False - if self._loop.get_debug(): - logger.debug("%r resumes writing", self) waiter = self._drain_waiter if waiter is not None: @@ -125,12 +115,6 @@ def connection_lost(self, exc): class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol): - """Helper class to adapt between Protocol and StreamReader. - (This is a helper class instead of making StreamReader itself a - Protocol subclass, because the StreamReader has other potential - uses, and to prevent the user of the StreamReader to accidentally - call inappropriate methods of the protocol.) - """ def __init__(self, stream_reader, client_connected_cb=None, loop=None): super().__init__(loop=loop) From 53523d06c7752b394c1f1e5566f464db741e2617 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 11:54:43 +0200 Subject: [PATCH 03/20] Merge FlowControlMixin in StreamReaderProtocol. --- src/websockets/protocol.py | 54 +++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index d74c8157..49d8b4f2 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -61,9 +61,9 @@ class State(enum.IntEnum): # between the check and the assignment. -class FlowControlMixin(asyncio.Protocol): +class StreamReaderProtocol(asyncio.Protocol): - def __init__(self, loop=None): + def __init__(self, stream_reader, client_connected_cb=None, loop=None): if loop is None: self._loop = asyncio.get_event_loop() else: @@ -72,6 +72,12 @@ def __init__(self, loop=None): self._drain_waiter = None self._connection_lost = False + self._stream_reader = stream_reader + self._stream_writer = None + self._client_connected_cb = client_connected_cb + self._over_ssl = False + self._closed = self._loop.create_future() + def pause_writing(self): assert not self._paused self._paused = True @@ -86,22 +92,6 @@ def resume_writing(self): if not waiter.done(): waiter.set_result(None) - def connection_lost(self, exc): - self._connection_lost = True - # Wake up the writer if currently paused. - if not self._paused: - return - waiter = self._drain_waiter - if waiter is None: - return - self._drain_waiter = None - if waiter.done(): - return - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) - async def _drain_helper(self): if self._connection_lost: raise ConnectionResetError("Connection lost") @@ -113,17 +103,6 @@ def connection_lost(self, exc): self._drain_waiter = waiter await waiter - -class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol): - - def __init__(self, stream_reader, client_connected_cb=None, loop=None): - super().__init__(loop=loop) - self._stream_reader = stream_reader - self._stream_writer = None - self._client_connected_cb = client_connected_cb - self._over_ssl = False - self._closed = self._loop.create_future() - def connection_made(self, transport): self._stream_reader.set_transport(transport) self._over_ssl = transport.get_extra_info("sslcontext") is not None @@ -146,7 +125,22 @@ def connection_lost(self, exc): self._closed.set_result(None) else: self._closed.set_exception(exc) - super().connection_lost(exc) + + self._connection_lost = True + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + self._stream_reader = None self._stream_writer = None From e6d8bcfe2a0f16c50260cafab45f06da6baa1689 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 11:55:53 +0200 Subject: [PATCH 04/20] Deduplicate loop and _loop attributes. --- src/websockets/protocol.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 49d8b4f2..98c23ab1 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -63,11 +63,7 @@ class State(enum.IntEnum): class StreamReaderProtocol(asyncio.Protocol): - def __init__(self, stream_reader, client_connected_cb=None, loop=None): - if loop is None: - self._loop = asyncio.get_event_loop() - else: - self._loop = loop + def __init__(self, stream_reader, client_connected_cb=None): self._paused = False self._drain_waiter = None self._connection_lost = False @@ -76,7 +72,7 @@ def __init__(self, stream_reader, client_connected_cb=None, loop=None): self._stream_writer = None self._client_connected_cb = client_connected_cb self._over_ssl = False - self._closed = self._loop.create_future() + self._closed = self.loop.create_future() def pause_writing(self): assert not self._paused @@ -99,7 +95,7 @@ def resume_writing(self): return waiter = self._drain_waiter assert waiter is None or waiter.cancelled() - waiter = self._loop.create_future() + waiter = self.loop.create_future() self._drain_waiter = waiter await waiter @@ -108,11 +104,11 @@ def connection_made(self, transport): self._over_ssl = transport.get_extra_info("sslcontext") is not None if self._client_connected_cb is not None: self._stream_writer = asyncio.StreamWriter( - transport, self, self._stream_reader, self._loop + transport, self, self._stream_reader, self.loop ) res = self._client_connected_cb(self._stream_reader, self._stream_writer) if asyncio.iscoroutine(res): - self._loop.create_task(res) + self.loop.create_task(res) def connection_lost(self, exc): if self._stream_reader is not None: @@ -315,8 +311,6 @@ def __init__( self.read_limit = read_limit self.write_limit = write_limit - # Store a reference to loop to avoid relying on self._loop, a private - # attribute of StreamReaderProtocol, inherited from FlowControlMixin. if loop is None: loop = asyncio.get_event_loop() self.loop = loop @@ -331,7 +325,7 @@ def __init__( # limit and half the buffer limit of :class:`~asyncio.StreamReader`. # That's why it must be set to half of ``self.read_limit``. stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) - super().__init__(stream_reader, self.client_connected, loop) + super().__init__(stream_reader, self.client_connected) self.reader: asyncio.StreamReader self.writer: asyncio.StreamWriter From 4205cb176e62c938247af59e94214814242360ec Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 11:57:34 +0200 Subject: [PATCH 05/20] Remove client_connected callback. --- src/websockets/protocol.py | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 98c23ab1..bfc354a8 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -63,14 +63,13 @@ class State(enum.IntEnum): class StreamReaderProtocol(asyncio.Protocol): - def __init__(self, stream_reader, client_connected_cb=None): + def __init__(self, stream_reader): self._paused = False self._drain_waiter = None self._connection_lost = False self._stream_reader = stream_reader self._stream_writer = None - self._client_connected_cb = client_connected_cb self._over_ssl = False self._closed = self.loop.create_future() @@ -102,13 +101,11 @@ def resume_writing(self): def connection_made(self, transport): self._stream_reader.set_transport(transport) self._over_ssl = transport.get_extra_info("sslcontext") is not None - if self._client_connected_cb is not None: - self._stream_writer = asyncio.StreamWriter( - transport, self, self._stream_reader, self.loop - ) - res = self._client_connected_cb(self._stream_reader, self._stream_writer) - if asyncio.iscoroutine(res): - self.loop.create_task(res) + self._stream_writer = asyncio.StreamWriter( + transport, self, self._stream_reader, self.loop + ) + self.reader = self._stream_reader + self.writer = self._stream_writer def connection_lost(self, exc): if self._stream_reader is not None: @@ -325,7 +322,7 @@ def __init__( # limit and half the buffer limit of :class:`~asyncio.StreamReader`. # That's why it must be set to half of ``self.read_limit``. stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) - super().__init__(stream_reader, self.client_connected) + super().__init__(stream_reader) self.reader: asyncio.StreamReader self.writer: asyncio.StreamWriter @@ -381,20 +378,6 @@ def __init__( # Task closing the TCP connection. self.close_connection_task: asyncio.Task[None] - def client_connected( - self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: - """ - Callback when the TCP connection is established. - - Record references to the stream reader and the stream writer to avoid - using private attributes ``_stream_reader`` and ``_stream_writer`` of - :class:`~asyncio.StreamReaderProtocol`. - - """ - self.reader = reader - self.writer = writer - def connection_open(self) -> None: """ Callback when the WebSocket opening handshake completes. From 2ebb4b4eff76be388ee37570e891ed8432e19a9f Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 12:00:31 +0200 Subject: [PATCH 06/20] Deduplicate reader/writer and _stream_reader/writer attributes. --- src/websockets/protocol.py | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index bfc354a8..9c61a409 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -62,14 +62,11 @@ class State(enum.IntEnum): class StreamReaderProtocol(asyncio.Protocol): - - def __init__(self, stream_reader): + def __init__(self): self._paused = False self._drain_waiter = None self._connection_lost = False - self._stream_reader = stream_reader - self._stream_writer = None self._over_ssl = False self._closed = self.loop.create_future() @@ -99,20 +96,16 @@ def resume_writing(self): await waiter def connection_made(self, transport): - self._stream_reader.set_transport(transport) + self.reader.set_transport(transport) self._over_ssl = transport.get_extra_info("sslcontext") is not None - self._stream_writer = asyncio.StreamWriter( - transport, self, self._stream_reader, self.loop - ) - self.reader = self._stream_reader - self.writer = self._stream_writer + self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) def connection_lost(self, exc): - if self._stream_reader is not None: + if self.reader is not None: if exc is None: - self._stream_reader.feed_eof() + self.reader.feed_eof() else: - self._stream_reader.set_exception(exc) + self.reader.set_exception(exc) if not self._closed.done(): if exc is None: self._closed.set_result(None) @@ -134,14 +127,14 @@ def connection_lost(self, exc): else: waiter.set_exception(exc) - self._stream_reader = None - self._stream_writer = None + del self.reader + del self.writer def data_received(self, data): - self._stream_reader.feed_data(data) + self.reader.feed_data(data) def eof_received(self): - self._stream_reader.feed_eof() + self.reader.feed_eof() if self._over_ssl: # Prevent a warning in SSLProtocol.eof_received: # "returning true from eof_received() @@ -321,13 +314,14 @@ def __init__( # ``self.read_limit``. The ``limit`` argument controls the line length # limit and half the buffer limit of :class:`~asyncio.StreamReader`. # That's why it must be set to half of ``self.read_limit``. - stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) - super().__init__(stream_reader) - - self.reader: asyncio.StreamReader + self.reader: asyncio.StreamReader = asyncio.StreamReader( + limit=read_limit // 2, loop=loop + ) self.writer: asyncio.StreamWriter self._drain_lock = asyncio.Lock(loop=loop) + super().__init__() + # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. # Subclasses implement the opening handshake and, on success, execute From c7ae53f4dfc586f21af1fdbdcfba321bbccd491e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 12:25:40 +0200 Subject: [PATCH 07/20] Merge asyncio.Protocol methods. --- src/websockets/protocol.py | 160 ++++++++++++++++--------------------- 1 file changed, 70 insertions(+), 90 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 9c61a409..89e3464a 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -70,20 +70,6 @@ def __init__(self): self._over_ssl = False self._closed = self.loop.create_future() - def pause_writing(self): - assert not self._paused - self._paused = True - - def resume_writing(self): - assert self._paused - self._paused = False - - waiter = self._drain_waiter - if waiter is not None: - self._drain_waiter = None - if not waiter.done(): - waiter.set_result(None) - async def _drain_helper(self): if self._connection_lost: raise ConnectionResetError("Connection lost") @@ -95,53 +81,6 @@ def resume_writing(self): self._drain_waiter = waiter await waiter - def connection_made(self, transport): - self.reader.set_transport(transport) - self._over_ssl = transport.get_extra_info("sslcontext") is not None - self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) - - def connection_lost(self, exc): - if self.reader is not None: - if exc is None: - self.reader.feed_eof() - else: - self.reader.set_exception(exc) - if not self._closed.done(): - if exc is None: - self._closed.set_result(None) - else: - self._closed.set_exception(exc) - - self._connection_lost = True - # Wake up the writer if currently paused. - if not self._paused: - return - waiter = self._drain_waiter - if waiter is None: - return - self._drain_waiter = None - if waiter.done(): - return - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) - - del self.reader - del self.writer - - def data_received(self, data): - self.reader.feed_data(data) - - def eof_received(self): - self.reader.feed_eof() - if self._over_ssl: - # Prevent a warning in SSLProtocol.eof_received: - # "returning true from eof_received() - # has no effect when using ssl" - return False - return True - def __del__(self): # Prevent reports about unhandled exceptions. # Better than self._closed._log_traceback = False hack @@ -1363,7 +1302,7 @@ def abort_pings(self) -> None: "%s - aborted pending ping%s: %s", self.side, plural, pings_hex ) - # asyncio.StreamReaderProtocol methods + # asyncio.Protocol methods def connection_made(self, transport: asyncio.BaseTransport) -> None: """ @@ -1382,34 +1321,11 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: logger.debug("%s - event = connection_made(%s)", self.side, transport) # mypy thinks transport is a BaseTransport, not a Transport. transport.set_write_buffer_limits(self.write_limit) # type: ignore - super().connection_made(transport) - - def eof_received(self) -> bool: - """ - Close the transport after receiving EOF. - - Since Python 3.5, `:meth:~StreamReaderProtocol.eof_received` returns - ``True`` on non-TLS connections. - - See http://bugs.python.org/issue24539 for more information. - - This is inappropriate for ``websockets`` for at least three reasons: - - 1. The use case is to read data until EOF with self.reader.read(-1). - Since WebSocket is a TLV protocol, this never happens. - - 2. It doesn't work on TLS connections. A falsy value must be - returned to have the same behavior on TLS and plain connections. - - 3. The WebSocket protocol has its own closing handshake. Endpoints - close the TCP connection after sending a close frame. - - As a consequence we revert to the previous, more useful behavior. - """ - logger.debug("%s - event = eof_received()", self.side) - super().eof_received() - return False + # Copied from asyncio.StreamReaderProtocol + self.reader.set_transport(transport) + self._over_ssl = transport.get_extra_info("sslcontext") is not None + self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) def connection_lost(self, exc: Optional[Exception]) -> None: """ @@ -1434,4 +1350,68 @@ def connection_lost(self, exc: Optional[Exception]) -> None: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. self.connection_lost_waiter.set_result(None) - super().connection_lost(exc) + + # Copied from asyncio.StreamReaderProtocol + if self.reader is not None: + if exc is None: + self.reader.feed_eof() + else: + self.reader.set_exception(exc) + if not self._closed.done(): + if exc is None: + self._closed.set_result(None) + else: + self._closed.set_exception(exc) + + # Copied from asyncio.FlowControlMixin + self._connection_lost = True + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + del self.reader + del self.writer + + def pause_writing(self) -> None: + assert not self._paused + self._paused = True + + def resume_writing(self) -> None: + assert self._paused + self._paused = False + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def data_received(self, data: bytes) -> None: + logger.debug("%s - event = data_received(<%d bytes>)", self.side, len(data)) + self.reader.feed_data(data) + + def eof_received(self) -> None: + """ + Close the transport after receiving EOF. + + The WebSocket protocol has its own closing handshake: endpoints close + the TCP or TLS connection after sending and receiving a close frame. + + As a consequence, they never need to write after receiving EOF, so + there's no reason to keep the transport open by returning ``True``. + + Besides, that doesn't work on TLS connections. + + """ + logger.debug("%s - event = eof_received()", self.side) + self.reader.feed_eof() From ab5dd72aebd86b878ff08d1ebd5cbb3256e2d6e3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 13:09:40 +0200 Subject: [PATCH 08/20] Finish merging StreamReaderProtocol and FlowControlMixin. --- src/websockets/protocol.py | 59 +++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 89e3464a..2db44e5d 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -61,35 +61,7 @@ class State(enum.IntEnum): # between the check and the assignment. -class StreamReaderProtocol(asyncio.Protocol): - def __init__(self): - self._paused = False - self._drain_waiter = None - self._connection_lost = False - - self._over_ssl = False - self._closed = self.loop.create_future() - - async def _drain_helper(self): - if self._connection_lost: - raise ConnectionResetError("Connection lost") - if not self._paused: - return - waiter = self._drain_waiter - assert waiter is None or waiter.cancelled() - waiter = self.loop.create_future() - self._drain_waiter = waiter - await waiter - - def __del__(self): - # Prevent reports about unhandled exceptions. - # Better than self._closed._log_traceback = False hack - closed = self._closed - if closed.done() and not closed.cancelled(): - closed.exception() - - -class WebSocketCommonProtocol(StreamReaderProtocol): +class WebSocketCommonProtocol(asyncio.Protocol): """ :class:`~asyncio.Protocol` subclass implementing the data transfer phase. @@ -259,7 +231,14 @@ def __init__( self.writer: asyncio.StreamWriter self._drain_lock = asyncio.Lock(loop=loop) - super().__init__() + # Copied from asyncio.FlowControlMixin + self._paused = False + self._drain_waiter: Optional[asyncio.Future[None]] = None + self._connection_lost = False + + # Copied from asyncio.StreamReaderProtocol + self._over_ssl = False + self._closed = self.loop.create_future() # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. @@ -311,6 +290,26 @@ def __init__( # Task closing the TCP connection. self.close_connection_task: asyncio.Task[None] + # Copied from asyncio.StreamReaderProtocol + def __del__(self) -> None: + # Prevent reports about unhandled exceptions. + # Better than self._closed._log_traceback = False hack + closed = self._closed + if closed.done() and not closed.cancelled(): + closed.exception() + + # Copied from asyncio.FlowControlMixin + async def _drain_helper(self) -> None: + if self._connection_lost: + raise ConnectionResetError("Connection lost") + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = self.loop.create_future() + self._drain_waiter = waiter + await waiter + def connection_open(self) -> None: """ Callback when the WebSocket opening handshake completes. From 080543c4961037fcd8a207956e48474366e13ad5 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 13:17:26 +0200 Subject: [PATCH 09/20] Deduplicate connection termination tracking. --- src/websockets/protocol.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 2db44e5d..1cd5a91c 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -234,11 +234,9 @@ def __init__( # Copied from asyncio.FlowControlMixin self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None - self._connection_lost = False # Copied from asyncio.StreamReaderProtocol self._over_ssl = False - self._closed = self.loop.create_future() # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. @@ -290,17 +288,14 @@ def __init__( # Task closing the TCP connection. self.close_connection_task: asyncio.Task[None] - # Copied from asyncio.StreamReaderProtocol - def __del__(self) -> None: - # Prevent reports about unhandled exceptions. - # Better than self._closed._log_traceback = False hack - closed = self._closed - if closed.done() and not closed.cancelled(): - closed.exception() + # asyncio.StreamWriter expects this attribute on the Protocol + @property + def _closed(self) -> asyncio.Future: + return self.connection_lost_waiter # Copied from asyncio.FlowControlMixin async def _drain_helper(self) -> None: - if self._connection_lost: + if self.connection_lost_waiter.done(): raise ConnectionResetError("Connection lost") if not self._paused: return @@ -1356,14 +1351,8 @@ def connection_lost(self, exc: Optional[Exception]) -> None: self.reader.feed_eof() else: self.reader.set_exception(exc) - if not self._closed.done(): - if exc is None: - self._closed.set_result(None) - else: - self._closed.set_exception(exc) # Copied from asyncio.FlowControlMixin - self._connection_lost = True # Wake up the writer if currently paused. if not self._paused: return From 284932f967699467908e69226118f9bcd5ba8f9d Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 13:21:25 +0200 Subject: [PATCH 10/20] Remove unused attribute. --- src/websockets/protocol.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 1cd5a91c..a1c90916 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -235,9 +235,6 @@ def __init__( self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None - # Copied from asyncio.StreamReaderProtocol - self._over_ssl = False - # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. # Subclasses implement the opening handshake and, on success, execute @@ -1318,7 +1315,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: # Copied from asyncio.StreamReaderProtocol self.reader.set_transport(transport) - self._over_ssl = transport.get_extra_info("sslcontext") is not None self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) def connection_lost(self, exc: Optional[Exception]) -> None: From 9eeeeb3547eb7abc2253c75e462a50a8f68ba16b Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 13:28:49 +0200 Subject: [PATCH 11/20] Ignore quality checks for code copied from asyncio. --- src/websockets/protocol.py | 51 +++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index a1c90916..0bb12fd5 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -287,11 +287,11 @@ def __init__( # asyncio.StreamWriter expects this attribute on the Protocol @property - def _closed(self) -> asyncio.Future: + def _closed(self) -> Any: # pragma: no cover return self.connection_lost_waiter # Copied from asyncio.FlowControlMixin - async def _drain_helper(self) -> None: + async def _drain_helper(self) -> None: # pragma: no cover if self.connection_lost_waiter.done(): raise ConnectionResetError("Connection lost") if not self._paused: @@ -1341,36 +1341,35 @@ def connection_lost(self, exc: Optional[Exception]) -> None: # - it must never be canceled. self.connection_lost_waiter.set_result(None) - # Copied from asyncio.StreamReaderProtocol - if self.reader is not None: - if exc is None: - self.reader.feed_eof() - else: - self.reader.set_exception(exc) + if True: # pragma: no cover - # Copied from asyncio.FlowControlMixin - # Wake up the writer if currently paused. - if not self._paused: - return - waiter = self._drain_waiter - if waiter is None: - return - self._drain_waiter = None - if waiter.done(): - return - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) + # Copied from asyncio.StreamReaderProtocol + if self.reader is not None: + if exc is None: + self.reader.feed_eof() + else: + self.reader.set_exception(exc) - del self.reader - del self.writer + # Copied from asyncio.FlowControlMixin + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) - def pause_writing(self) -> None: + def pause_writing(self) -> None: # pragma: no cover assert not self._paused self._paused = True - def resume_writing(self) -> None: + def resume_writing(self) -> None: # pragma: no cover assert self._paused self._paused = False From f79cdd8e50e82433bbb0c003279a6b6f13e6c17e Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 13:54:13 +0200 Subject: [PATCH 12/20] Remove asyncio.StreamWriter. It adds only one method for flow control. Copy it, as we've already copied the rest of the flow control implementation. --- src/websockets/client.py | 2 +- src/websockets/protocol.py | 64 +++++++++++++++++++++---------------- src/websockets/server.py | 6 ++-- tests/test_client_server.py | 2 +- tests/test_protocol.py | 18 +++++------ 5 files changed, 51 insertions(+), 41 deletions(-) diff --git a/src/websockets/client.py b/src/websockets/client.py index c1fdf88a..34cd8624 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -85,7 +85,7 @@ def write_http_request(self, path: str, headers: Headers) -> None: request = f"GET {path} HTTP/1.1\r\n" request += str(headers) - self.writer.write(request.encode()) + self.transport.write(request.encode()) async def read_http_response(self) -> Tuple[int, Headers]: """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 0bb12fd5..eb3d6bcc 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -228,7 +228,8 @@ def __init__( self.reader: asyncio.StreamReader = asyncio.StreamReader( limit=read_limit // 2, loop=loop ) - self.writer: asyncio.StreamWriter + + self.transport: asyncio.Transport self._drain_lock = asyncio.Lock(loop=loop) # Copied from asyncio.FlowControlMixin @@ -285,11 +286,6 @@ def __init__( # Task closing the TCP connection. self.close_connection_task: asyncio.Task[None] - # asyncio.StreamWriter expects this attribute on the Protocol - @property - def _closed(self) -> Any: # pragma: no cover - return self.connection_lost_waiter - # Copied from asyncio.FlowControlMixin async def _drain_helper(self) -> None: # pragma: no cover if self.connection_lost_waiter.done(): @@ -302,6 +298,23 @@ def _closed(self) -> Any: # pragma: no cover self._drain_waiter = waiter await waiter + # Copied from asyncio.StreamWriter + async def _drain(self) -> None: # pragma: no cover + if self.reader is not None: + exc = self.reader.exception() + if exc is not None: + raise exc + if self.transport is not None: + if self.transport.is_closing(): + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); yield from drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + await asyncio.sleep(0) + await self._drain_helper() + def connection_open(self) -> None: """ Callback when the WebSocket opening handshake completes. @@ -348,9 +361,9 @@ def local_address(self) -> Any: been established yet. """ - if self.writer is None: + if self.transport is None: return None - return self.writer.get_extra_info("sockname") + return self.transport.get_extra_info("sockname") @property def remote_address(self) -> Any: @@ -361,9 +374,9 @@ def remote_address(self) -> Any: been established yet. """ - if self.writer is None: + if self.transport is None: return None - return self.writer.get_extra_info("peername") + return self.transport.get_extra_info("peername") @property def open(self) -> bool: @@ -1037,7 +1050,9 @@ def append(frame: Frame) -> None: frame = Frame(fin, opcode, data) logger.debug("%s > %r", self.side, frame) - frame.write(self.writer.write, mask=self.is_client, extensions=self.extensions) + frame.write( + self.transport.write, mask=self.is_client, extensions=self.extensions + ) try: # drain() cannot be called concurrently by multiple coroutines: @@ -1045,7 +1060,7 @@ def append(frame: Frame) -> None: # version of Python where this bugs exists is supported anymore. async with self._drain_lock: # Handle flow control automatically. - await self.writer.drain() + await self._drain() except ConnectionError: # Terminate the connection if the socket died. self.fail_connection() @@ -1147,9 +1162,9 @@ def append(frame: Frame) -> None: logger.debug("%s ! timed out waiting for TCP close", self.side) # Half-close the TCP connection if possible (when there's no TLS). - if self.writer.can_write_eof(): + if self.transport.can_write_eof(): logger.debug("%s x half-closing TCP connection", self.side) - self.writer.write_eof() + self.transport.write_eof() if await self.wait_for_connection_lost(): return @@ -1162,17 +1177,12 @@ def append(frame: Frame) -> None: # If connection_lost() was called, the TCP connection is closed. # However, if TLS is enabled, the transport still needs closing. # Else asyncio complains: ResourceWarning: unclosed transport. - try: - writer_is_closing = self.writer.is_closing # type: ignore - except AttributeError: # pragma: no cover - # Python < 3.7 - writer_is_closing = self.writer.transport.is_closing - if self.connection_lost_waiter.done() and writer_is_closing(): + if self.connection_lost_waiter.done() and self.transport.is_closing(): return # Close the TCP connection. Buffers are flushed asynchronously. logger.debug("%s x closing TCP connection", self.side) - self.writer.close() + self.transport.close() if await self.wait_for_connection_lost(): return @@ -1180,8 +1190,7 @@ def append(frame: Frame) -> None: # Abort the TCP connection. Buffers are discarded. logger.debug("%s x aborting TCP connection", self.side) - # mypy thinks self.writer.transport is a BaseTransport, not a Transport. - self.writer.transport.abort() # type: ignore + self.transport.abort() # connection_lost() is called quickly after aborting. await self.wait_for_connection_lost() @@ -1261,7 +1270,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: frame = Frame(True, OP_CLOSE, frame_data) logger.debug("%s > %r", self.side, frame) frame.write( - self.writer.write, mask=self.is_client, extensions=self.extensions + self.transport.write, mask=self.is_client, extensions=self.extensions ) # Start close_connection_task if the opening handshake didn't succeed. @@ -1310,12 +1319,13 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: """ logger.debug("%s - event = connection_made(%s)", self.side, transport) - # mypy thinks transport is a BaseTransport, not a Transport. - transport.set_write_buffer_limits(self.write_limit) # type: ignore + + transport = cast(asyncio.Transport, transport) + transport.set_write_buffer_limits(self.write_limit) + self.transport = transport # Copied from asyncio.StreamReaderProtocol self.reader.set_transport(transport) - self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop) def connection_lost(self, exc: Optional[Exception]) -> None: """ diff --git a/src/websockets/server.py b/src/websockets/server.py index b220a1b8..1e8ae861 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -211,7 +211,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: except Exception: # Last-ditch attempt to avoid leaking connections on errors. try: - self.writer.close() + self.transport.close() except Exception: # pragma: no cover pass @@ -265,11 +265,11 @@ def write_http_response( response = f"HTTP/1.1 {status.value} {status.phrase}\r\n" response += str(headers) - self.writer.write(response.encode()) + self.transport.write(response.encode()) if body is not None: logger.debug("%s > body (%d bytes)", self.side, len(body)) - self.writer.write(body) + self.transport.write(body) async def process_request( self, path: str, request_headers: Headers diff --git a/tests/test_client_server.py b/tests/test_client_server.py index e74ec6bf..6171f21b 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1166,7 +1166,7 @@ def test_server_close_crashes(self, close): def test_client_closes_connection_before_handshake(self, handshake): # We have mocked the handshake() method to prevent the client from # performing the opening handshake. Force it to close the connection. - self.client.writer.close() + self.client.transport.close() # The server should stop properly anyway. It used to hang because the # task handling the connection was waiting for the opening handshake. diff --git a/tests/test_protocol.py b/tests/test_protocol.py index a6c42018..dfc2c6d4 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -94,16 +94,16 @@ def tearDown(self): # Utilities for writing tests. def make_drain_slow(self, delay=MS): - # Process connection_made in order to initialize self.protocol.writer. + # Process connection_made in order to initialize self.protocol.transport. self.run_loop_once() - original_drain = self.protocol.writer.drain + original_drain = self.protocol._drain async def delayed_drain(): await asyncio.sleep(delay, loop=self.loop) await original_drain() - self.protocol.writer.drain = delayed_drain + self.protocol._drain = delayed_drain close_frame = Frame(True, OP_CLOSE, serialize_close(1000, "close")) local_close = Frame(True, OP_CLOSE, serialize_close(1000, "local")) @@ -321,32 +321,32 @@ def test_local_address(self): self.transport.get_extra_info = get_extra_info self.assertEqual(self.protocol.local_address, ("host", 4312)) - get_extra_info.assert_called_with("sockname", None) + get_extra_info.assert_called_with("sockname") def test_local_address_before_connection(self): # Emulate the situation before connection_open() runs. - self.protocol.writer, _writer = None, self.protocol.writer + self.protocol.transport, _transport = None, self.protocol.transport try: self.assertEqual(self.protocol.local_address, None) finally: - self.protocol.writer = _writer + self.protocol.transport = _transport def test_remote_address(self): get_extra_info = unittest.mock.Mock(return_value=("host", 4312)) self.transport.get_extra_info = get_extra_info self.assertEqual(self.protocol.remote_address, ("host", 4312)) - get_extra_info.assert_called_with("peername", None) + get_extra_info.assert_called_with("peername") def test_remote_address_before_connection(self): # Emulate the situation before connection_open() runs. - self.protocol.writer, _writer = None, self.protocol.writer + self.protocol.transport, _transport = None, self.protocol.transport try: self.assertEqual(self.protocol.remote_address, None) finally: - self.protocol.writer = _writer + self.protocol.transport = _transport def test_open(self): self.assertTrue(self.protocol.open) From c55f374a93f7e9b14bc2bdcc8868f38f8a834a87 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 13:55:16 +0200 Subject: [PATCH 13/20] Rename writer to write. It's a better name for a function that writes bytes. --- src/websockets/framing.py | 8 ++++---- tests/test_framing.py | 14 +++++++------- tests/test_protocol.py | 4 ++-- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 81a3185b..c24b8a73 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -147,7 +147,7 @@ class Frame(NamedTuple): def write( frame, - writer: Callable[[bytes], Any], + write: Callable[[bytes], Any], *, mask: bool, extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, @@ -156,7 +156,7 @@ def write( Write a WebSocket frame. :param frame: frame to write - :param writer: function that writes bytes + :param write: function that writes bytes :param mask: whether the frame should be masked i.e. whether the write happens on the client side :param extensions: list of classes with an ``encode()`` method that @@ -210,10 +210,10 @@ def write( # Send the frame. - # The frame is written in a single call to writer in order to prevent + # The frame is written in a single call to write in order to prevent # TCP fragmentation. See #68 for details. This also makes it safe to # send frames concurrently from multiple coroutines. - writer(output.getvalue()) + write(output.getvalue()) def check(frame) -> None: """ diff --git a/tests/test_framing.py b/tests/test_framing.py index 9e6f1871..5def415d 100644 --- a/tests/test_framing.py +++ b/tests/test_framing.py @@ -27,15 +27,15 @@ def decode(self, message, mask=False, max_size=None, extensions=None): return frame def encode(self, frame, mask=False, extensions=None): - writer = unittest.mock.Mock() - frame.write(writer, mask=mask, extensions=extensions) - # Ensure the entire frame is sent with a single call to writer(). + write = unittest.mock.Mock() + frame.write(write, mask=mask, extensions=extensions) + # Ensure the entire frame is sent with a single call to write(). # Multiple calls cause TCP fragmentation and degrade performance. - self.assertEqual(writer.call_count, 1) + self.assertEqual(write.call_count, 1) # The frame data is the single positional argument of that call. - self.assertEqual(len(writer.call_args[0]), 1) - self.assertEqual(len(writer.call_args[1]), 0) - return writer.call_args[0][0] + self.assertEqual(len(write.call_args[0]), 1) + self.assertEqual(len(write.call_args[1]), 0) + return write.call_args[0][0] def round_trip(self, message, expected, mask=False, extensions=None): decoded = self.decode(message, mask, extensions=extensions) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index dfc2c6d4..d2793faf 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -114,9 +114,9 @@ def receive_frame(self, frame): Make the protocol receive a frame. """ - writer = self.protocol.data_received + write = self.protocol.data_received mask = not self.protocol.is_client - frame.write(writer, mask=mask) + frame.write(write, mask=mask) def receive_eof(self): """ From 3e03d1a19dbbf54e1c714c933b90ec283caa7bb0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 15:49:08 +0200 Subject: [PATCH 14/20] Update to the latest version of mypy. The bugs that were locking us on an old version are fixed. --- src/websockets/__init__.py | 15 ++++++++------- src/websockets/__main__.py | 8 ++++---- src/websockets/client.py | 3 +-- src/websockets/handshake.py | 15 ++++++++++----- src/websockets/protocol.py | 14 +++++++++----- src/websockets/server.py | 11 +++-------- tox.ini | 2 +- 7 files changed, 36 insertions(+), 32 deletions(-) diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index e7ba31ce..6bad0f7b 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,12 +1,13 @@ # This relies on each of the submodules having an __all__ variable. -from .auth import * -from .client import * -from .exceptions import * -from .protocol import * -from .server import * -from .typing import * -from .uri import * +from . import auth, client, exceptions, protocol, server, typing, uri +from .auth import * # noqa +from .client import * # noqa +from .exceptions import * # noqa +from .protocol import * # noqa +from .server import * # noqa +from .typing import * # noqa +from .uri import * # noqa from .version import version as __version__ # noqa diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index bccb8aa5..394f7ac7 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -6,8 +6,8 @@ import threading from typing import Any, Set -import websockets -from websockets.exceptions import format_close +from .client import connect +from .exceptions import ConnectionClosed, format_close if sys.platform == "win32": @@ -95,7 +95,7 @@ def print_over_input(string: str) -> None: stop: "asyncio.Future[None]", ) -> None: try: - websocket = await websockets.connect(uri) + websocket = await connect(uri) except Exception as exc: print_over_input(f"Failed to connect to {uri}: {exc}.") exit_from_event_loop_thread(loop, stop) @@ -122,7 +122,7 @@ def print_over_input(string: str) -> None: if incoming in done: try: message = incoming.result() - except websockets.ConnectionClosed: + except ConnectionClosed: break else: if isinstance(message, str): diff --git a/src/websockets/client.py b/src/websockets/client.py index 34cd8624..725ec1e7 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -24,7 +24,6 @@ from .extensions.permessage_deflate import ClientPerMessageDeflateFactory from .handshake import build_request, check_response from .headers import ( - ExtensionHeader, build_authorization_basic, build_extension, build_subprotocol, @@ -33,7 +32,7 @@ ) from .http import USER_AGENT, Headers, HeadersLike, read_response from .protocol import WebSocketCommonProtocol -from .typing import Origin, Subprotocol +from .typing import ExtensionHeader, Origin, Subprotocol from .uri import WebSocketURI, parse_uri diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py index 17332d15..9bfe2775 100644 --- a/src/websockets/handshake.py +++ b/src/websockets/handshake.py @@ -29,9 +29,10 @@ import binascii import hashlib import random +from typing import List from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade -from .headers import parse_connection, parse_upgrade +from .headers import ConnectionOption, UpgradeProtocol, parse_connection, parse_upgrade from .http import Headers, MultipleValuesError @@ -74,14 +75,16 @@ def check_request(headers: Headers) -> str: is invalid; then the server must return 400 Bad Request error """ - connection = sum( + connection: List[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", ", ".join(connection)) - upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], []) + upgrade: List[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. @@ -148,14 +151,16 @@ def check_response(headers: Headers, key: str) -> None: is invalid """ - connection = sum( + connection: List[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", " ".join(connection)) - upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], []) + upgrade: List[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index eb3d6bcc..b7c1f19c 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -601,10 +601,14 @@ def closed(self) -> bool: elif isinstance(message, AsyncIterable): # aiter_message = aiter(message) without aiter - aiter_message = type(message).__aiter__(message) + # https://github.com/python/mypy/issues/5738 + aiter_message = type(message).__aiter__(message) # type: ignore try: # message_chunk = anext(aiter_message) without anext - message_chunk = await type(aiter_message).__anext__(aiter_message) + # https://github.com/python/mypy/issues/5738 + message_chunk = await type(aiter_message).__anext__( # type: ignore + aiter_message + ) except StopAsyncIteration: return opcode, data = prepare_data(message_chunk) @@ -615,7 +619,8 @@ def closed(self) -> bool: await self.write_frame(False, opcode, data) # Other fragments. - async for message_chunk in aiter_message: + # https://github.com/python/mypy/issues/5738 + async for message_chunk in aiter_message: # type: ignore confirm_opcode, data = prepare_data(message_chunk) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") @@ -899,8 +904,7 @@ def connection_closed_exc(self) -> ConnectionClosed: max_size = self.max_size if text: decoder_factory = codecs.getincrementaldecoder("utf-8") - # https://github.com/python/typeshed/pull/2752 - decoder = decoder_factory(errors="strict") # type: ignore + decoder = decoder_factory(errors="strict") if max_size is None: def append(frame: Frame) -> None: diff --git a/src/websockets/server.py b/src/websockets/server.py index 1e8ae861..5114646d 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -39,15 +39,10 @@ from .extensions.base import Extension, ServerExtensionFactory from .extensions.permessage_deflate import ServerPerMessageDeflateFactory from .handshake import build_response, check_request -from .headers import ( - ExtensionHeader, - build_extension, - parse_extension, - parse_subprotocol, -) +from .headers import build_extension, parse_extension, parse_subprotocol from .http import USER_AGENT, Headers, HeadersLike, MultipleValuesError, read_request from .protocol import WebSocketCommonProtocol -from .typing import Origin, Subprotocol +from .typing import ExtensionHeader, Origin, Subprotocol __all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] @@ -662,7 +657,7 @@ def is_serving(self) -> bool: """ try: # Python ≥ 3.7 - return self.server.is_serving() # type: ignore + return self.server.is_serving() except AttributeError: # pragma: no cover # Python < 3.7 return self.server.sockets is not None diff --git a/tox.ini b/tox.ini index 801d4d5d..7397c90a 100644 --- a/tox.ini +++ b/tox.ini @@ -25,4 +25,4 @@ deps = isort [testenv:mypy] commands = mypy --strict src -deps = mypy==0.670 +deps = mypy From aef1e8781826653720f4db110a4d2af4418fa5b8 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 15:49:51 +0200 Subject: [PATCH 15/20] Fix deprecation warnings on Python 3.8. * Don't pass the deprecated loop argument. * Ignore deprecation warnings for @asyncio.coroutine. --- src/websockets/protocol.py | 28 ++++++++++++++++++++-------- src/websockets/server.py | 10 +++++++--- tests/test_client_server.py | 36 +++++++++++++++++++++--------------- tests/test_protocol.py | 5 ++++- 4 files changed, 52 insertions(+), 27 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index b7c1f19c..76d46ad9 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -14,6 +14,7 @@ import logging import random import struct +import sys import warnings from typing import ( Any, @@ -230,7 +231,9 @@ def __init__( ) self.transport: asyncio.Transport - self._drain_lock = asyncio.Lock(loop=loop) + self._drain_lock = asyncio.Lock( + loop=loop if sys.version_info[:2] < (3, 8) else None + ) # Copied from asyncio.FlowControlMixin self._paused = False @@ -312,7 +315,9 @@ def __init__( # write(...); yield from drain() # in a loop would never call connection_lost(), so it # would not see an error when the socket is closed. - await asyncio.sleep(0) + await asyncio.sleep( + 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None + ) await self._drain_helper() def connection_open(self) -> None: @@ -483,7 +488,7 @@ def closed(self) -> bool: # pop_message_waiter and self.transfer_data_task. await asyncio.wait( [pop_message_waiter, self.transfer_data_task], - loop=self.loop, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, return_when=asyncio.FIRST_COMPLETED, ) finally: @@ -668,7 +673,7 @@ def closed(self) -> bool: await asyncio.wait_for( self.write_close_frame(serialize_close(code, reason)), self.close_timeout, - loop=self.loop, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers @@ -687,7 +692,9 @@ def closed(self) -> bool: # If close() is canceled during the wait, self.transfer_data_task # is canceled before the timeout elapses. await asyncio.wait_for( - self.transfer_data_task, self.close_timeout, loop=self.loop + self.transfer_data_task, + self.close_timeout, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -1106,7 +1113,10 @@ def append(frame: Frame) -> None: try: while True: - await asyncio.sleep(self.ping_interval, loop=self.loop) + await asyncio.sleep( + self.ping_interval, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) # ping() raises CancelledError if the connection is closed, # when close_connection() cancels self.keepalive_ping_task. @@ -1119,7 +1129,9 @@ def append(frame: Frame) -> None: if self.ping_timeout is not None: try: await asyncio.wait_for( - ping_waiter, self.ping_timeout, loop=self.loop + ping_waiter, + self.ping_timeout, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) except asyncio.TimeoutError: logger.debug("%s ! timed out waiting for pong", self.side) @@ -1211,7 +1223,7 @@ def append(frame: Frame) -> None: await asyncio.wait_for( asyncio.shield(self.connection_lost_waiter), self.close_timeout, - loop=self.loop, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) except asyncio.TimeoutError: pass diff --git a/src/websockets/server.py b/src/websockets/server.py index 5114646d..4f5e9e0e 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -10,6 +10,7 @@ import http import logging import socket +import sys import warnings from types import TracebackType from typing import ( @@ -698,7 +699,9 @@ def close(self) -> None: # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. - await asyncio.sleep(0) + await asyncio.sleep( + 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None + ) # Close OPEN connections with status code 1001. Since the server was # closed, handshake() closes OPENING conections with a HTTP 503 error. @@ -707,7 +710,8 @@ def close(self) -> None: # asyncio.wait doesn't accept an empty first argument if self.websockets: await asyncio.wait( - [websocket.close(1001) for websocket in self.websockets], loop=self.loop + [websocket.close(1001) for websocket in self.websockets], + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) # Wait until all connection handlers are complete. @@ -716,7 +720,7 @@ def close(self) -> None: if self.websockets: await asyncio.wait( [websocket.handler_task for websocket in self.websockets], - loop=self.loop, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) # Tell wait_closed() to return. diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 6171f21b..85828bdb 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -1381,13 +1381,16 @@ def test_client(self): start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) - @asyncio.coroutine - def run_client(): - # Yield from connect. - client = yield from connect(get_server_uri(server)) - self.assertEqual(client.state, State.OPEN) - yield from client.close() - self.assertEqual(client.state, State.CLOSED) + # @asyncio.coroutine is deprecated on Python ≥ 3.8 + with warnings.catch_warnings(record=True): + + @asyncio.coroutine + def run_client(): + # Yield from connect. + client = yield from connect(get_server_uri(server)) + self.assertEqual(client.state, State.OPEN) + yield from client.close() + self.assertEqual(client.state, State.CLOSED) self.loop.run_until_complete(run_client()) @@ -1395,14 +1398,17 @@ def run_client(): self.loop.run_until_complete(server.wait_closed()) def test_server(self): - @asyncio.coroutine - def run_server(): - # Yield from serve. - server = yield from serve(handler, "localhost", 0) - self.assertTrue(server.sockets) - server.close() - yield from server.wait_closed() - self.assertFalse(server.sockets) + # @asyncio.coroutine is deprecated on Python ≥ 3.8 + with warnings.catch_warnings(record=True): + + @asyncio.coroutine + def run_server(): + # Yield from serve. + server = yield from serve(handler, "localhost", 0) + self.assertTrue(server.sockets) + server.close() + yield from server.wait_closed() + self.assertFalse(server.sockets) self.loop.run_until_complete(run_server()) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d2793faf..04e2a38f 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,6 +1,7 @@ import asyncio import contextlib import logging +import sys import unittest import unittest.mock import warnings @@ -100,7 +101,9 @@ def make_drain_slow(self, delay=MS): original_drain = self.protocol._drain async def delayed_drain(): - await asyncio.sleep(delay, loop=self.loop) + await asyncio.sleep( + delay, loop=self.loop if sys.version_info[:2] < (3, 8) else None + ) await original_drain() self.protocol._drain = delayed_drain From 504c66cf01bc7e2c4f5fbcceb1387cadf056a678 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 14:02:51 +0200 Subject: [PATCH 16/20] Document and test support for Python 3.8. --- .circleci/config.yml | 12 ++++++++++++ docs/changelog.rst | 2 ++ setup.py | 1 + tox.ini | 2 +- 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index a6c85d23..0877c161 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -29,6 +29,15 @@ jobs: - checkout - run: sudo pip install tox - run: tox -e py37 + py38: + docker: + - image: circleci/python:3.8.0rc1 + steps: + # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. + - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc + - checkout + - run: sudo pip install tox + - run: tox -e py38 workflows: version: 2 @@ -41,3 +50,6 @@ workflows: - py37: requires: - main + - py38: + requires: + - main diff --git a/docs/changelog.rst b/docs/changelog.rst index 87b2e438..2a106fbc 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,8 @@ Changelog *In development* +* Added compatibility with Python 3.8. + 8.0.2 ..... diff --git a/setup.py b/setup.py index c7643010..f3581924 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', ], package_dir = {'': 'src'}, package_data = {'websockets': ['py.typed']}, diff --git a/tox.ini b/tox.ini index 7397c90a..825e3406 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py36,py37,coverage,black,flake8,isort,mypy +envlist = py36,py37,py38,coverage,black,flake8,isort,mypy [testenv] commands = python -W default -m unittest {posargs} From fc99c16a53e5d0c85448714021a5deb74021920a Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 21:18:30 +0200 Subject: [PATCH 17/20] Move test logging configuration to a single place. --- tests/__init__.py | 5 +++++ tests/test_client_server.py | 5 ----- tests/test_protocol.py | 5 ----- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..dd78609f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +import logging + + +# Avoid displaying stack traces at the ERROR logging level. +logging.basicConfig(level=logging.CRITICAL) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index 85828bdb..ce0f66ce 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -2,7 +2,6 @@ import contextlib import functools import http -import logging import pathlib import random import socket @@ -37,10 +36,6 @@ from .utils import AsyncioTestCase -# Avoid displaying stack traces at the ERROR logging level. -logging.basicConfig(level=logging.CRITICAL) - - # Generate TLS certificate with: # $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ # -out test_localhost.crt -keyout test_localhost.key diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 04e2a38f..d95260a8 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,6 +1,5 @@ import asyncio import contextlib -import logging import sys import unittest import unittest.mock @@ -13,10 +12,6 @@ from .utils import MS, AsyncioTestCase -# Avoid displaying stack traces at the ERROR logging level. -logging.basicConfig(level=logging.CRITICAL) - - async def async_iterable(iterable): for item in iterable: yield item From 91dea74a9430d8573f7fd58c35fd731eddbb6cdb Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 21:33:27 +0200 Subject: [PATCH 18/20] Remove test that no longer makes sense. Since version 7.0, when the server closes, it terminates connections with close code 1001 instead of canceling them. --- tests/test_client_server.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index ce0f66ce..35913666 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -173,7 +173,7 @@ class HealthCheckServerProtocol(WebSocketServerProtocol): return http.HTTPStatus.OK, [("X-Access", "OK")], b"status = green\n" -class SlowServerProtocol(WebSocketServerProtocol): +class SlowOpeningHandshakeProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): await asyncio.sleep(10 * MS) @@ -1165,7 +1165,7 @@ def test_client_closes_connection_before_handshake(self, handshake): # The server should stop properly anyway. It used to hang because the # task handling the connection was waiting for the opening handshake. - @with_server(create_protocol=SlowServerProtocol) + @with_server(create_protocol=SlowOpeningHandshakeProtocol) def test_server_shuts_down_during_opening_handshake(self): self.loop.call_later(5 * MS, self.server.close) with self.assertRaises(InvalidStatusCode) as raised: @@ -1188,20 +1188,6 @@ def test_server_shuts_down_during_connection_handling(self): self.assertEqual(self.client.close_code, 1001) self.assertEqual(server_ws.close_code, 1001) - @with_server() - @unittest.mock.patch("websockets.server.WebSocketServerProtocol.close") - def test_server_shuts_down_during_connection_close(self, _close): - _close.side_effect = asyncio.CancelledError - - self.server.closing = True - with self.temp_client(): - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - - # Websocket connection terminates abnormally. - self.assertEqual(self.client.close_code, 1006) - @with_server() def test_server_shuts_down_waits_until_handlers_terminate(self): # This handler waits a bit after the connection is closed in order From 88073fe7da7e8c07a5f46ae7c77e1a4ef80736ee Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 21:41:13 +0200 Subject: [PATCH 19/20] Fix refactoring error. WebSocketCommonProtocol.transport can be unset, but it cannot be None. --- src/websockets/protocol.py | 14 ++++++++++---- tests/test_protocol.py | 8 ++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 76d46ad9..0623e136 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -366,9 +366,12 @@ def local_address(self) -> Any: been established yet. """ - if self.transport is None: + try: + transport = self.transport + except AttributeError: return None - return self.transport.get_extra_info("sockname") + else: + return transport.get_extra_info("sockname") @property def remote_address(self) -> Any: @@ -379,9 +382,12 @@ def remote_address(self) -> Any: been established yet. """ - if self.transport is None: + try: + transport = self.transport + except AttributeError: return None - return self.transport.get_extra_info("peername") + else: + return transport.get_extra_info("peername") @property def open(self) -> bool: diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d95260a8..66a822e7 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -323,8 +323,8 @@ def test_local_address(self): def test_local_address_before_connection(self): # Emulate the situation before connection_open() runs. - self.protocol.transport, _transport = None, self.protocol.transport - + _transport = self.protocol.transport + del self.protocol.transport try: self.assertEqual(self.protocol.local_address, None) finally: @@ -339,8 +339,8 @@ def test_remote_address(self): def test_remote_address_before_connection(self): # Emulate the situation before connection_open() runs. - self.protocol.transport, _transport = None, self.protocol.transport - + _transport = self.protocol.transport + del self.protocol.transport try: self.assertEqual(self.protocol.remote_address, None) finally: From d75d4a4d34a7485ab8e5f2e65dee06dc37b7cfba Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 5 Oct 2019 21:39:14 +0200 Subject: [PATCH 20/20] Remove useless type declaration. --- src/websockets/protocol.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 0623e136..6c29b2a5 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -226,19 +226,16 @@ def __init__( # ``self.read_limit``. The ``limit`` argument controls the line length # limit and half the buffer limit of :class:`~asyncio.StreamReader`. # That's why it must be set to half of ``self.read_limit``. - self.reader: asyncio.StreamReader = asyncio.StreamReader( - limit=read_limit // 2, loop=loop - ) - - self.transport: asyncio.Transport - self._drain_lock = asyncio.Lock( - loop=loop if sys.version_info[:2] < (3, 8) else None - ) + self.reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) # Copied from asyncio.FlowControlMixin self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None + self._drain_lock = asyncio.Lock( + loop=loop if sys.version_info[:2] < (3, 8) else None + ) + # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. # Subclasses implement the opening handshake and, on success, execute