14
0
Files
python-websockets/python38-support.patch

2202 lines
81 KiB
Diff

From fbab68fad41b1ddad2bbe7eccd793f5d6c52f761 Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 11:41:21 +0200
Subject: [PATCH 01/20] Copy FlowControlMixin and StreamReaderProtocol.
This is the official recommendation of Python core devs.
The code is taken from the current 3.7 branch.
---
src/websockets/protocol.py | 127 ++++++++++++++++++++++++++++++++++++-
1 file changed, 126 insertions(+), 1 deletion(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index 1f0edcce..2f74cd23 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -61,7 +61,132 @@ class State(enum.IntEnum):
# between the check and the assignment.
-class WebSocketCommonProtocol(asyncio.StreamReaderProtocol):
+class FlowControlMixin(asyncio.Protocol):
+ """Reusable flow control logic for StreamWriter.drain().
+ This implements the protocol methods pause_writing(),
+ resume_writing() and connection_lost(). If the subclass overrides
+ these it must call the super methods.
+ StreamWriter.drain() must wait for _drain_helper() coroutine.
+ """
+
+ def __init__(self, loop=None):
+ if loop is None:
+ self._loop = asyncio.get_event_loop()
+ else:
+ self._loop = loop
+ self._paused = False
+ self._drain_waiter = None
+ self._connection_lost = False
+
+ def pause_writing(self):
+ assert not self._paused
+ self._paused = True
+ if self._loop.get_debug():
+ logger.debug("%r pauses writing", self)
+
+ def resume_writing(self):
+ assert self._paused
+ self._paused = False
+ if self._loop.get_debug():
+ logger.debug("%r resumes writing", self)
+
+ waiter = self._drain_waiter
+ if waiter is not None:
+ self._drain_waiter = None
+ if not waiter.done():
+ waiter.set_result(None)
+
+ def connection_lost(self, exc):
+ self._connection_lost = True
+ # Wake up the writer if currently paused.
+ if not self._paused:
+ return
+ waiter = self._drain_waiter
+ if waiter is None:
+ return
+ self._drain_waiter = None
+ if waiter.done():
+ return
+ if exc is None:
+ waiter.set_result(None)
+ else:
+ waiter.set_exception(exc)
+
+ async def _drain_helper(self):
+ if self._connection_lost:
+ raise ConnectionResetError("Connection lost")
+ if not self._paused:
+ return
+ waiter = self._drain_waiter
+ assert waiter is None or waiter.cancelled()
+ waiter = self._loop.create_future()
+ self._drain_waiter = waiter
+ await waiter
+
+
+class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol):
+ """Helper class to adapt between Protocol and StreamReader.
+ (This is a helper class instead of making StreamReader itself a
+ Protocol subclass, because the StreamReader has other potential
+ uses, and to prevent the user of the StreamReader to accidentally
+ call inappropriate methods of the protocol.)
+ """
+
+ def __init__(self, stream_reader, client_connected_cb=None, loop=None):
+ super().__init__(loop=loop)
+ self._stream_reader = stream_reader
+ self._stream_writer = None
+ self._client_connected_cb = client_connected_cb
+ self._over_ssl = False
+ self._closed = self._loop.create_future()
+
+ def connection_made(self, transport):
+ self._stream_reader.set_transport(transport)
+ self._over_ssl = transport.get_extra_info("sslcontext") is not None
+ if self._client_connected_cb is not None:
+ self._stream_writer = asyncio.StreamWriter(
+ transport, self, self._stream_reader, self._loop
+ )
+ res = self._client_connected_cb(self._stream_reader, self._stream_writer)
+ if asyncio.iscoroutine(res):
+ self._loop.create_task(res)
+
+ def connection_lost(self, exc):
+ if self._stream_reader is not None:
+ if exc is None:
+ self._stream_reader.feed_eof()
+ else:
+ self._stream_reader.set_exception(exc)
+ if not self._closed.done():
+ if exc is None:
+ self._closed.set_result(None)
+ else:
+ self._closed.set_exception(exc)
+ super().connection_lost(exc)
+ self._stream_reader = None
+ self._stream_writer = None
+
+ def data_received(self, data):
+ self._stream_reader.feed_data(data)
+
+ def eof_received(self):
+ self._stream_reader.feed_eof()
+ if self._over_ssl:
+ # Prevent a warning in SSLProtocol.eof_received:
+ # "returning true from eof_received()
+ # has no effect when using ssl"
+ return False
+ return True
+
+ def __del__(self):
+ # Prevent reports about unhandled exceptions.
+ # Better than self._closed._log_traceback = False hack
+ closed = self._closed
+ if closed.done() and not closed.cancelled():
+ closed.exception()
+
+
+class WebSocketCommonProtocol(StreamReaderProtocol):
"""
:class:`~asyncio.Protocol` subclass implementing the data transfer phase.
From 57907330fe789f0b5f6b71b2810a0e36fee14844 Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 11:54:04 +0200
Subject: [PATCH 02/20] Remove docstrings and debug logs.
---
src/websockets/protocol.py | 16 ----------------
1 file changed, 16 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index 2f74cd23..d74c8157 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -62,12 +62,6 @@ class State(enum.IntEnum):
class FlowControlMixin(asyncio.Protocol):
- """Reusable flow control logic for StreamWriter.drain().
- This implements the protocol methods pause_writing(),
- resume_writing() and connection_lost(). If the subclass overrides
- these it must call the super methods.
- StreamWriter.drain() must wait for _drain_helper() coroutine.
- """
def __init__(self, loop=None):
if loop is None:
@@ -81,14 +75,10 @@ def __init__(self, loop=None):
def pause_writing(self):
assert not self._paused
self._paused = True
- if self._loop.get_debug():
- logger.debug("%r pauses writing", self)
def resume_writing(self):
assert self._paused
self._paused = False
- if self._loop.get_debug():
- logger.debug("%r resumes writing", self)
waiter = self._drain_waiter
if waiter is not None:
@@ -125,12 +115,6 @@ def connection_lost(self, exc):
class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol):
- """Helper class to adapt between Protocol and StreamReader.
- (This is a helper class instead of making StreamReader itself a
- Protocol subclass, because the StreamReader has other potential
- uses, and to prevent the user of the StreamReader to accidentally
- call inappropriate methods of the protocol.)
- """
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
super().__init__(loop=loop)
From 53523d06c7752b394c1f1e5566f464db741e2617 Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 11:54:43 +0200
Subject: [PATCH 03/20] Merge FlowControlMixin in StreamReaderProtocol.
---
src/websockets/protocol.py | 54 +++++++++++++++++---------------------
1 file changed, 24 insertions(+), 30 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index d74c8157..49d8b4f2 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -61,9 +61,9 @@ class State(enum.IntEnum):
# between the check and the assignment.
-class FlowControlMixin(asyncio.Protocol):
+class StreamReaderProtocol(asyncio.Protocol):
- def __init__(self, loop=None):
+ def __init__(self, stream_reader, client_connected_cb=None, loop=None):
if loop is None:
self._loop = asyncio.get_event_loop()
else:
@@ -72,6 +72,12 @@ def __init__(self, loop=None):
self._drain_waiter = None
self._connection_lost = False
+ self._stream_reader = stream_reader
+ self._stream_writer = None
+ self._client_connected_cb = client_connected_cb
+ self._over_ssl = False
+ self._closed = self._loop.create_future()
+
def pause_writing(self):
assert not self._paused
self._paused = True
@@ -86,22 +92,6 @@ def resume_writing(self):
if not waiter.done():
waiter.set_result(None)
- def connection_lost(self, exc):
- self._connection_lost = True
- # Wake up the writer if currently paused.
- if not self._paused:
- return
- waiter = self._drain_waiter
- if waiter is None:
- return
- self._drain_waiter = None
- if waiter.done():
- return
- if exc is None:
- waiter.set_result(None)
- else:
- waiter.set_exception(exc)
-
async def _drain_helper(self):
if self._connection_lost:
raise ConnectionResetError("Connection lost")
@@ -113,17 +103,6 @@ def connection_lost(self, exc):
self._drain_waiter = waiter
await waiter
-
-class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol):
-
- def __init__(self, stream_reader, client_connected_cb=None, loop=None):
- super().__init__(loop=loop)
- self._stream_reader = stream_reader
- self._stream_writer = None
- self._client_connected_cb = client_connected_cb
- self._over_ssl = False
- self._closed = self._loop.create_future()
-
def connection_made(self, transport):
self._stream_reader.set_transport(transport)
self._over_ssl = transport.get_extra_info("sslcontext") is not None
@@ -146,7 +125,22 @@ def connection_lost(self, exc):
self._closed.set_result(None)
else:
self._closed.set_exception(exc)
- super().connection_lost(exc)
+
+ self._connection_lost = True
+ # Wake up the writer if currently paused.
+ if not self._paused:
+ return
+ waiter = self._drain_waiter
+ if waiter is None:
+ return
+ self._drain_waiter = None
+ if waiter.done():
+ return
+ if exc is None:
+ waiter.set_result(None)
+ else:
+ waiter.set_exception(exc)
+
self._stream_reader = None
self._stream_writer = None
From e6d8bcfe2a0f16c50260cafab45f06da6baa1689 Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 11:55:53 +0200
Subject: [PATCH 04/20] Deduplicate loop and _loop attributes.
---
src/websockets/protocol.py | 18 ++++++------------
1 file changed, 6 insertions(+), 12 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index 49d8b4f2..98c23ab1 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -63,11 +63,7 @@ class State(enum.IntEnum):
class StreamReaderProtocol(asyncio.Protocol):
- def __init__(self, stream_reader, client_connected_cb=None, loop=None):
- if loop is None:
- self._loop = asyncio.get_event_loop()
- else:
- self._loop = loop
+ def __init__(self, stream_reader, client_connected_cb=None):
self._paused = False
self._drain_waiter = None
self._connection_lost = False
@@ -76,7 +72,7 @@ def __init__(self, stream_reader, client_connected_cb=None, loop=None):
self._stream_writer = None
self._client_connected_cb = client_connected_cb
self._over_ssl = False
- self._closed = self._loop.create_future()
+ self._closed = self.loop.create_future()
def pause_writing(self):
assert not self._paused
@@ -99,7 +95,7 @@ def resume_writing(self):
return
waiter = self._drain_waiter
assert waiter is None or waiter.cancelled()
- waiter = self._loop.create_future()
+ waiter = self.loop.create_future()
self._drain_waiter = waiter
await waiter
@@ -108,11 +104,11 @@ def connection_made(self, transport):
self._over_ssl = transport.get_extra_info("sslcontext") is not None
if self._client_connected_cb is not None:
self._stream_writer = asyncio.StreamWriter(
- transport, self, self._stream_reader, self._loop
+ transport, self, self._stream_reader, self.loop
)
res = self._client_connected_cb(self._stream_reader, self._stream_writer)
if asyncio.iscoroutine(res):
- self._loop.create_task(res)
+ self.loop.create_task(res)
def connection_lost(self, exc):
if self._stream_reader is not None:
@@ -315,8 +311,6 @@ def __init__(
self.read_limit = read_limit
self.write_limit = write_limit
- # Store a reference to loop to avoid relying on self._loop, a private
- # attribute of StreamReaderProtocol, inherited from FlowControlMixin.
if loop is None:
loop = asyncio.get_event_loop()
self.loop = loop
@@ -331,7 +325,7 @@ def __init__(
# limit and half the buffer limit of :class:`~asyncio.StreamReader`.
# That's why it must be set to half of ``self.read_limit``.
stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop)
- super().__init__(stream_reader, self.client_connected, loop)
+ super().__init__(stream_reader, self.client_connected)
self.reader: asyncio.StreamReader
self.writer: asyncio.StreamWriter
From 4205cb176e62c938247af59e94214814242360ec Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 11:57:34 +0200
Subject: [PATCH 05/20] Remove client_connected callback.
---
src/websockets/protocol.py | 31 +++++++------------------------
1 file changed, 7 insertions(+), 24 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index 98c23ab1..bfc354a8 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -63,14 +63,13 @@ class State(enum.IntEnum):
class StreamReaderProtocol(asyncio.Protocol):
- def __init__(self, stream_reader, client_connected_cb=None):
+ def __init__(self, stream_reader):
self._paused = False
self._drain_waiter = None
self._connection_lost = False
self._stream_reader = stream_reader
self._stream_writer = None
- self._client_connected_cb = client_connected_cb
self._over_ssl = False
self._closed = self.loop.create_future()
@@ -102,13 +101,11 @@ def resume_writing(self):
def connection_made(self, transport):
self._stream_reader.set_transport(transport)
self._over_ssl = transport.get_extra_info("sslcontext") is not None
- if self._client_connected_cb is not None:
- self._stream_writer = asyncio.StreamWriter(
- transport, self, self._stream_reader, self.loop
- )
- res = self._client_connected_cb(self._stream_reader, self._stream_writer)
- if asyncio.iscoroutine(res):
- self.loop.create_task(res)
+ self._stream_writer = asyncio.StreamWriter(
+ transport, self, self._stream_reader, self.loop
+ )
+ self.reader = self._stream_reader
+ self.writer = self._stream_writer
def connection_lost(self, exc):
if self._stream_reader is not None:
@@ -325,7 +322,7 @@ def __init__(
# limit and half the buffer limit of :class:`~asyncio.StreamReader`.
# That's why it must be set to half of ``self.read_limit``.
stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop)
- super().__init__(stream_reader, self.client_connected)
+ super().__init__(stream_reader)
self.reader: asyncio.StreamReader
self.writer: asyncio.StreamWriter
@@ -381,20 +378,6 @@ def __init__(
# Task closing the TCP connection.
self.close_connection_task: asyncio.Task[None]
- def client_connected(
- self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
- ) -> None:
- """
- Callback when the TCP connection is established.
-
- Record references to the stream reader and the stream writer to avoid
- using private attributes ``_stream_reader`` and ``_stream_writer`` of
- :class:`~asyncio.StreamReaderProtocol`.
-
- """
- self.reader = reader
- self.writer = writer
-
def connection_open(self) -> None:
"""
Callback when the WebSocket opening handshake completes.
From 2ebb4b4eff76be388ee37570e891ed8432e19a9f Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 12:00:31 +0200
Subject: [PATCH 06/20] Deduplicate reader/writer and _stream_reader/writer
attributes.
---
src/websockets/protocol.py | 36 +++++++++++++++---------------------
1 file changed, 15 insertions(+), 21 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index bfc354a8..9c61a409 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -62,14 +62,11 @@ class State(enum.IntEnum):
class StreamReaderProtocol(asyncio.Protocol):
-
- def __init__(self, stream_reader):
+ def __init__(self):
self._paused = False
self._drain_waiter = None
self._connection_lost = False
- self._stream_reader = stream_reader
- self._stream_writer = None
self._over_ssl = False
self._closed = self.loop.create_future()
@@ -99,20 +96,16 @@ def resume_writing(self):
await waiter
def connection_made(self, transport):
- self._stream_reader.set_transport(transport)
+ self.reader.set_transport(transport)
self._over_ssl = transport.get_extra_info("sslcontext") is not None
- self._stream_writer = asyncio.StreamWriter(
- transport, self, self._stream_reader, self.loop
- )
- self.reader = self._stream_reader
- self.writer = self._stream_writer
+ self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop)
def connection_lost(self, exc):
- if self._stream_reader is not None:
+ if self.reader is not None:
if exc is None:
- self._stream_reader.feed_eof()
+ self.reader.feed_eof()
else:
- self._stream_reader.set_exception(exc)
+ self.reader.set_exception(exc)
if not self._closed.done():
if exc is None:
self._closed.set_result(None)
@@ -134,14 +127,14 @@ def connection_lost(self, exc):
else:
waiter.set_exception(exc)
- self._stream_reader = None
- self._stream_writer = None
+ del self.reader
+ del self.writer
def data_received(self, data):
- self._stream_reader.feed_data(data)
+ self.reader.feed_data(data)
def eof_received(self):
- self._stream_reader.feed_eof()
+ self.reader.feed_eof()
if self._over_ssl:
# Prevent a warning in SSLProtocol.eof_received:
# "returning true from eof_received()
@@ -321,13 +314,14 @@ def __init__(
# ``self.read_limit``. The ``limit`` argument controls the line length
# limit and half the buffer limit of :class:`~asyncio.StreamReader`.
# That's why it must be set to half of ``self.read_limit``.
- stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop)
- super().__init__(stream_reader)
-
- self.reader: asyncio.StreamReader
+ self.reader: asyncio.StreamReader = asyncio.StreamReader(
+ limit=read_limit // 2, loop=loop
+ )
self.writer: asyncio.StreamWriter
self._drain_lock = asyncio.Lock(loop=loop)
+ super().__init__()
+
# This class implements the data transfer and closing handshake, which
# are shared between the client-side and the server-side.
# Subclasses implement the opening handshake and, on success, execute
From c7ae53f4dfc586f21af1fdbdcfba321bbccd491e Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 12:25:40 +0200
Subject: [PATCH 07/20] Merge asyncio.Protocol methods.
---
src/websockets/protocol.py | 160 ++++++++++++++++---------------------
1 file changed, 70 insertions(+), 90 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index 9c61a409..89e3464a 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -70,20 +70,6 @@ def __init__(self):
self._over_ssl = False
self._closed = self.loop.create_future()
- def pause_writing(self):
- assert not self._paused
- self._paused = True
-
- def resume_writing(self):
- assert self._paused
- self._paused = False
-
- waiter = self._drain_waiter
- if waiter is not None:
- self._drain_waiter = None
- if not waiter.done():
- waiter.set_result(None)
-
async def _drain_helper(self):
if self._connection_lost:
raise ConnectionResetError("Connection lost")
@@ -95,53 +81,6 @@ def resume_writing(self):
self._drain_waiter = waiter
await waiter
- def connection_made(self, transport):
- self.reader.set_transport(transport)
- self._over_ssl = transport.get_extra_info("sslcontext") is not None
- self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop)
-
- def connection_lost(self, exc):
- if self.reader is not None:
- if exc is None:
- self.reader.feed_eof()
- else:
- self.reader.set_exception(exc)
- if not self._closed.done():
- if exc is None:
- self._closed.set_result(None)
- else:
- self._closed.set_exception(exc)
-
- self._connection_lost = True
- # Wake up the writer if currently paused.
- if not self._paused:
- return
- waiter = self._drain_waiter
- if waiter is None:
- return
- self._drain_waiter = None
- if waiter.done():
- return
- if exc is None:
- waiter.set_result(None)
- else:
- waiter.set_exception(exc)
-
- del self.reader
- del self.writer
-
- def data_received(self, data):
- self.reader.feed_data(data)
-
- def eof_received(self):
- self.reader.feed_eof()
- if self._over_ssl:
- # Prevent a warning in SSLProtocol.eof_received:
- # "returning true from eof_received()
- # has no effect when using ssl"
- return False
- return True
-
def __del__(self):
# Prevent reports about unhandled exceptions.
# Better than self._closed._log_traceback = False hack
@@ -1363,7 +1302,7 @@ def abort_pings(self) -> None:
"%s - aborted pending ping%s: %s", self.side, plural, pings_hex
)
- # asyncio.StreamReaderProtocol methods
+ # asyncio.Protocol methods
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""
@@ -1382,34 +1321,11 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
logger.debug("%s - event = connection_made(%s)", self.side, transport)
# mypy thinks transport is a BaseTransport, not a Transport.
transport.set_write_buffer_limits(self.write_limit) # type: ignore
- super().connection_made(transport)
-
- def eof_received(self) -> bool:
- """
- Close the transport after receiving EOF.
-
- Since Python 3.5, `:meth:~StreamReaderProtocol.eof_received` returns
- ``True`` on non-TLS connections.
-
- See http://bugs.python.org/issue24539 for more information.
-
- This is inappropriate for ``websockets`` for at least three reasons:
-
- 1. The use case is to read data until EOF with self.reader.read(-1).
- Since WebSocket is a TLV protocol, this never happens.
-
- 2. It doesn't work on TLS connections. A falsy value must be
- returned to have the same behavior on TLS and plain connections.
-
- 3. The WebSocket protocol has its own closing handshake. Endpoints
- close the TCP connection after sending a close frame.
-
- As a consequence we revert to the previous, more useful behavior.
- """
- logger.debug("%s - event = eof_received()", self.side)
- super().eof_received()
- return False
+ # Copied from asyncio.StreamReaderProtocol
+ self.reader.set_transport(transport)
+ self._over_ssl = transport.get_extra_info("sslcontext") is not None
+ self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop)
def connection_lost(self, exc: Optional[Exception]) -> None:
"""
@@ -1434,4 +1350,68 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
# - it's set only here in connection_lost() which is called only once;
# - it must never be canceled.
self.connection_lost_waiter.set_result(None)
- super().connection_lost(exc)
+
+ # Copied from asyncio.StreamReaderProtocol
+ if self.reader is not None:
+ if exc is None:
+ self.reader.feed_eof()
+ else:
+ self.reader.set_exception(exc)
+ if not self._closed.done():
+ if exc is None:
+ self._closed.set_result(None)
+ else:
+ self._closed.set_exception(exc)
+
+ # Copied from asyncio.FlowControlMixin
+ self._connection_lost = True
+ # Wake up the writer if currently paused.
+ if not self._paused:
+ return
+ waiter = self._drain_waiter
+ if waiter is None:
+ return
+ self._drain_waiter = None
+ if waiter.done():
+ return
+ if exc is None:
+ waiter.set_result(None)
+ else:
+ waiter.set_exception(exc)
+
+ del self.reader
+ del self.writer
+
+ def pause_writing(self) -> None:
+ assert not self._paused
+ self._paused = True
+
+ def resume_writing(self) -> None:
+ assert self._paused
+ self._paused = False
+
+ waiter = self._drain_waiter
+ if waiter is not None:
+ self._drain_waiter = None
+ if not waiter.done():
+ waiter.set_result(None)
+
+ def data_received(self, data: bytes) -> None:
+ logger.debug("%s - event = data_received(<%d bytes>)", self.side, len(data))
+ self.reader.feed_data(data)
+
+ def eof_received(self) -> None:
+ """
+ Close the transport after receiving EOF.
+
+ The WebSocket protocol has its own closing handshake: endpoints close
+ the TCP or TLS connection after sending and receiving a close frame.
+
+ As a consequence, they never need to write after receiving EOF, so
+ there's no reason to keep the transport open by returning ``True``.
+
+ Besides, that doesn't work on TLS connections.
+
+ """
+ logger.debug("%s - event = eof_received()", self.side)
+ self.reader.feed_eof()
From ab5dd72aebd86b878ff08d1ebd5cbb3256e2d6e3 Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 13:09:40 +0200
Subject: [PATCH 08/20] Finish merging StreamReaderProtocol and
FlowControlMixin.
---
src/websockets/protocol.py | 59 +++++++++++++++++++-------------------
1 file changed, 29 insertions(+), 30 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index 89e3464a..2db44e5d 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -61,35 +61,7 @@ class State(enum.IntEnum):
# between the check and the assignment.
-class StreamReaderProtocol(asyncio.Protocol):
- def __init__(self):
- self._paused = False
- self._drain_waiter = None
- self._connection_lost = False
-
- self._over_ssl = False
- self._closed = self.loop.create_future()
-
- async def _drain_helper(self):
- if self._connection_lost:
- raise ConnectionResetError("Connection lost")
- if not self._paused:
- return
- waiter = self._drain_waiter
- assert waiter is None or waiter.cancelled()
- waiter = self.loop.create_future()
- self._drain_waiter = waiter
- await waiter
-
- def __del__(self):
- # Prevent reports about unhandled exceptions.
- # Better than self._closed._log_traceback = False hack
- closed = self._closed
- if closed.done() and not closed.cancelled():
- closed.exception()
-
-
-class WebSocketCommonProtocol(StreamReaderProtocol):
+class WebSocketCommonProtocol(asyncio.Protocol):
"""
:class:`~asyncio.Protocol` subclass implementing the data transfer phase.
@@ -259,7 +231,14 @@ def __init__(
self.writer: asyncio.StreamWriter
self._drain_lock = asyncio.Lock(loop=loop)
- super().__init__()
+ # Copied from asyncio.FlowControlMixin
+ self._paused = False
+ self._drain_waiter: Optional[asyncio.Future[None]] = None
+ self._connection_lost = False
+
+ # Copied from asyncio.StreamReaderProtocol
+ self._over_ssl = False
+ self._closed = self.loop.create_future()
# This class implements the data transfer and closing handshake, which
# are shared between the client-side and the server-side.
@@ -311,6 +290,26 @@ def __init__(
# Task closing the TCP connection.
self.close_connection_task: asyncio.Task[None]
+ # Copied from asyncio.StreamReaderProtocol
+ def __del__(self) -> None:
+ # Prevent reports about unhandled exceptions.
+ # Better than self._closed._log_traceback = False hack
+ closed = self._closed
+ if closed.done() and not closed.cancelled():
+ closed.exception()
+
+ # Copied from asyncio.FlowControlMixin
+ async def _drain_helper(self) -> None:
+ if self._connection_lost:
+ raise ConnectionResetError("Connection lost")
+ if not self._paused:
+ return
+ waiter = self._drain_waiter
+ assert waiter is None or waiter.cancelled()
+ waiter = self.loop.create_future()
+ self._drain_waiter = waiter
+ await waiter
+
def connection_open(self) -> None:
"""
Callback when the WebSocket opening handshake completes.
From 080543c4961037fcd8a207956e48474366e13ad5 Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 13:17:26 +0200
Subject: [PATCH 09/20] Deduplicate connection termination tracking.
---
src/websockets/protocol.py | 21 +++++----------------
1 file changed, 5 insertions(+), 16 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index 2db44e5d..1cd5a91c 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -234,11 +234,9 @@ def __init__(
# Copied from asyncio.FlowControlMixin
self._paused = False
self._drain_waiter: Optional[asyncio.Future[None]] = None
- self._connection_lost = False
# Copied from asyncio.StreamReaderProtocol
self._over_ssl = False
- self._closed = self.loop.create_future()
# This class implements the data transfer and closing handshake, which
# are shared between the client-side and the server-side.
@@ -290,17 +288,14 @@ def __init__(
# Task closing the TCP connection.
self.close_connection_task: asyncio.Task[None]
- # Copied from asyncio.StreamReaderProtocol
- def __del__(self) -> None:
- # Prevent reports about unhandled exceptions.
- # Better than self._closed._log_traceback = False hack
- closed = self._closed
- if closed.done() and not closed.cancelled():
- closed.exception()
+ # asyncio.StreamWriter expects this attribute on the Protocol
+ @property
+ def _closed(self) -> asyncio.Future:
+ return self.connection_lost_waiter
# Copied from asyncio.FlowControlMixin
async def _drain_helper(self) -> None:
- if self._connection_lost:
+ if self.connection_lost_waiter.done():
raise ConnectionResetError("Connection lost")
if not self._paused:
return
@@ -1356,14 +1351,8 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
self.reader.feed_eof()
else:
self.reader.set_exception(exc)
- if not self._closed.done():
- if exc is None:
- self._closed.set_result(None)
- else:
- self._closed.set_exception(exc)
# Copied from asyncio.FlowControlMixin
- self._connection_lost = True
# Wake up the writer if currently paused.
if not self._paused:
return
From 284932f967699467908e69226118f9bcd5ba8f9d Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 13:21:25 +0200
Subject: [PATCH 10/20] Remove unused attribute.
---
src/websockets/protocol.py | 4 ----
1 file changed, 4 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index 1cd5a91c..a1c90916 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -235,9 +235,6 @@ def __init__(
self._paused = False
self._drain_waiter: Optional[asyncio.Future[None]] = None
- # Copied from asyncio.StreamReaderProtocol
- self._over_ssl = False
-
# This class implements the data transfer and closing handshake, which
# are shared between the client-side and the server-side.
# Subclasses implement the opening handshake and, on success, execute
@@ -1318,7 +1315,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
# Copied from asyncio.StreamReaderProtocol
self.reader.set_transport(transport)
- self._over_ssl = transport.get_extra_info("sslcontext") is not None
self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop)
def connection_lost(self, exc: Optional[Exception]) -> None:
From 9eeeeb3547eb7abc2253c75e462a50a8f68ba16b Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 13:28:49 +0200
Subject: [PATCH 11/20] Ignore quality checks for code copied from asyncio.
---
src/websockets/protocol.py | 51 +++++++++++++++++++-------------------
1 file changed, 25 insertions(+), 26 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index a1c90916..0bb12fd5 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -287,11 +287,11 @@ def __init__(
# asyncio.StreamWriter expects this attribute on the Protocol
@property
- def _closed(self) -> asyncio.Future:
+ def _closed(self) -> Any: # pragma: no cover
return self.connection_lost_waiter
# Copied from asyncio.FlowControlMixin
- async def _drain_helper(self) -> None:
+ async def _drain_helper(self) -> None: # pragma: no cover
if self.connection_lost_waiter.done():
raise ConnectionResetError("Connection lost")
if not self._paused:
@@ -1341,36 +1341,35 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
# - it must never be canceled.
self.connection_lost_waiter.set_result(None)
- # Copied from asyncio.StreamReaderProtocol
- if self.reader is not None:
- if exc is None:
- self.reader.feed_eof()
- else:
- self.reader.set_exception(exc)
+ if True: # pragma: no cover
- # Copied from asyncio.FlowControlMixin
- # Wake up the writer if currently paused.
- if not self._paused:
- return
- waiter = self._drain_waiter
- if waiter is None:
- return
- self._drain_waiter = None
- if waiter.done():
- return
- if exc is None:
- waiter.set_result(None)
- else:
- waiter.set_exception(exc)
+ # Copied from asyncio.StreamReaderProtocol
+ if self.reader is not None:
+ if exc is None:
+ self.reader.feed_eof()
+ else:
+ self.reader.set_exception(exc)
- del self.reader
- del self.writer
+ # Copied from asyncio.FlowControlMixin
+ # Wake up the writer if currently paused.
+ if not self._paused:
+ return
+ waiter = self._drain_waiter
+ if waiter is None:
+ return
+ self._drain_waiter = None
+ if waiter.done():
+ return
+ if exc is None:
+ waiter.set_result(None)
+ else:
+ waiter.set_exception(exc)
- def pause_writing(self) -> None:
+ def pause_writing(self) -> None: # pragma: no cover
assert not self._paused
self._paused = True
- def resume_writing(self) -> None:
+ def resume_writing(self) -> None: # pragma: no cover
assert self._paused
self._paused = False
From f79cdd8e50e82433bbb0c003279a6b6f13e6c17e Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 13:54:13 +0200
Subject: [PATCH 12/20] Remove asyncio.StreamWriter.
It adds only one method for flow control. Copy it, as we've already
copied the rest of the flow control implementation.
---
src/websockets/client.py | 2 +-
src/websockets/protocol.py | 64 +++++++++++++++++++++----------------
src/websockets/server.py | 6 ++--
tests/test_client_server.py | 2 +-
tests/test_protocol.py | 18 +++++------
5 files changed, 51 insertions(+), 41 deletions(-)
diff --git a/src/websockets/client.py b/src/websockets/client.py
index c1fdf88a..34cd8624 100644
--- a/src/websockets/client.py
+++ b/src/websockets/client.py
@@ -85,7 +85,7 @@ def write_http_request(self, path: str, headers: Headers) -> None:
request = f"GET {path} HTTP/1.1\r\n"
request += str(headers)
- self.writer.write(request.encode())
+ self.transport.write(request.encode())
async def read_http_response(self) -> Tuple[int, Headers]:
"""
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index 0bb12fd5..eb3d6bcc 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -228,7 +228,8 @@ def __init__(
self.reader: asyncio.StreamReader = asyncio.StreamReader(
limit=read_limit // 2, loop=loop
)
- self.writer: asyncio.StreamWriter
+
+ self.transport: asyncio.Transport
self._drain_lock = asyncio.Lock(loop=loop)
# Copied from asyncio.FlowControlMixin
@@ -285,11 +286,6 @@ def __init__(
# Task closing the TCP connection.
self.close_connection_task: asyncio.Task[None]
- # asyncio.StreamWriter expects this attribute on the Protocol
- @property
- def _closed(self) -> Any: # pragma: no cover
- return self.connection_lost_waiter
-
# Copied from asyncio.FlowControlMixin
async def _drain_helper(self) -> None: # pragma: no cover
if self.connection_lost_waiter.done():
@@ -302,6 +298,23 @@ def _closed(self) -> Any: # pragma: no cover
self._drain_waiter = waiter
await waiter
+ # Copied from asyncio.StreamWriter
+ async def _drain(self) -> None: # pragma: no cover
+ if self.reader is not None:
+ exc = self.reader.exception()
+ if exc is not None:
+ raise exc
+ if self.transport is not None:
+ if self.transport.is_closing():
+ # Yield to the event loop so connection_lost() may be
+ # called. Without this, _drain_helper() would return
+ # immediately, and code that calls
+ # write(...); yield from drain()
+ # in a loop would never call connection_lost(), so it
+ # would not see an error when the socket is closed.
+ await asyncio.sleep(0)
+ await self._drain_helper()
+
def connection_open(self) -> None:
"""
Callback when the WebSocket opening handshake completes.
@@ -348,9 +361,9 @@ def local_address(self) -> Any:
been established yet.
"""
- if self.writer is None:
+ if self.transport is None:
return None
- return self.writer.get_extra_info("sockname")
+ return self.transport.get_extra_info("sockname")
@property
def remote_address(self) -> Any:
@@ -361,9 +374,9 @@ def remote_address(self) -> Any:
been established yet.
"""
- if self.writer is None:
+ if self.transport is None:
return None
- return self.writer.get_extra_info("peername")
+ return self.transport.get_extra_info("peername")
@property
def open(self) -> bool:
@@ -1037,7 +1050,9 @@ def append(frame: Frame) -> None:
frame = Frame(fin, opcode, data)
logger.debug("%s > %r", self.side, frame)
- frame.write(self.writer.write, mask=self.is_client, extensions=self.extensions)
+ frame.write(
+ self.transport.write, mask=self.is_client, extensions=self.extensions
+ )
try:
# drain() cannot be called concurrently by multiple coroutines:
@@ -1045,7 +1060,7 @@ def append(frame: Frame) -> None:
# version of Python where this bugs exists is supported anymore.
async with self._drain_lock:
# Handle flow control automatically.
- await self.writer.drain()
+ await self._drain()
except ConnectionError:
# Terminate the connection if the socket died.
self.fail_connection()
@@ -1147,9 +1162,9 @@ def append(frame: Frame) -> None:
logger.debug("%s ! timed out waiting for TCP close", self.side)
# Half-close the TCP connection if possible (when there's no TLS).
- if self.writer.can_write_eof():
+ if self.transport.can_write_eof():
logger.debug("%s x half-closing TCP connection", self.side)
- self.writer.write_eof()
+ self.transport.write_eof()
if await self.wait_for_connection_lost():
return
@@ -1162,17 +1177,12 @@ def append(frame: Frame) -> None:
# If connection_lost() was called, the TCP connection is closed.
# However, if TLS is enabled, the transport still needs closing.
# Else asyncio complains: ResourceWarning: unclosed transport.
- try:
- writer_is_closing = self.writer.is_closing # type: ignore
- except AttributeError: # pragma: no cover
- # Python < 3.7
- writer_is_closing = self.writer.transport.is_closing
- if self.connection_lost_waiter.done() and writer_is_closing():
+ if self.connection_lost_waiter.done() and self.transport.is_closing():
return
# Close the TCP connection. Buffers are flushed asynchronously.
logger.debug("%s x closing TCP connection", self.side)
- self.writer.close()
+ self.transport.close()
if await self.wait_for_connection_lost():
return
@@ -1180,8 +1190,7 @@ def append(frame: Frame) -> None:
# Abort the TCP connection. Buffers are discarded.
logger.debug("%s x aborting TCP connection", self.side)
- # mypy thinks self.writer.transport is a BaseTransport, not a Transport.
- self.writer.transport.abort() # type: ignore
+ self.transport.abort()
# connection_lost() is called quickly after aborting.
await self.wait_for_connection_lost()
@@ -1261,7 +1270,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None:
frame = Frame(True, OP_CLOSE, frame_data)
logger.debug("%s > %r", self.side, frame)
frame.write(
- self.writer.write, mask=self.is_client, extensions=self.extensions
+ self.transport.write, mask=self.is_client, extensions=self.extensions
)
# Start close_connection_task if the opening handshake didn't succeed.
@@ -1310,12 +1319,13 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""
logger.debug("%s - event = connection_made(%s)", self.side, transport)
- # mypy thinks transport is a BaseTransport, not a Transport.
- transport.set_write_buffer_limits(self.write_limit) # type: ignore
+
+ transport = cast(asyncio.Transport, transport)
+ transport.set_write_buffer_limits(self.write_limit)
+ self.transport = transport
# Copied from asyncio.StreamReaderProtocol
self.reader.set_transport(transport)
- self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop)
def connection_lost(self, exc: Optional[Exception]) -> None:
"""
diff --git a/src/websockets/server.py b/src/websockets/server.py
index b220a1b8..1e8ae861 100644
--- a/src/websockets/server.py
+++ b/src/websockets/server.py
@@ -211,7 +211,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
except Exception:
# Last-ditch attempt to avoid leaking connections on errors.
try:
- self.writer.close()
+ self.transport.close()
except Exception: # pragma: no cover
pass
@@ -265,11 +265,11 @@ def write_http_response(
response = f"HTTP/1.1 {status.value} {status.phrase}\r\n"
response += str(headers)
- self.writer.write(response.encode())
+ self.transport.write(response.encode())
if body is not None:
logger.debug("%s > body (%d bytes)", self.side, len(body))
- self.writer.write(body)
+ self.transport.write(body)
async def process_request(
self, path: str, request_headers: Headers
diff --git a/tests/test_client_server.py b/tests/test_client_server.py
index e74ec6bf..6171f21b 100644
--- a/tests/test_client_server.py
+++ b/tests/test_client_server.py
@@ -1166,7 +1166,7 @@ def test_server_close_crashes(self, close):
def test_client_closes_connection_before_handshake(self, handshake):
# We have mocked the handshake() method to prevent the client from
# performing the opening handshake. Force it to close the connection.
- self.client.writer.close()
+ self.client.transport.close()
# The server should stop properly anyway. It used to hang because the
# task handling the connection was waiting for the opening handshake.
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
index a6c42018..dfc2c6d4 100644
--- a/tests/test_protocol.py
+++ b/tests/test_protocol.py
@@ -94,16 +94,16 @@ def tearDown(self):
# Utilities for writing tests.
def make_drain_slow(self, delay=MS):
- # Process connection_made in order to initialize self.protocol.writer.
+ # Process connection_made in order to initialize self.protocol.transport.
self.run_loop_once()
- original_drain = self.protocol.writer.drain
+ original_drain = self.protocol._drain
async def delayed_drain():
await asyncio.sleep(delay, loop=self.loop)
await original_drain()
- self.protocol.writer.drain = delayed_drain
+ self.protocol._drain = delayed_drain
close_frame = Frame(True, OP_CLOSE, serialize_close(1000, "close"))
local_close = Frame(True, OP_CLOSE, serialize_close(1000, "local"))
@@ -321,32 +321,32 @@ def test_local_address(self):
self.transport.get_extra_info = get_extra_info
self.assertEqual(self.protocol.local_address, ("host", 4312))
- get_extra_info.assert_called_with("sockname", None)
+ get_extra_info.assert_called_with("sockname")
def test_local_address_before_connection(self):
# Emulate the situation before connection_open() runs.
- self.protocol.writer, _writer = None, self.protocol.writer
+ self.protocol.transport, _transport = None, self.protocol.transport
try:
self.assertEqual(self.protocol.local_address, None)
finally:
- self.protocol.writer = _writer
+ self.protocol.transport = _transport
def test_remote_address(self):
get_extra_info = unittest.mock.Mock(return_value=("host", 4312))
self.transport.get_extra_info = get_extra_info
self.assertEqual(self.protocol.remote_address, ("host", 4312))
- get_extra_info.assert_called_with("peername", None)
+ get_extra_info.assert_called_with("peername")
def test_remote_address_before_connection(self):
# Emulate the situation before connection_open() runs.
- self.protocol.writer, _writer = None, self.protocol.writer
+ self.protocol.transport, _transport = None, self.protocol.transport
try:
self.assertEqual(self.protocol.remote_address, None)
finally:
- self.protocol.writer = _writer
+ self.protocol.transport = _transport
def test_open(self):
self.assertTrue(self.protocol.open)
From c55f374a93f7e9b14bc2bdcc8868f38f8a834a87 Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 13:55:16 +0200
Subject: [PATCH 13/20] Rename writer to write.
It's a better name for a function that writes bytes.
---
src/websockets/framing.py | 8 ++++----
tests/test_framing.py | 14 +++++++-------
tests/test_protocol.py | 4 ++--
3 files changed, 13 insertions(+), 13 deletions(-)
diff --git a/src/websockets/framing.py b/src/websockets/framing.py
index 81a3185b..c24b8a73 100644
--- a/src/websockets/framing.py
+++ b/src/websockets/framing.py
@@ -147,7 +147,7 @@ class Frame(NamedTuple):
def write(
frame,
- writer: Callable[[bytes], Any],
+ write: Callable[[bytes], Any],
*,
mask: bool,
extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None,
@@ -156,7 +156,7 @@ def write(
Write a WebSocket frame.
:param frame: frame to write
- :param writer: function that writes bytes
+ :param write: function that writes bytes
:param mask: whether the frame should be masked i.e. whether the write
happens on the client side
:param extensions: list of classes with an ``encode()`` method that
@@ -210,10 +210,10 @@ def write(
# Send the frame.
- # The frame is written in a single call to writer in order to prevent
+ # The frame is written in a single call to write in order to prevent
# TCP fragmentation. See #68 for details. This also makes it safe to
# send frames concurrently from multiple coroutines.
- writer(output.getvalue())
+ write(output.getvalue())
def check(frame) -> None:
"""
diff --git a/tests/test_framing.py b/tests/test_framing.py
index 9e6f1871..5def415d 100644
--- a/tests/test_framing.py
+++ b/tests/test_framing.py
@@ -27,15 +27,15 @@ def decode(self, message, mask=False, max_size=None, extensions=None):
return frame
def encode(self, frame, mask=False, extensions=None):
- writer = unittest.mock.Mock()
- frame.write(writer, mask=mask, extensions=extensions)
- # Ensure the entire frame is sent with a single call to writer().
+ write = unittest.mock.Mock()
+ frame.write(write, mask=mask, extensions=extensions)
+ # Ensure the entire frame is sent with a single call to write().
# Multiple calls cause TCP fragmentation and degrade performance.
- self.assertEqual(writer.call_count, 1)
+ self.assertEqual(write.call_count, 1)
# The frame data is the single positional argument of that call.
- self.assertEqual(len(writer.call_args[0]), 1)
- self.assertEqual(len(writer.call_args[1]), 0)
- return writer.call_args[0][0]
+ self.assertEqual(len(write.call_args[0]), 1)
+ self.assertEqual(len(write.call_args[1]), 0)
+ return write.call_args[0][0]
def round_trip(self, message, expected, mask=False, extensions=None):
decoded = self.decode(message, mask, extensions=extensions)
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
index dfc2c6d4..d2793faf 100644
--- a/tests/test_protocol.py
+++ b/tests/test_protocol.py
@@ -114,9 +114,9 @@ def receive_frame(self, frame):
Make the protocol receive a frame.
"""
- writer = self.protocol.data_received
+ write = self.protocol.data_received
mask = not self.protocol.is_client
- frame.write(writer, mask=mask)
+ frame.write(write, mask=mask)
def receive_eof(self):
"""
From 3e03d1a19dbbf54e1c714c933b90ec283caa7bb0 Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 15:49:08 +0200
Subject: [PATCH 14/20] Update to the latest version of mypy.
The bugs that were locking us on an old version are fixed.
---
src/websockets/__init__.py | 15 ++++++++-------
src/websockets/__main__.py | 8 ++++----
src/websockets/client.py | 3 +--
src/websockets/handshake.py | 15 ++++++++++-----
src/websockets/protocol.py | 14 +++++++++-----
src/websockets/server.py | 11 +++--------
tox.ini | 2 +-
7 files changed, 36 insertions(+), 32 deletions(-)
diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py
index e7ba31ce..6bad0f7b 100644
--- a/src/websockets/__init__.py
+++ b/src/websockets/__init__.py
@@ -1,12 +1,13 @@
# This relies on each of the submodules having an __all__ variable.
-from .auth import *
-from .client import *
-from .exceptions import *
-from .protocol import *
-from .server import *
-from .typing import *
-from .uri import *
+from . import auth, client, exceptions, protocol, server, typing, uri
+from .auth import * # noqa
+from .client import * # noqa
+from .exceptions import * # noqa
+from .protocol import * # noqa
+from .server import * # noqa
+from .typing import * # noqa
+from .uri import * # noqa
from .version import version as __version__ # noqa
diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py
index bccb8aa5..394f7ac7 100644
--- a/src/websockets/__main__.py
+++ b/src/websockets/__main__.py
@@ -6,8 +6,8 @@
import threading
from typing import Any, Set
-import websockets
-from websockets.exceptions import format_close
+from .client import connect
+from .exceptions import ConnectionClosed, format_close
if sys.platform == "win32":
@@ -95,7 +95,7 @@ def print_over_input(string: str) -> None:
stop: "asyncio.Future[None]",
) -> None:
try:
- websocket = await websockets.connect(uri)
+ websocket = await connect(uri)
except Exception as exc:
print_over_input(f"Failed to connect to {uri}: {exc}.")
exit_from_event_loop_thread(loop, stop)
@@ -122,7 +122,7 @@ def print_over_input(string: str) -> None:
if incoming in done:
try:
message = incoming.result()
- except websockets.ConnectionClosed:
+ except ConnectionClosed:
break
else:
if isinstance(message, str):
diff --git a/src/websockets/client.py b/src/websockets/client.py
index 34cd8624..725ec1e7 100644
--- a/src/websockets/client.py
+++ b/src/websockets/client.py
@@ -24,7 +24,6 @@
from .extensions.permessage_deflate import ClientPerMessageDeflateFactory
from .handshake import build_request, check_response
from .headers import (
- ExtensionHeader,
build_authorization_basic,
build_extension,
build_subprotocol,
@@ -33,7 +32,7 @@
)
from .http import USER_AGENT, Headers, HeadersLike, read_response
from .protocol import WebSocketCommonProtocol
-from .typing import Origin, Subprotocol
+from .typing import ExtensionHeader, Origin, Subprotocol
from .uri import WebSocketURI, parse_uri
diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py
index 17332d15..9bfe2775 100644
--- a/src/websockets/handshake.py
+++ b/src/websockets/handshake.py
@@ -29,9 +29,10 @@
import binascii
import hashlib
import random
+from typing import List
from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade
-from .headers import parse_connection, parse_upgrade
+from .headers import ConnectionOption, UpgradeProtocol, parse_connection, parse_upgrade
from .http import Headers, MultipleValuesError
@@ -74,14 +75,16 @@ def check_request(headers: Headers) -> str:
is invalid; then the server must return 400 Bad Request error
"""
- connection = sum(
+ connection: List[ConnectionOption] = sum(
[parse_connection(value) for value in headers.get_all("Connection")], []
)
if not any(value.lower() == "upgrade" for value in connection):
raise InvalidUpgrade("Connection", ", ".join(connection))
- upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], [])
+ upgrade: List[UpgradeProtocol] = sum(
+ [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
+ )
# For compatibility with non-strict implementations, ignore case when
# checking the Upgrade header. It's supposed to be 'WebSocket'.
@@ -148,14 +151,16 @@ def check_response(headers: Headers, key: str) -> None:
is invalid
"""
- connection = sum(
+ connection: List[ConnectionOption] = sum(
[parse_connection(value) for value in headers.get_all("Connection")], []
)
if not any(value.lower() == "upgrade" for value in connection):
raise InvalidUpgrade("Connection", " ".join(connection))
- upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], [])
+ upgrade: List[UpgradeProtocol] = sum(
+ [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
+ )
# For compatibility with non-strict implementations, ignore case when
# checking the Upgrade header. It's supposed to be 'WebSocket'.
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index eb3d6bcc..b7c1f19c 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -601,10 +601,14 @@ def closed(self) -> bool:
elif isinstance(message, AsyncIterable):
# aiter_message = aiter(message) without aiter
- aiter_message = type(message).__aiter__(message)
+ # https://github.com/python/mypy/issues/5738
+ aiter_message = type(message).__aiter__(message) # type: ignore
try:
# message_chunk = anext(aiter_message) without anext
- message_chunk = await type(aiter_message).__anext__(aiter_message)
+ # https://github.com/python/mypy/issues/5738
+ message_chunk = await type(aiter_message).__anext__( # type: ignore
+ aiter_message
+ )
except StopAsyncIteration:
return
opcode, data = prepare_data(message_chunk)
@@ -615,7 +619,8 @@ def closed(self) -> bool:
await self.write_frame(False, opcode, data)
# Other fragments.
- async for message_chunk in aiter_message:
+ # https://github.com/python/mypy/issues/5738
+ async for message_chunk in aiter_message: # type: ignore
confirm_opcode, data = prepare_data(message_chunk)
if confirm_opcode != opcode:
raise TypeError("data contains inconsistent types")
@@ -899,8 +904,7 @@ def connection_closed_exc(self) -> ConnectionClosed:
max_size = self.max_size
if text:
decoder_factory = codecs.getincrementaldecoder("utf-8")
- # https://github.com/python/typeshed/pull/2752
- decoder = decoder_factory(errors="strict") # type: ignore
+ decoder = decoder_factory(errors="strict")
if max_size is None:
def append(frame: Frame) -> None:
diff --git a/src/websockets/server.py b/src/websockets/server.py
index 1e8ae861..5114646d 100644
--- a/src/websockets/server.py
+++ b/src/websockets/server.py
@@ -39,15 +39,10 @@
from .extensions.base import Extension, ServerExtensionFactory
from .extensions.permessage_deflate import ServerPerMessageDeflateFactory
from .handshake import build_response, check_request
-from .headers import (
- ExtensionHeader,
- build_extension,
- parse_extension,
- parse_subprotocol,
-)
+from .headers import build_extension, parse_extension, parse_subprotocol
from .http import USER_AGENT, Headers, HeadersLike, MultipleValuesError, read_request
from .protocol import WebSocketCommonProtocol
-from .typing import Origin, Subprotocol
+from .typing import ExtensionHeader, Origin, Subprotocol
__all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"]
@@ -662,7 +657,7 @@ def is_serving(self) -> bool:
"""
try:
# Python ≥ 3.7
- return self.server.is_serving() # type: ignore
+ return self.server.is_serving()
except AttributeError: # pragma: no cover
# Python < 3.7
return self.server.sockets is not None
diff --git a/tox.ini b/tox.ini
index 801d4d5d..7397c90a 100644
--- a/tox.ini
+++ b/tox.ini
@@ -25,4 +25,4 @@ deps = isort
[testenv:mypy]
commands = mypy --strict src
-deps = mypy==0.670
+deps = mypy
From aef1e8781826653720f4db110a4d2af4418fa5b8 Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 15:49:51 +0200
Subject: [PATCH 15/20] Fix deprecation warnings on Python 3.8.
* Don't pass the deprecated loop argument.
* Ignore deprecation warnings for @asyncio.coroutine.
---
src/websockets/protocol.py | 28 ++++++++++++++++++++--------
src/websockets/server.py | 10 +++++++---
tests/test_client_server.py | 36 +++++++++++++++++++++---------------
tests/test_protocol.py | 5 ++++-
4 files changed, 52 insertions(+), 27 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index b7c1f19c..76d46ad9 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -14,6 +14,7 @@
import logging
import random
import struct
+import sys
import warnings
from typing import (
Any,
@@ -230,7 +231,9 @@ def __init__(
)
self.transport: asyncio.Transport
- self._drain_lock = asyncio.Lock(loop=loop)
+ self._drain_lock = asyncio.Lock(
+ loop=loop if sys.version_info[:2] < (3, 8) else None
+ )
# Copied from asyncio.FlowControlMixin
self._paused = False
@@ -312,7 +315,9 @@ def __init__(
# write(...); yield from drain()
# in a loop would never call connection_lost(), so it
# would not see an error when the socket is closed.
- await asyncio.sleep(0)
+ await asyncio.sleep(
+ 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None
+ )
await self._drain_helper()
def connection_open(self) -> None:
@@ -483,7 +488,7 @@ def closed(self) -> bool:
# pop_message_waiter and self.transfer_data_task.
await asyncio.wait(
[pop_message_waiter, self.transfer_data_task],
- loop=self.loop,
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
return_when=asyncio.FIRST_COMPLETED,
)
finally:
@@ -668,7 +673,7 @@ def closed(self) -> bool:
await asyncio.wait_for(
self.write_close_frame(serialize_close(code, reason)),
self.close_timeout,
- loop=self.loop,
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
)
except asyncio.TimeoutError:
# If the close frame cannot be sent because the send buffers
@@ -687,7 +692,9 @@ def closed(self) -> bool:
# If close() is canceled during the wait, self.transfer_data_task
# is canceled before the timeout elapses.
await asyncio.wait_for(
- self.transfer_data_task, self.close_timeout, loop=self.loop
+ self.transfer_data_task,
+ self.close_timeout,
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
@@ -1106,7 +1113,10 @@ def append(frame: Frame) -> None:
try:
while True:
- await asyncio.sleep(self.ping_interval, loop=self.loop)
+ await asyncio.sleep(
+ self.ping_interval,
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
+ )
# ping() raises CancelledError if the connection is closed,
# when close_connection() cancels self.keepalive_ping_task.
@@ -1119,7 +1129,9 @@ def append(frame: Frame) -> None:
if self.ping_timeout is not None:
try:
await asyncio.wait_for(
- ping_waiter, self.ping_timeout, loop=self.loop
+ ping_waiter,
+ self.ping_timeout,
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
)
except asyncio.TimeoutError:
logger.debug("%s ! timed out waiting for pong", self.side)
@@ -1211,7 +1223,7 @@ def append(frame: Frame) -> None:
await asyncio.wait_for(
asyncio.shield(self.connection_lost_waiter),
self.close_timeout,
- loop=self.loop,
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
)
except asyncio.TimeoutError:
pass
diff --git a/src/websockets/server.py b/src/websockets/server.py
index 5114646d..4f5e9e0e 100644
--- a/src/websockets/server.py
+++ b/src/websockets/server.py
@@ -10,6 +10,7 @@
import http
import logging
import socket
+import sys
import warnings
from types import TracebackType
from typing import (
@@ -698,7 +699,9 @@ def close(self) -> None:
# Wait until all accepted connections reach connection_made() and call
# register(). See https://bugs.python.org/issue34852 for details.
- await asyncio.sleep(0)
+ await asyncio.sleep(
+ 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None
+ )
# Close OPEN connections with status code 1001. Since the server was
# closed, handshake() closes OPENING conections with a HTTP 503 error.
@@ -707,7 +710,8 @@ def close(self) -> None:
# asyncio.wait doesn't accept an empty first argument
if self.websockets:
await asyncio.wait(
- [websocket.close(1001) for websocket in self.websockets], loop=self.loop
+ [websocket.close(1001) for websocket in self.websockets],
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
)
# Wait until all connection handlers are complete.
@@ -716,7 +720,7 @@ def close(self) -> None:
if self.websockets:
await asyncio.wait(
[websocket.handler_task for websocket in self.websockets],
- loop=self.loop,
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
)
# Tell wait_closed() to return.
diff --git a/tests/test_client_server.py b/tests/test_client_server.py
index 6171f21b..85828bdb 100644
--- a/tests/test_client_server.py
+++ b/tests/test_client_server.py
@@ -1381,13 +1381,16 @@ def test_client(self):
start_server = serve(handler, "localhost", 0)
server = self.loop.run_until_complete(start_server)
- @asyncio.coroutine
- def run_client():
- # Yield from connect.
- client = yield from connect(get_server_uri(server))
- self.assertEqual(client.state, State.OPEN)
- yield from client.close()
- self.assertEqual(client.state, State.CLOSED)
+ # @asyncio.coroutine is deprecated on Python ≥ 3.8
+ with warnings.catch_warnings(record=True):
+
+ @asyncio.coroutine
+ def run_client():
+ # Yield from connect.
+ client = yield from connect(get_server_uri(server))
+ self.assertEqual(client.state, State.OPEN)
+ yield from client.close()
+ self.assertEqual(client.state, State.CLOSED)
self.loop.run_until_complete(run_client())
@@ -1395,14 +1398,17 @@ def run_client():
self.loop.run_until_complete(server.wait_closed())
def test_server(self):
- @asyncio.coroutine
- def run_server():
- # Yield from serve.
- server = yield from serve(handler, "localhost", 0)
- self.assertTrue(server.sockets)
- server.close()
- yield from server.wait_closed()
- self.assertFalse(server.sockets)
+ # @asyncio.coroutine is deprecated on Python ≥ 3.8
+ with warnings.catch_warnings(record=True):
+
+ @asyncio.coroutine
+ def run_server():
+ # Yield from serve.
+ server = yield from serve(handler, "localhost", 0)
+ self.assertTrue(server.sockets)
+ server.close()
+ yield from server.wait_closed()
+ self.assertFalse(server.sockets)
self.loop.run_until_complete(run_server())
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
index d2793faf..04e2a38f 100644
--- a/tests/test_protocol.py
+++ b/tests/test_protocol.py
@@ -1,6 +1,7 @@
import asyncio
import contextlib
import logging
+import sys
import unittest
import unittest.mock
import warnings
@@ -100,7 +101,9 @@ def make_drain_slow(self, delay=MS):
original_drain = self.protocol._drain
async def delayed_drain():
- await asyncio.sleep(delay, loop=self.loop)
+ await asyncio.sleep(
+ delay, loop=self.loop if sys.version_info[:2] < (3, 8) else None
+ )
await original_drain()
self.protocol._drain = delayed_drain
From 504c66cf01bc7e2c4f5fbcceb1387cadf056a678 Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 14:02:51 +0200
Subject: [PATCH 16/20] Document and test support for Python 3.8.
---
.circleci/config.yml | 12 ++++++++++++
docs/changelog.rst | 2 ++
setup.py | 1 +
tox.ini | 2 +-
4 files changed, 16 insertions(+), 1 deletion(-)
diff --git a/.circleci/config.yml b/.circleci/config.yml
index a6c85d23..0877c161 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -29,6 +29,15 @@ jobs:
- checkout
- run: sudo pip install tox
- run: tox -e py37
+ py38:
+ docker:
+ - image: circleci/python:3.8.0rc1
+ steps:
+ # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway.
+ - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc
+ - checkout
+ - run: sudo pip install tox
+ - run: tox -e py38
workflows:
version: 2
@@ -41,3 +50,6 @@ workflows:
- py37:
requires:
- main
+ - py38:
+ requires:
+ - main
diff --git a/docs/changelog.rst b/docs/changelog.rst
index 87b2e438..2a106fbc 100644
--- a/docs/changelog.rst
+++ b/docs/changelog.rst
@@ -8,6 +8,8 @@ Changelog
*In development*
+* Added compatibility with Python 3.8.
+
8.0.2
.....
diff --git a/setup.py b/setup.py
index c7643010..f3581924 100644
--- a/setup.py
+++ b/setup.py
@@ -53,6 +53,7 @@
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
],
package_dir = {'': 'src'},
package_data = {'websockets': ['py.typed']},
diff --git a/tox.ini b/tox.ini
index 7397c90a..825e3406 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
[tox]
-envlist = py36,py37,coverage,black,flake8,isort,mypy
+envlist = py36,py37,py38,coverage,black,flake8,isort,mypy
[testenv]
commands = python -W default -m unittest {posargs}
From fc99c16a53e5d0c85448714021a5deb74021920a Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 21:18:30 +0200
Subject: [PATCH 17/20] Move test logging configuration to a single place.
---
tests/__init__.py | 5 +++++
tests/test_client_server.py | 5 -----
tests/test_protocol.py | 5 -----
3 files changed, 5 insertions(+), 10 deletions(-)
diff --git a/tests/__init__.py b/tests/__init__.py
index e69de29b..dd78609f 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -0,0 +1,5 @@
+import logging
+
+
+# Avoid displaying stack traces at the ERROR logging level.
+logging.basicConfig(level=logging.CRITICAL)
diff --git a/tests/test_client_server.py b/tests/test_client_server.py
index 85828bdb..ce0f66ce 100644
--- a/tests/test_client_server.py
+++ b/tests/test_client_server.py
@@ -2,7 +2,6 @@
import contextlib
import functools
import http
-import logging
import pathlib
import random
import socket
@@ -37,10 +36,6 @@
from .utils import AsyncioTestCase
-# Avoid displaying stack traces at the ERROR logging level.
-logging.basicConfig(level=logging.CRITICAL)
-
-
# Generate TLS certificate with:
# $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \
# -out test_localhost.crt -keyout test_localhost.key
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
index 04e2a38f..d95260a8 100644
--- a/tests/test_protocol.py
+++ b/tests/test_protocol.py
@@ -1,6 +1,5 @@
import asyncio
import contextlib
-import logging
import sys
import unittest
import unittest.mock
@@ -13,10 +12,6 @@
from .utils import MS, AsyncioTestCase
-# Avoid displaying stack traces at the ERROR logging level.
-logging.basicConfig(level=logging.CRITICAL)
-
-
async def async_iterable(iterable):
for item in iterable:
yield item
From 91dea74a9430d8573f7fd58c35fd731eddbb6cdb Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 21:33:27 +0200
Subject: [PATCH 18/20] Remove test that no longer makes sense.
Since version 7.0, when the server closes, it terminates connections
with close code 1001 instead of canceling them.
---
tests/test_client_server.py | 18 ++----------------
1 file changed, 2 insertions(+), 16 deletions(-)
diff --git a/tests/test_client_server.py b/tests/test_client_server.py
index ce0f66ce..35913666 100644
--- a/tests/test_client_server.py
+++ b/tests/test_client_server.py
@@ -173,7 +173,7 @@ class HealthCheckServerProtocol(WebSocketServerProtocol):
return http.HTTPStatus.OK, [("X-Access", "OK")], b"status = green\n"
-class SlowServerProtocol(WebSocketServerProtocol):
+class SlowOpeningHandshakeProtocol(WebSocketServerProtocol):
async def process_request(self, path, request_headers):
await asyncio.sleep(10 * MS)
@@ -1165,7 +1165,7 @@ def test_client_closes_connection_before_handshake(self, handshake):
# The server should stop properly anyway. It used to hang because the
# task handling the connection was waiting for the opening handshake.
- @with_server(create_protocol=SlowServerProtocol)
+ @with_server(create_protocol=SlowOpeningHandshakeProtocol)
def test_server_shuts_down_during_opening_handshake(self):
self.loop.call_later(5 * MS, self.server.close)
with self.assertRaises(InvalidStatusCode) as raised:
@@ -1188,20 +1188,6 @@ def test_server_shuts_down_during_connection_handling(self):
self.assertEqual(self.client.close_code, 1001)
self.assertEqual(server_ws.close_code, 1001)
- @with_server()
- @unittest.mock.patch("websockets.server.WebSocketServerProtocol.close")
- def test_server_shuts_down_during_connection_close(self, _close):
- _close.side_effect = asyncio.CancelledError
-
- self.server.closing = True
- with self.temp_client():
- self.loop.run_until_complete(self.client.send("Hello!"))
- reply = self.loop.run_until_complete(self.client.recv())
- self.assertEqual(reply, "Hello!")
-
- # Websocket connection terminates abnormally.
- self.assertEqual(self.client.close_code, 1006)
-
@with_server()
def test_server_shuts_down_waits_until_handlers_terminate(self):
# This handler waits a bit after the connection is closed in order
From 88073fe7da7e8c07a5f46ae7c77e1a4ef80736ee Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 21:41:13 +0200
Subject: [PATCH 19/20] Fix refactoring error.
WebSocketCommonProtocol.transport can be unset, but it cannot be None.
---
src/websockets/protocol.py | 14 ++++++++++----
tests/test_protocol.py | 8 ++++----
2 files changed, 14 insertions(+), 8 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index 76d46ad9..0623e136 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -366,9 +366,12 @@ def local_address(self) -> Any:
been established yet.
"""
- if self.transport is None:
+ try:
+ transport = self.transport
+ except AttributeError:
return None
- return self.transport.get_extra_info("sockname")
+ else:
+ return transport.get_extra_info("sockname")
@property
def remote_address(self) -> Any:
@@ -379,9 +382,12 @@ def remote_address(self) -> Any:
been established yet.
"""
- if self.transport is None:
+ try:
+ transport = self.transport
+ except AttributeError:
return None
- return self.transport.get_extra_info("peername")
+ else:
+ return transport.get_extra_info("peername")
@property
def open(self) -> bool:
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
index d95260a8..66a822e7 100644
--- a/tests/test_protocol.py
+++ b/tests/test_protocol.py
@@ -323,8 +323,8 @@ def test_local_address(self):
def test_local_address_before_connection(self):
# Emulate the situation before connection_open() runs.
- self.protocol.transport, _transport = None, self.protocol.transport
-
+ _transport = self.protocol.transport
+ del self.protocol.transport
try:
self.assertEqual(self.protocol.local_address, None)
finally:
@@ -339,8 +339,8 @@ def test_remote_address(self):
def test_remote_address_before_connection(self):
# Emulate the situation before connection_open() runs.
- self.protocol.transport, _transport = None, self.protocol.transport
-
+ _transport = self.protocol.transport
+ del self.protocol.transport
try:
self.assertEqual(self.protocol.remote_address, None)
finally:
From d75d4a4d34a7485ab8e5f2e65dee06dc37b7cfba Mon Sep 17 00:00:00 2001
From: Aymeric Augustin <aymeric.augustin@m4x.org>
Date: Sat, 5 Oct 2019 21:39:14 +0200
Subject: [PATCH 20/20] Remove useless type declaration.
---
src/websockets/protocol.py | 13 +++++--------
1 file changed, 5 insertions(+), 8 deletions(-)
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
index 0623e136..6c29b2a5 100644
--- a/src/websockets/protocol.py
+++ b/src/websockets/protocol.py
@@ -226,19 +226,16 @@ def __init__(
# ``self.read_limit``. The ``limit`` argument controls the line length
# limit and half the buffer limit of :class:`~asyncio.StreamReader`.
# That's why it must be set to half of ``self.read_limit``.
- self.reader: asyncio.StreamReader = asyncio.StreamReader(
- limit=read_limit // 2, loop=loop
- )
-
- self.transport: asyncio.Transport
- self._drain_lock = asyncio.Lock(
- loop=loop if sys.version_info[:2] < (3, 8) else None
- )
+ self.reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop)
# Copied from asyncio.FlowControlMixin
self._paused = False
self._drain_waiter: Optional[asyncio.Future[None]] = None
+ self._drain_lock = asyncio.Lock(
+ loop=loop if sys.version_info[:2] < (3, 8) else None
+ )
+
# This class implements the data transfer and closing handshake, which
# are shared between the client-side and the server-side.
# Subclasses implement the opening handshake and, on success, execute