forked from pool/python-websockets
against Python 3.8. OBS-URL: https://build.opensuse.org/package/show/devel:languages:python/python-websockets?expand=0&rev=19
2202 lines
81 KiB
Diff
2202 lines
81 KiB
Diff
From fbab68fad41b1ddad2bbe7eccd793f5d6c52f761 Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 11:41:21 +0200
|
|
Subject: [PATCH 01/20] Copy FlowControlMixin and StreamReaderProtocol.
|
|
|
|
This is the official recommendation of Python core devs.
|
|
|
|
The code is taken from the current 3.7 branch.
|
|
---
|
|
src/websockets/protocol.py | 127 ++++++++++++++++++++++++++++++++++++-
|
|
1 file changed, 126 insertions(+), 1 deletion(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index 1f0edcce..2f74cd23 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -61,7 +61,132 @@ class State(enum.IntEnum):
|
|
# between the check and the assignment.
|
|
|
|
|
|
-class WebSocketCommonProtocol(asyncio.StreamReaderProtocol):
|
|
+class FlowControlMixin(asyncio.Protocol):
|
|
+ """Reusable flow control logic for StreamWriter.drain().
|
|
+ This implements the protocol methods pause_writing(),
|
|
+ resume_writing() and connection_lost(). If the subclass overrides
|
|
+ these it must call the super methods.
|
|
+ StreamWriter.drain() must wait for _drain_helper() coroutine.
|
|
+ """
|
|
+
|
|
+ def __init__(self, loop=None):
|
|
+ if loop is None:
|
|
+ self._loop = asyncio.get_event_loop()
|
|
+ else:
|
|
+ self._loop = loop
|
|
+ self._paused = False
|
|
+ self._drain_waiter = None
|
|
+ self._connection_lost = False
|
|
+
|
|
+ def pause_writing(self):
|
|
+ assert not self._paused
|
|
+ self._paused = True
|
|
+ if self._loop.get_debug():
|
|
+ logger.debug("%r pauses writing", self)
|
|
+
|
|
+ def resume_writing(self):
|
|
+ assert self._paused
|
|
+ self._paused = False
|
|
+ if self._loop.get_debug():
|
|
+ logger.debug("%r resumes writing", self)
|
|
+
|
|
+ waiter = self._drain_waiter
|
|
+ if waiter is not None:
|
|
+ self._drain_waiter = None
|
|
+ if not waiter.done():
|
|
+ waiter.set_result(None)
|
|
+
|
|
+ def connection_lost(self, exc):
|
|
+ self._connection_lost = True
|
|
+ # Wake up the writer if currently paused.
|
|
+ if not self._paused:
|
|
+ return
|
|
+ waiter = self._drain_waiter
|
|
+ if waiter is None:
|
|
+ return
|
|
+ self._drain_waiter = None
|
|
+ if waiter.done():
|
|
+ return
|
|
+ if exc is None:
|
|
+ waiter.set_result(None)
|
|
+ else:
|
|
+ waiter.set_exception(exc)
|
|
+
|
|
+ async def _drain_helper(self):
|
|
+ if self._connection_lost:
|
|
+ raise ConnectionResetError("Connection lost")
|
|
+ if not self._paused:
|
|
+ return
|
|
+ waiter = self._drain_waiter
|
|
+ assert waiter is None or waiter.cancelled()
|
|
+ waiter = self._loop.create_future()
|
|
+ self._drain_waiter = waiter
|
|
+ await waiter
|
|
+
|
|
+
|
|
+class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol):
|
|
+ """Helper class to adapt between Protocol and StreamReader.
|
|
+ (This is a helper class instead of making StreamReader itself a
|
|
+ Protocol subclass, because the StreamReader has other potential
|
|
+ uses, and to prevent the user of the StreamReader to accidentally
|
|
+ call inappropriate methods of the protocol.)
|
|
+ """
|
|
+
|
|
+ def __init__(self, stream_reader, client_connected_cb=None, loop=None):
|
|
+ super().__init__(loop=loop)
|
|
+ self._stream_reader = stream_reader
|
|
+ self._stream_writer = None
|
|
+ self._client_connected_cb = client_connected_cb
|
|
+ self._over_ssl = False
|
|
+ self._closed = self._loop.create_future()
|
|
+
|
|
+ def connection_made(self, transport):
|
|
+ self._stream_reader.set_transport(transport)
|
|
+ self._over_ssl = transport.get_extra_info("sslcontext") is not None
|
|
+ if self._client_connected_cb is not None:
|
|
+ self._stream_writer = asyncio.StreamWriter(
|
|
+ transport, self, self._stream_reader, self._loop
|
|
+ )
|
|
+ res = self._client_connected_cb(self._stream_reader, self._stream_writer)
|
|
+ if asyncio.iscoroutine(res):
|
|
+ self._loop.create_task(res)
|
|
+
|
|
+ def connection_lost(self, exc):
|
|
+ if self._stream_reader is not None:
|
|
+ if exc is None:
|
|
+ self._stream_reader.feed_eof()
|
|
+ else:
|
|
+ self._stream_reader.set_exception(exc)
|
|
+ if not self._closed.done():
|
|
+ if exc is None:
|
|
+ self._closed.set_result(None)
|
|
+ else:
|
|
+ self._closed.set_exception(exc)
|
|
+ super().connection_lost(exc)
|
|
+ self._stream_reader = None
|
|
+ self._stream_writer = None
|
|
+
|
|
+ def data_received(self, data):
|
|
+ self._stream_reader.feed_data(data)
|
|
+
|
|
+ def eof_received(self):
|
|
+ self._stream_reader.feed_eof()
|
|
+ if self._over_ssl:
|
|
+ # Prevent a warning in SSLProtocol.eof_received:
|
|
+ # "returning true from eof_received()
|
|
+ # has no effect when using ssl"
|
|
+ return False
|
|
+ return True
|
|
+
|
|
+ def __del__(self):
|
|
+ # Prevent reports about unhandled exceptions.
|
|
+ # Better than self._closed._log_traceback = False hack
|
|
+ closed = self._closed
|
|
+ if closed.done() and not closed.cancelled():
|
|
+ closed.exception()
|
|
+
|
|
+
|
|
+class WebSocketCommonProtocol(StreamReaderProtocol):
|
|
"""
|
|
:class:`~asyncio.Protocol` subclass implementing the data transfer phase.
|
|
|
|
|
|
From 57907330fe789f0b5f6b71b2810a0e36fee14844 Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 11:54:04 +0200
|
|
Subject: [PATCH 02/20] Remove docstrings and debug logs.
|
|
|
|
---
|
|
src/websockets/protocol.py | 16 ----------------
|
|
1 file changed, 16 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index 2f74cd23..d74c8157 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -62,12 +62,6 @@ class State(enum.IntEnum):
|
|
|
|
|
|
class FlowControlMixin(asyncio.Protocol):
|
|
- """Reusable flow control logic for StreamWriter.drain().
|
|
- This implements the protocol methods pause_writing(),
|
|
- resume_writing() and connection_lost(). If the subclass overrides
|
|
- these it must call the super methods.
|
|
- StreamWriter.drain() must wait for _drain_helper() coroutine.
|
|
- """
|
|
|
|
def __init__(self, loop=None):
|
|
if loop is None:
|
|
@@ -81,14 +75,10 @@ def __init__(self, loop=None):
|
|
def pause_writing(self):
|
|
assert not self._paused
|
|
self._paused = True
|
|
- if self._loop.get_debug():
|
|
- logger.debug("%r pauses writing", self)
|
|
|
|
def resume_writing(self):
|
|
assert self._paused
|
|
self._paused = False
|
|
- if self._loop.get_debug():
|
|
- logger.debug("%r resumes writing", self)
|
|
|
|
waiter = self._drain_waiter
|
|
if waiter is not None:
|
|
@@ -125,12 +115,6 @@ def connection_lost(self, exc):
|
|
|
|
|
|
class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol):
|
|
- """Helper class to adapt between Protocol and StreamReader.
|
|
- (This is a helper class instead of making StreamReader itself a
|
|
- Protocol subclass, because the StreamReader has other potential
|
|
- uses, and to prevent the user of the StreamReader to accidentally
|
|
- call inappropriate methods of the protocol.)
|
|
- """
|
|
|
|
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
|
|
super().__init__(loop=loop)
|
|
|
|
From 53523d06c7752b394c1f1e5566f464db741e2617 Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 11:54:43 +0200
|
|
Subject: [PATCH 03/20] Merge FlowControlMixin in StreamReaderProtocol.
|
|
|
|
---
|
|
src/websockets/protocol.py | 54 +++++++++++++++++---------------------
|
|
1 file changed, 24 insertions(+), 30 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index d74c8157..49d8b4f2 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -61,9 +61,9 @@ class State(enum.IntEnum):
|
|
# between the check and the assignment.
|
|
|
|
|
|
-class FlowControlMixin(asyncio.Protocol):
|
|
+class StreamReaderProtocol(asyncio.Protocol):
|
|
|
|
- def __init__(self, loop=None):
|
|
+ def __init__(self, stream_reader, client_connected_cb=None, loop=None):
|
|
if loop is None:
|
|
self._loop = asyncio.get_event_loop()
|
|
else:
|
|
@@ -72,6 +72,12 @@ def __init__(self, loop=None):
|
|
self._drain_waiter = None
|
|
self._connection_lost = False
|
|
|
|
+ self._stream_reader = stream_reader
|
|
+ self._stream_writer = None
|
|
+ self._client_connected_cb = client_connected_cb
|
|
+ self._over_ssl = False
|
|
+ self._closed = self._loop.create_future()
|
|
+
|
|
def pause_writing(self):
|
|
assert not self._paused
|
|
self._paused = True
|
|
@@ -86,22 +92,6 @@ def resume_writing(self):
|
|
if not waiter.done():
|
|
waiter.set_result(None)
|
|
|
|
- def connection_lost(self, exc):
|
|
- self._connection_lost = True
|
|
- # Wake up the writer if currently paused.
|
|
- if not self._paused:
|
|
- return
|
|
- waiter = self._drain_waiter
|
|
- if waiter is None:
|
|
- return
|
|
- self._drain_waiter = None
|
|
- if waiter.done():
|
|
- return
|
|
- if exc is None:
|
|
- waiter.set_result(None)
|
|
- else:
|
|
- waiter.set_exception(exc)
|
|
-
|
|
async def _drain_helper(self):
|
|
if self._connection_lost:
|
|
raise ConnectionResetError("Connection lost")
|
|
@@ -113,17 +103,6 @@ def connection_lost(self, exc):
|
|
self._drain_waiter = waiter
|
|
await waiter
|
|
|
|
-
|
|
-class StreamReaderProtocol(FlowControlMixin, asyncio.Protocol):
|
|
-
|
|
- def __init__(self, stream_reader, client_connected_cb=None, loop=None):
|
|
- super().__init__(loop=loop)
|
|
- self._stream_reader = stream_reader
|
|
- self._stream_writer = None
|
|
- self._client_connected_cb = client_connected_cb
|
|
- self._over_ssl = False
|
|
- self._closed = self._loop.create_future()
|
|
-
|
|
def connection_made(self, transport):
|
|
self._stream_reader.set_transport(transport)
|
|
self._over_ssl = transport.get_extra_info("sslcontext") is not None
|
|
@@ -146,7 +125,22 @@ def connection_lost(self, exc):
|
|
self._closed.set_result(None)
|
|
else:
|
|
self._closed.set_exception(exc)
|
|
- super().connection_lost(exc)
|
|
+
|
|
+ self._connection_lost = True
|
|
+ # Wake up the writer if currently paused.
|
|
+ if not self._paused:
|
|
+ return
|
|
+ waiter = self._drain_waiter
|
|
+ if waiter is None:
|
|
+ return
|
|
+ self._drain_waiter = None
|
|
+ if waiter.done():
|
|
+ return
|
|
+ if exc is None:
|
|
+ waiter.set_result(None)
|
|
+ else:
|
|
+ waiter.set_exception(exc)
|
|
+
|
|
self._stream_reader = None
|
|
self._stream_writer = None
|
|
|
|
|
|
From e6d8bcfe2a0f16c50260cafab45f06da6baa1689 Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 11:55:53 +0200
|
|
Subject: [PATCH 04/20] Deduplicate loop and _loop attributes.
|
|
|
|
---
|
|
src/websockets/protocol.py | 18 ++++++------------
|
|
1 file changed, 6 insertions(+), 12 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index 49d8b4f2..98c23ab1 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -63,11 +63,7 @@ class State(enum.IntEnum):
|
|
|
|
class StreamReaderProtocol(asyncio.Protocol):
|
|
|
|
- def __init__(self, stream_reader, client_connected_cb=None, loop=None):
|
|
- if loop is None:
|
|
- self._loop = asyncio.get_event_loop()
|
|
- else:
|
|
- self._loop = loop
|
|
+ def __init__(self, stream_reader, client_connected_cb=None):
|
|
self._paused = False
|
|
self._drain_waiter = None
|
|
self._connection_lost = False
|
|
@@ -76,7 +72,7 @@ def __init__(self, stream_reader, client_connected_cb=None, loop=None):
|
|
self._stream_writer = None
|
|
self._client_connected_cb = client_connected_cb
|
|
self._over_ssl = False
|
|
- self._closed = self._loop.create_future()
|
|
+ self._closed = self.loop.create_future()
|
|
|
|
def pause_writing(self):
|
|
assert not self._paused
|
|
@@ -99,7 +95,7 @@ def resume_writing(self):
|
|
return
|
|
waiter = self._drain_waiter
|
|
assert waiter is None or waiter.cancelled()
|
|
- waiter = self._loop.create_future()
|
|
+ waiter = self.loop.create_future()
|
|
self._drain_waiter = waiter
|
|
await waiter
|
|
|
|
@@ -108,11 +104,11 @@ def connection_made(self, transport):
|
|
self._over_ssl = transport.get_extra_info("sslcontext") is not None
|
|
if self._client_connected_cb is not None:
|
|
self._stream_writer = asyncio.StreamWriter(
|
|
- transport, self, self._stream_reader, self._loop
|
|
+ transport, self, self._stream_reader, self.loop
|
|
)
|
|
res = self._client_connected_cb(self._stream_reader, self._stream_writer)
|
|
if asyncio.iscoroutine(res):
|
|
- self._loop.create_task(res)
|
|
+ self.loop.create_task(res)
|
|
|
|
def connection_lost(self, exc):
|
|
if self._stream_reader is not None:
|
|
@@ -315,8 +311,6 @@ def __init__(
|
|
self.read_limit = read_limit
|
|
self.write_limit = write_limit
|
|
|
|
- # Store a reference to loop to avoid relying on self._loop, a private
|
|
- # attribute of StreamReaderProtocol, inherited from FlowControlMixin.
|
|
if loop is None:
|
|
loop = asyncio.get_event_loop()
|
|
self.loop = loop
|
|
@@ -331,7 +325,7 @@ def __init__(
|
|
# limit and half the buffer limit of :class:`~asyncio.StreamReader`.
|
|
# That's why it must be set to half of ``self.read_limit``.
|
|
stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop)
|
|
- super().__init__(stream_reader, self.client_connected, loop)
|
|
+ super().__init__(stream_reader, self.client_connected)
|
|
|
|
self.reader: asyncio.StreamReader
|
|
self.writer: asyncio.StreamWriter
|
|
|
|
From 4205cb176e62c938247af59e94214814242360ec Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 11:57:34 +0200
|
|
Subject: [PATCH 05/20] Remove client_connected callback.
|
|
|
|
---
|
|
src/websockets/protocol.py | 31 +++++++------------------------
|
|
1 file changed, 7 insertions(+), 24 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index 98c23ab1..bfc354a8 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -63,14 +63,13 @@ class State(enum.IntEnum):
|
|
|
|
class StreamReaderProtocol(asyncio.Protocol):
|
|
|
|
- def __init__(self, stream_reader, client_connected_cb=None):
|
|
+ def __init__(self, stream_reader):
|
|
self._paused = False
|
|
self._drain_waiter = None
|
|
self._connection_lost = False
|
|
|
|
self._stream_reader = stream_reader
|
|
self._stream_writer = None
|
|
- self._client_connected_cb = client_connected_cb
|
|
self._over_ssl = False
|
|
self._closed = self.loop.create_future()
|
|
|
|
@@ -102,13 +101,11 @@ def resume_writing(self):
|
|
def connection_made(self, transport):
|
|
self._stream_reader.set_transport(transport)
|
|
self._over_ssl = transport.get_extra_info("sslcontext") is not None
|
|
- if self._client_connected_cb is not None:
|
|
- self._stream_writer = asyncio.StreamWriter(
|
|
- transport, self, self._stream_reader, self.loop
|
|
- )
|
|
- res = self._client_connected_cb(self._stream_reader, self._stream_writer)
|
|
- if asyncio.iscoroutine(res):
|
|
- self.loop.create_task(res)
|
|
+ self._stream_writer = asyncio.StreamWriter(
|
|
+ transport, self, self._stream_reader, self.loop
|
|
+ )
|
|
+ self.reader = self._stream_reader
|
|
+ self.writer = self._stream_writer
|
|
|
|
def connection_lost(self, exc):
|
|
if self._stream_reader is not None:
|
|
@@ -325,7 +322,7 @@ def __init__(
|
|
# limit and half the buffer limit of :class:`~asyncio.StreamReader`.
|
|
# That's why it must be set to half of ``self.read_limit``.
|
|
stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop)
|
|
- super().__init__(stream_reader, self.client_connected)
|
|
+ super().__init__(stream_reader)
|
|
|
|
self.reader: asyncio.StreamReader
|
|
self.writer: asyncio.StreamWriter
|
|
@@ -381,20 +378,6 @@ def __init__(
|
|
# Task closing the TCP connection.
|
|
self.close_connection_task: asyncio.Task[None]
|
|
|
|
- def client_connected(
|
|
- self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
|
- ) -> None:
|
|
- """
|
|
- Callback when the TCP connection is established.
|
|
-
|
|
- Record references to the stream reader and the stream writer to avoid
|
|
- using private attributes ``_stream_reader`` and ``_stream_writer`` of
|
|
- :class:`~asyncio.StreamReaderProtocol`.
|
|
-
|
|
- """
|
|
- self.reader = reader
|
|
- self.writer = writer
|
|
-
|
|
def connection_open(self) -> None:
|
|
"""
|
|
Callback when the WebSocket opening handshake completes.
|
|
|
|
From 2ebb4b4eff76be388ee37570e891ed8432e19a9f Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 12:00:31 +0200
|
|
Subject: [PATCH 06/20] Deduplicate reader/writer and _stream_reader/writer
|
|
attributes.
|
|
|
|
---
|
|
src/websockets/protocol.py | 36 +++++++++++++++---------------------
|
|
1 file changed, 15 insertions(+), 21 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index bfc354a8..9c61a409 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -62,14 +62,11 @@ class State(enum.IntEnum):
|
|
|
|
|
|
class StreamReaderProtocol(asyncio.Protocol):
|
|
-
|
|
- def __init__(self, stream_reader):
|
|
+ def __init__(self):
|
|
self._paused = False
|
|
self._drain_waiter = None
|
|
self._connection_lost = False
|
|
|
|
- self._stream_reader = stream_reader
|
|
- self._stream_writer = None
|
|
self._over_ssl = False
|
|
self._closed = self.loop.create_future()
|
|
|
|
@@ -99,20 +96,16 @@ def resume_writing(self):
|
|
await waiter
|
|
|
|
def connection_made(self, transport):
|
|
- self._stream_reader.set_transport(transport)
|
|
+ self.reader.set_transport(transport)
|
|
self._over_ssl = transport.get_extra_info("sslcontext") is not None
|
|
- self._stream_writer = asyncio.StreamWriter(
|
|
- transport, self, self._stream_reader, self.loop
|
|
- )
|
|
- self.reader = self._stream_reader
|
|
- self.writer = self._stream_writer
|
|
+ self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop)
|
|
|
|
def connection_lost(self, exc):
|
|
- if self._stream_reader is not None:
|
|
+ if self.reader is not None:
|
|
if exc is None:
|
|
- self._stream_reader.feed_eof()
|
|
+ self.reader.feed_eof()
|
|
else:
|
|
- self._stream_reader.set_exception(exc)
|
|
+ self.reader.set_exception(exc)
|
|
if not self._closed.done():
|
|
if exc is None:
|
|
self._closed.set_result(None)
|
|
@@ -134,14 +127,14 @@ def connection_lost(self, exc):
|
|
else:
|
|
waiter.set_exception(exc)
|
|
|
|
- self._stream_reader = None
|
|
- self._stream_writer = None
|
|
+ del self.reader
|
|
+ del self.writer
|
|
|
|
def data_received(self, data):
|
|
- self._stream_reader.feed_data(data)
|
|
+ self.reader.feed_data(data)
|
|
|
|
def eof_received(self):
|
|
- self._stream_reader.feed_eof()
|
|
+ self.reader.feed_eof()
|
|
if self._over_ssl:
|
|
# Prevent a warning in SSLProtocol.eof_received:
|
|
# "returning true from eof_received()
|
|
@@ -321,13 +314,14 @@ def __init__(
|
|
# ``self.read_limit``. The ``limit`` argument controls the line length
|
|
# limit and half the buffer limit of :class:`~asyncio.StreamReader`.
|
|
# That's why it must be set to half of ``self.read_limit``.
|
|
- stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop)
|
|
- super().__init__(stream_reader)
|
|
-
|
|
- self.reader: asyncio.StreamReader
|
|
+ self.reader: asyncio.StreamReader = asyncio.StreamReader(
|
|
+ limit=read_limit // 2, loop=loop
|
|
+ )
|
|
self.writer: asyncio.StreamWriter
|
|
self._drain_lock = asyncio.Lock(loop=loop)
|
|
|
|
+ super().__init__()
|
|
+
|
|
# This class implements the data transfer and closing handshake, which
|
|
# are shared between the client-side and the server-side.
|
|
# Subclasses implement the opening handshake and, on success, execute
|
|
|
|
From c7ae53f4dfc586f21af1fdbdcfba321bbccd491e Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 12:25:40 +0200
|
|
Subject: [PATCH 07/20] Merge asyncio.Protocol methods.
|
|
|
|
---
|
|
src/websockets/protocol.py | 160 ++++++++++++++++---------------------
|
|
1 file changed, 70 insertions(+), 90 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index 9c61a409..89e3464a 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -70,20 +70,6 @@ def __init__(self):
|
|
self._over_ssl = False
|
|
self._closed = self.loop.create_future()
|
|
|
|
- def pause_writing(self):
|
|
- assert not self._paused
|
|
- self._paused = True
|
|
-
|
|
- def resume_writing(self):
|
|
- assert self._paused
|
|
- self._paused = False
|
|
-
|
|
- waiter = self._drain_waiter
|
|
- if waiter is not None:
|
|
- self._drain_waiter = None
|
|
- if not waiter.done():
|
|
- waiter.set_result(None)
|
|
-
|
|
async def _drain_helper(self):
|
|
if self._connection_lost:
|
|
raise ConnectionResetError("Connection lost")
|
|
@@ -95,53 +81,6 @@ def resume_writing(self):
|
|
self._drain_waiter = waiter
|
|
await waiter
|
|
|
|
- def connection_made(self, transport):
|
|
- self.reader.set_transport(transport)
|
|
- self._over_ssl = transport.get_extra_info("sslcontext") is not None
|
|
- self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop)
|
|
-
|
|
- def connection_lost(self, exc):
|
|
- if self.reader is not None:
|
|
- if exc is None:
|
|
- self.reader.feed_eof()
|
|
- else:
|
|
- self.reader.set_exception(exc)
|
|
- if not self._closed.done():
|
|
- if exc is None:
|
|
- self._closed.set_result(None)
|
|
- else:
|
|
- self._closed.set_exception(exc)
|
|
-
|
|
- self._connection_lost = True
|
|
- # Wake up the writer if currently paused.
|
|
- if not self._paused:
|
|
- return
|
|
- waiter = self._drain_waiter
|
|
- if waiter is None:
|
|
- return
|
|
- self._drain_waiter = None
|
|
- if waiter.done():
|
|
- return
|
|
- if exc is None:
|
|
- waiter.set_result(None)
|
|
- else:
|
|
- waiter.set_exception(exc)
|
|
-
|
|
- del self.reader
|
|
- del self.writer
|
|
-
|
|
- def data_received(self, data):
|
|
- self.reader.feed_data(data)
|
|
-
|
|
- def eof_received(self):
|
|
- self.reader.feed_eof()
|
|
- if self._over_ssl:
|
|
- # Prevent a warning in SSLProtocol.eof_received:
|
|
- # "returning true from eof_received()
|
|
- # has no effect when using ssl"
|
|
- return False
|
|
- return True
|
|
-
|
|
def __del__(self):
|
|
# Prevent reports about unhandled exceptions.
|
|
# Better than self._closed._log_traceback = False hack
|
|
@@ -1363,7 +1302,7 @@ def abort_pings(self) -> None:
|
|
"%s - aborted pending ping%s: %s", self.side, plural, pings_hex
|
|
)
|
|
|
|
- # asyncio.StreamReaderProtocol methods
|
|
+ # asyncio.Protocol methods
|
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
"""
|
|
@@ -1382,34 +1321,11 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
logger.debug("%s - event = connection_made(%s)", self.side, transport)
|
|
# mypy thinks transport is a BaseTransport, not a Transport.
|
|
transport.set_write_buffer_limits(self.write_limit) # type: ignore
|
|
- super().connection_made(transport)
|
|
-
|
|
- def eof_received(self) -> bool:
|
|
- """
|
|
- Close the transport after receiving EOF.
|
|
-
|
|
- Since Python 3.5, `:meth:~StreamReaderProtocol.eof_received` returns
|
|
- ``True`` on non-TLS connections.
|
|
-
|
|
- See http://bugs.python.org/issue24539 for more information.
|
|
-
|
|
- This is inappropriate for ``websockets`` for at least three reasons:
|
|
-
|
|
- 1. The use case is to read data until EOF with self.reader.read(-1).
|
|
- Since WebSocket is a TLV protocol, this never happens.
|
|
-
|
|
- 2. It doesn't work on TLS connections. A falsy value must be
|
|
- returned to have the same behavior on TLS and plain connections.
|
|
-
|
|
- 3. The WebSocket protocol has its own closing handshake. Endpoints
|
|
- close the TCP connection after sending a close frame.
|
|
-
|
|
- As a consequence we revert to the previous, more useful behavior.
|
|
|
|
- """
|
|
- logger.debug("%s - event = eof_received()", self.side)
|
|
- super().eof_received()
|
|
- return False
|
|
+ # Copied from asyncio.StreamReaderProtocol
|
|
+ self.reader.set_transport(transport)
|
|
+ self._over_ssl = transport.get_extra_info("sslcontext") is not None
|
|
+ self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop)
|
|
|
|
def connection_lost(self, exc: Optional[Exception]) -> None:
|
|
"""
|
|
@@ -1434,4 +1350,68 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
|
|
# - it's set only here in connection_lost() which is called only once;
|
|
# - it must never be canceled.
|
|
self.connection_lost_waiter.set_result(None)
|
|
- super().connection_lost(exc)
|
|
+
|
|
+ # Copied from asyncio.StreamReaderProtocol
|
|
+ if self.reader is not None:
|
|
+ if exc is None:
|
|
+ self.reader.feed_eof()
|
|
+ else:
|
|
+ self.reader.set_exception(exc)
|
|
+ if not self._closed.done():
|
|
+ if exc is None:
|
|
+ self._closed.set_result(None)
|
|
+ else:
|
|
+ self._closed.set_exception(exc)
|
|
+
|
|
+ # Copied from asyncio.FlowControlMixin
|
|
+ self._connection_lost = True
|
|
+ # Wake up the writer if currently paused.
|
|
+ if not self._paused:
|
|
+ return
|
|
+ waiter = self._drain_waiter
|
|
+ if waiter is None:
|
|
+ return
|
|
+ self._drain_waiter = None
|
|
+ if waiter.done():
|
|
+ return
|
|
+ if exc is None:
|
|
+ waiter.set_result(None)
|
|
+ else:
|
|
+ waiter.set_exception(exc)
|
|
+
|
|
+ del self.reader
|
|
+ del self.writer
|
|
+
|
|
+ def pause_writing(self) -> None:
|
|
+ assert not self._paused
|
|
+ self._paused = True
|
|
+
|
|
+ def resume_writing(self) -> None:
|
|
+ assert self._paused
|
|
+ self._paused = False
|
|
+
|
|
+ waiter = self._drain_waiter
|
|
+ if waiter is not None:
|
|
+ self._drain_waiter = None
|
|
+ if not waiter.done():
|
|
+ waiter.set_result(None)
|
|
+
|
|
+ def data_received(self, data: bytes) -> None:
|
|
+ logger.debug("%s - event = data_received(<%d bytes>)", self.side, len(data))
|
|
+ self.reader.feed_data(data)
|
|
+
|
|
+ def eof_received(self) -> None:
|
|
+ """
|
|
+ Close the transport after receiving EOF.
|
|
+
|
|
+ The WebSocket protocol has its own closing handshake: endpoints close
|
|
+ the TCP or TLS connection after sending and receiving a close frame.
|
|
+
|
|
+ As a consequence, they never need to write after receiving EOF, so
|
|
+ there's no reason to keep the transport open by returning ``True``.
|
|
+
|
|
+ Besides, that doesn't work on TLS connections.
|
|
+
|
|
+ """
|
|
+ logger.debug("%s - event = eof_received()", self.side)
|
|
+ self.reader.feed_eof()
|
|
|
|
From ab5dd72aebd86b878ff08d1ebd5cbb3256e2d6e3 Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 13:09:40 +0200
|
|
Subject: [PATCH 08/20] Finish merging StreamReaderProtocol and
|
|
FlowControlMixin.
|
|
|
|
---
|
|
src/websockets/protocol.py | 59 +++++++++++++++++++-------------------
|
|
1 file changed, 29 insertions(+), 30 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index 89e3464a..2db44e5d 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -61,35 +61,7 @@ class State(enum.IntEnum):
|
|
# between the check and the assignment.
|
|
|
|
|
|
-class StreamReaderProtocol(asyncio.Protocol):
|
|
- def __init__(self):
|
|
- self._paused = False
|
|
- self._drain_waiter = None
|
|
- self._connection_lost = False
|
|
-
|
|
- self._over_ssl = False
|
|
- self._closed = self.loop.create_future()
|
|
-
|
|
- async def _drain_helper(self):
|
|
- if self._connection_lost:
|
|
- raise ConnectionResetError("Connection lost")
|
|
- if not self._paused:
|
|
- return
|
|
- waiter = self._drain_waiter
|
|
- assert waiter is None or waiter.cancelled()
|
|
- waiter = self.loop.create_future()
|
|
- self._drain_waiter = waiter
|
|
- await waiter
|
|
-
|
|
- def __del__(self):
|
|
- # Prevent reports about unhandled exceptions.
|
|
- # Better than self._closed._log_traceback = False hack
|
|
- closed = self._closed
|
|
- if closed.done() and not closed.cancelled():
|
|
- closed.exception()
|
|
-
|
|
-
|
|
-class WebSocketCommonProtocol(StreamReaderProtocol):
|
|
+class WebSocketCommonProtocol(asyncio.Protocol):
|
|
"""
|
|
:class:`~asyncio.Protocol` subclass implementing the data transfer phase.
|
|
|
|
@@ -259,7 +231,14 @@ def __init__(
|
|
self.writer: asyncio.StreamWriter
|
|
self._drain_lock = asyncio.Lock(loop=loop)
|
|
|
|
- super().__init__()
|
|
+ # Copied from asyncio.FlowControlMixin
|
|
+ self._paused = False
|
|
+ self._drain_waiter: Optional[asyncio.Future[None]] = None
|
|
+ self._connection_lost = False
|
|
+
|
|
+ # Copied from asyncio.StreamReaderProtocol
|
|
+ self._over_ssl = False
|
|
+ self._closed = self.loop.create_future()
|
|
|
|
# This class implements the data transfer and closing handshake, which
|
|
# are shared between the client-side and the server-side.
|
|
@@ -311,6 +290,26 @@ def __init__(
|
|
# Task closing the TCP connection.
|
|
self.close_connection_task: asyncio.Task[None]
|
|
|
|
+ # Copied from asyncio.StreamReaderProtocol
|
|
+ def __del__(self) -> None:
|
|
+ # Prevent reports about unhandled exceptions.
|
|
+ # Better than self._closed._log_traceback = False hack
|
|
+ closed = self._closed
|
|
+ if closed.done() and not closed.cancelled():
|
|
+ closed.exception()
|
|
+
|
|
+ # Copied from asyncio.FlowControlMixin
|
|
+ async def _drain_helper(self) -> None:
|
|
+ if self._connection_lost:
|
|
+ raise ConnectionResetError("Connection lost")
|
|
+ if not self._paused:
|
|
+ return
|
|
+ waiter = self._drain_waiter
|
|
+ assert waiter is None or waiter.cancelled()
|
|
+ waiter = self.loop.create_future()
|
|
+ self._drain_waiter = waiter
|
|
+ await waiter
|
|
+
|
|
def connection_open(self) -> None:
|
|
"""
|
|
Callback when the WebSocket opening handshake completes.
|
|
|
|
From 080543c4961037fcd8a207956e48474366e13ad5 Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 13:17:26 +0200
|
|
Subject: [PATCH 09/20] Deduplicate connection termination tracking.
|
|
|
|
---
|
|
src/websockets/protocol.py | 21 +++++----------------
|
|
1 file changed, 5 insertions(+), 16 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index 2db44e5d..1cd5a91c 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -234,11 +234,9 @@ def __init__(
|
|
# Copied from asyncio.FlowControlMixin
|
|
self._paused = False
|
|
self._drain_waiter: Optional[asyncio.Future[None]] = None
|
|
- self._connection_lost = False
|
|
|
|
# Copied from asyncio.StreamReaderProtocol
|
|
self._over_ssl = False
|
|
- self._closed = self.loop.create_future()
|
|
|
|
# This class implements the data transfer and closing handshake, which
|
|
# are shared between the client-side and the server-side.
|
|
@@ -290,17 +288,14 @@ def __init__(
|
|
# Task closing the TCP connection.
|
|
self.close_connection_task: asyncio.Task[None]
|
|
|
|
- # Copied from asyncio.StreamReaderProtocol
|
|
- def __del__(self) -> None:
|
|
- # Prevent reports about unhandled exceptions.
|
|
- # Better than self._closed._log_traceback = False hack
|
|
- closed = self._closed
|
|
- if closed.done() and not closed.cancelled():
|
|
- closed.exception()
|
|
+ # asyncio.StreamWriter expects this attribute on the Protocol
|
|
+ @property
|
|
+ def _closed(self) -> asyncio.Future:
|
|
+ return self.connection_lost_waiter
|
|
|
|
# Copied from asyncio.FlowControlMixin
|
|
async def _drain_helper(self) -> None:
|
|
- if self._connection_lost:
|
|
+ if self.connection_lost_waiter.done():
|
|
raise ConnectionResetError("Connection lost")
|
|
if not self._paused:
|
|
return
|
|
@@ -1356,14 +1351,8 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
|
|
self.reader.feed_eof()
|
|
else:
|
|
self.reader.set_exception(exc)
|
|
- if not self._closed.done():
|
|
- if exc is None:
|
|
- self._closed.set_result(None)
|
|
- else:
|
|
- self._closed.set_exception(exc)
|
|
|
|
# Copied from asyncio.FlowControlMixin
|
|
- self._connection_lost = True
|
|
# Wake up the writer if currently paused.
|
|
if not self._paused:
|
|
return
|
|
|
|
From 284932f967699467908e69226118f9bcd5ba8f9d Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 13:21:25 +0200
|
|
Subject: [PATCH 10/20] Remove unused attribute.
|
|
|
|
---
|
|
src/websockets/protocol.py | 4 ----
|
|
1 file changed, 4 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index 1cd5a91c..a1c90916 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -235,9 +235,6 @@ def __init__(
|
|
self._paused = False
|
|
self._drain_waiter: Optional[asyncio.Future[None]] = None
|
|
|
|
- # Copied from asyncio.StreamReaderProtocol
|
|
- self._over_ssl = False
|
|
-
|
|
# This class implements the data transfer and closing handshake, which
|
|
# are shared between the client-side and the server-side.
|
|
# Subclasses implement the opening handshake and, on success, execute
|
|
@@ -1318,7 +1315,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
|
|
# Copied from asyncio.StreamReaderProtocol
|
|
self.reader.set_transport(transport)
|
|
- self._over_ssl = transport.get_extra_info("sslcontext") is not None
|
|
self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop)
|
|
|
|
def connection_lost(self, exc: Optional[Exception]) -> None:
|
|
|
|
From 9eeeeb3547eb7abc2253c75e462a50a8f68ba16b Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 13:28:49 +0200
|
|
Subject: [PATCH 11/20] Ignore quality checks for code copied from asyncio.
|
|
|
|
---
|
|
src/websockets/protocol.py | 51 +++++++++++++++++++-------------------
|
|
1 file changed, 25 insertions(+), 26 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index a1c90916..0bb12fd5 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -287,11 +287,11 @@ def __init__(
|
|
|
|
# asyncio.StreamWriter expects this attribute on the Protocol
|
|
@property
|
|
- def _closed(self) -> asyncio.Future:
|
|
+ def _closed(self) -> Any: # pragma: no cover
|
|
return self.connection_lost_waiter
|
|
|
|
# Copied from asyncio.FlowControlMixin
|
|
- async def _drain_helper(self) -> None:
|
|
+ async def _drain_helper(self) -> None: # pragma: no cover
|
|
if self.connection_lost_waiter.done():
|
|
raise ConnectionResetError("Connection lost")
|
|
if not self._paused:
|
|
@@ -1341,36 +1341,35 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
|
|
# - it must never be canceled.
|
|
self.connection_lost_waiter.set_result(None)
|
|
|
|
- # Copied from asyncio.StreamReaderProtocol
|
|
- if self.reader is not None:
|
|
- if exc is None:
|
|
- self.reader.feed_eof()
|
|
- else:
|
|
- self.reader.set_exception(exc)
|
|
+ if True: # pragma: no cover
|
|
|
|
- # Copied from asyncio.FlowControlMixin
|
|
- # Wake up the writer if currently paused.
|
|
- if not self._paused:
|
|
- return
|
|
- waiter = self._drain_waiter
|
|
- if waiter is None:
|
|
- return
|
|
- self._drain_waiter = None
|
|
- if waiter.done():
|
|
- return
|
|
- if exc is None:
|
|
- waiter.set_result(None)
|
|
- else:
|
|
- waiter.set_exception(exc)
|
|
+ # Copied from asyncio.StreamReaderProtocol
|
|
+ if self.reader is not None:
|
|
+ if exc is None:
|
|
+ self.reader.feed_eof()
|
|
+ else:
|
|
+ self.reader.set_exception(exc)
|
|
|
|
- del self.reader
|
|
- del self.writer
|
|
+ # Copied from asyncio.FlowControlMixin
|
|
+ # Wake up the writer if currently paused.
|
|
+ if not self._paused:
|
|
+ return
|
|
+ waiter = self._drain_waiter
|
|
+ if waiter is None:
|
|
+ return
|
|
+ self._drain_waiter = None
|
|
+ if waiter.done():
|
|
+ return
|
|
+ if exc is None:
|
|
+ waiter.set_result(None)
|
|
+ else:
|
|
+ waiter.set_exception(exc)
|
|
|
|
- def pause_writing(self) -> None:
|
|
+ def pause_writing(self) -> None: # pragma: no cover
|
|
assert not self._paused
|
|
self._paused = True
|
|
|
|
- def resume_writing(self) -> None:
|
|
+ def resume_writing(self) -> None: # pragma: no cover
|
|
assert self._paused
|
|
self._paused = False
|
|
|
|
|
|
From f79cdd8e50e82433bbb0c003279a6b6f13e6c17e Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 13:54:13 +0200
|
|
Subject: [PATCH 12/20] Remove asyncio.StreamWriter.
|
|
|
|
It adds only one method for flow control. Copy it, as we've already
|
|
copied the rest of the flow control implementation.
|
|
---
|
|
src/websockets/client.py | 2 +-
|
|
src/websockets/protocol.py | 64 +++++++++++++++++++++----------------
|
|
src/websockets/server.py | 6 ++--
|
|
tests/test_client_server.py | 2 +-
|
|
tests/test_protocol.py | 18 +++++------
|
|
5 files changed, 51 insertions(+), 41 deletions(-)
|
|
|
|
diff --git a/src/websockets/client.py b/src/websockets/client.py
|
|
index c1fdf88a..34cd8624 100644
|
|
--- a/src/websockets/client.py
|
|
+++ b/src/websockets/client.py
|
|
@@ -85,7 +85,7 @@ def write_http_request(self, path: str, headers: Headers) -> None:
|
|
request = f"GET {path} HTTP/1.1\r\n"
|
|
request += str(headers)
|
|
|
|
- self.writer.write(request.encode())
|
|
+ self.transport.write(request.encode())
|
|
|
|
async def read_http_response(self) -> Tuple[int, Headers]:
|
|
"""
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index 0bb12fd5..eb3d6bcc 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -228,7 +228,8 @@ def __init__(
|
|
self.reader: asyncio.StreamReader = asyncio.StreamReader(
|
|
limit=read_limit // 2, loop=loop
|
|
)
|
|
- self.writer: asyncio.StreamWriter
|
|
+
|
|
+ self.transport: asyncio.Transport
|
|
self._drain_lock = asyncio.Lock(loop=loop)
|
|
|
|
# Copied from asyncio.FlowControlMixin
|
|
@@ -285,11 +286,6 @@ def __init__(
|
|
# Task closing the TCP connection.
|
|
self.close_connection_task: asyncio.Task[None]
|
|
|
|
- # asyncio.StreamWriter expects this attribute on the Protocol
|
|
- @property
|
|
- def _closed(self) -> Any: # pragma: no cover
|
|
- return self.connection_lost_waiter
|
|
-
|
|
# Copied from asyncio.FlowControlMixin
|
|
async def _drain_helper(self) -> None: # pragma: no cover
|
|
if self.connection_lost_waiter.done():
|
|
@@ -302,6 +298,23 @@ def _closed(self) -> Any: # pragma: no cover
|
|
self._drain_waiter = waiter
|
|
await waiter
|
|
|
|
+ # Copied from asyncio.StreamWriter
|
|
+ async def _drain(self) -> None: # pragma: no cover
|
|
+ if self.reader is not None:
|
|
+ exc = self.reader.exception()
|
|
+ if exc is not None:
|
|
+ raise exc
|
|
+ if self.transport is not None:
|
|
+ if self.transport.is_closing():
|
|
+ # Yield to the event loop so connection_lost() may be
|
|
+ # called. Without this, _drain_helper() would return
|
|
+ # immediately, and code that calls
|
|
+ # write(...); yield from drain()
|
|
+ # in a loop would never call connection_lost(), so it
|
|
+ # would not see an error when the socket is closed.
|
|
+ await asyncio.sleep(0)
|
|
+ await self._drain_helper()
|
|
+
|
|
def connection_open(self) -> None:
|
|
"""
|
|
Callback when the WebSocket opening handshake completes.
|
|
@@ -348,9 +361,9 @@ def local_address(self) -> Any:
|
|
been established yet.
|
|
|
|
"""
|
|
- if self.writer is None:
|
|
+ if self.transport is None:
|
|
return None
|
|
- return self.writer.get_extra_info("sockname")
|
|
+ return self.transport.get_extra_info("sockname")
|
|
|
|
@property
|
|
def remote_address(self) -> Any:
|
|
@@ -361,9 +374,9 @@ def remote_address(self) -> Any:
|
|
been established yet.
|
|
|
|
"""
|
|
- if self.writer is None:
|
|
+ if self.transport is None:
|
|
return None
|
|
- return self.writer.get_extra_info("peername")
|
|
+ return self.transport.get_extra_info("peername")
|
|
|
|
@property
|
|
def open(self) -> bool:
|
|
@@ -1037,7 +1050,9 @@ def append(frame: Frame) -> None:
|
|
|
|
frame = Frame(fin, opcode, data)
|
|
logger.debug("%s > %r", self.side, frame)
|
|
- frame.write(self.writer.write, mask=self.is_client, extensions=self.extensions)
|
|
+ frame.write(
|
|
+ self.transport.write, mask=self.is_client, extensions=self.extensions
|
|
+ )
|
|
|
|
try:
|
|
# drain() cannot be called concurrently by multiple coroutines:
|
|
@@ -1045,7 +1060,7 @@ def append(frame: Frame) -> None:
|
|
# version of Python where this bugs exists is supported anymore.
|
|
async with self._drain_lock:
|
|
# Handle flow control automatically.
|
|
- await self.writer.drain()
|
|
+ await self._drain()
|
|
except ConnectionError:
|
|
# Terminate the connection if the socket died.
|
|
self.fail_connection()
|
|
@@ -1147,9 +1162,9 @@ def append(frame: Frame) -> None:
|
|
logger.debug("%s ! timed out waiting for TCP close", self.side)
|
|
|
|
# Half-close the TCP connection if possible (when there's no TLS).
|
|
- if self.writer.can_write_eof():
|
|
+ if self.transport.can_write_eof():
|
|
logger.debug("%s x half-closing TCP connection", self.side)
|
|
- self.writer.write_eof()
|
|
+ self.transport.write_eof()
|
|
|
|
if await self.wait_for_connection_lost():
|
|
return
|
|
@@ -1162,17 +1177,12 @@ def append(frame: Frame) -> None:
|
|
# If connection_lost() was called, the TCP connection is closed.
|
|
# However, if TLS is enabled, the transport still needs closing.
|
|
# Else asyncio complains: ResourceWarning: unclosed transport.
|
|
- try:
|
|
- writer_is_closing = self.writer.is_closing # type: ignore
|
|
- except AttributeError: # pragma: no cover
|
|
- # Python < 3.7
|
|
- writer_is_closing = self.writer.transport.is_closing
|
|
- if self.connection_lost_waiter.done() and writer_is_closing():
|
|
+ if self.connection_lost_waiter.done() and self.transport.is_closing():
|
|
return
|
|
|
|
# Close the TCP connection. Buffers are flushed asynchronously.
|
|
logger.debug("%s x closing TCP connection", self.side)
|
|
- self.writer.close()
|
|
+ self.transport.close()
|
|
|
|
if await self.wait_for_connection_lost():
|
|
return
|
|
@@ -1180,8 +1190,7 @@ def append(frame: Frame) -> None:
|
|
|
|
# Abort the TCP connection. Buffers are discarded.
|
|
logger.debug("%s x aborting TCP connection", self.side)
|
|
- # mypy thinks self.writer.transport is a BaseTransport, not a Transport.
|
|
- self.writer.transport.abort() # type: ignore
|
|
+ self.transport.abort()
|
|
|
|
# connection_lost() is called quickly after aborting.
|
|
await self.wait_for_connection_lost()
|
|
@@ -1261,7 +1270,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None:
|
|
frame = Frame(True, OP_CLOSE, frame_data)
|
|
logger.debug("%s > %r", self.side, frame)
|
|
frame.write(
|
|
- self.writer.write, mask=self.is_client, extensions=self.extensions
|
|
+ self.transport.write, mask=self.is_client, extensions=self.extensions
|
|
)
|
|
|
|
# Start close_connection_task if the opening handshake didn't succeed.
|
|
@@ -1310,12 +1319,13 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
|
|
"""
|
|
logger.debug("%s - event = connection_made(%s)", self.side, transport)
|
|
- # mypy thinks transport is a BaseTransport, not a Transport.
|
|
- transport.set_write_buffer_limits(self.write_limit) # type: ignore
|
|
+
|
|
+ transport = cast(asyncio.Transport, transport)
|
|
+ transport.set_write_buffer_limits(self.write_limit)
|
|
+ self.transport = transport
|
|
|
|
# Copied from asyncio.StreamReaderProtocol
|
|
self.reader.set_transport(transport)
|
|
- self.writer = asyncio.StreamWriter(transport, self, self.reader, self.loop)
|
|
|
|
def connection_lost(self, exc: Optional[Exception]) -> None:
|
|
"""
|
|
diff --git a/src/websockets/server.py b/src/websockets/server.py
|
|
index b220a1b8..1e8ae861 100644
|
|
--- a/src/websockets/server.py
|
|
+++ b/src/websockets/server.py
|
|
@@ -211,7 +211,7 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
except Exception:
|
|
# Last-ditch attempt to avoid leaking connections on errors.
|
|
try:
|
|
- self.writer.close()
|
|
+ self.transport.close()
|
|
except Exception: # pragma: no cover
|
|
pass
|
|
|
|
@@ -265,11 +265,11 @@ def write_http_response(
|
|
response = f"HTTP/1.1 {status.value} {status.phrase}\r\n"
|
|
response += str(headers)
|
|
|
|
- self.writer.write(response.encode())
|
|
+ self.transport.write(response.encode())
|
|
|
|
if body is not None:
|
|
logger.debug("%s > body (%d bytes)", self.side, len(body))
|
|
- self.writer.write(body)
|
|
+ self.transport.write(body)
|
|
|
|
async def process_request(
|
|
self, path: str, request_headers: Headers
|
|
diff --git a/tests/test_client_server.py b/tests/test_client_server.py
|
|
index e74ec6bf..6171f21b 100644
|
|
--- a/tests/test_client_server.py
|
|
+++ b/tests/test_client_server.py
|
|
@@ -1166,7 +1166,7 @@ def test_server_close_crashes(self, close):
|
|
def test_client_closes_connection_before_handshake(self, handshake):
|
|
# We have mocked the handshake() method to prevent the client from
|
|
# performing the opening handshake. Force it to close the connection.
|
|
- self.client.writer.close()
|
|
+ self.client.transport.close()
|
|
# The server should stop properly anyway. It used to hang because the
|
|
# task handling the connection was waiting for the opening handshake.
|
|
|
|
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
|
|
index a6c42018..dfc2c6d4 100644
|
|
--- a/tests/test_protocol.py
|
|
+++ b/tests/test_protocol.py
|
|
@@ -94,16 +94,16 @@ def tearDown(self):
|
|
# Utilities for writing tests.
|
|
|
|
def make_drain_slow(self, delay=MS):
|
|
- # Process connection_made in order to initialize self.protocol.writer.
|
|
+ # Process connection_made in order to initialize self.protocol.transport.
|
|
self.run_loop_once()
|
|
|
|
- original_drain = self.protocol.writer.drain
|
|
+ original_drain = self.protocol._drain
|
|
|
|
async def delayed_drain():
|
|
await asyncio.sleep(delay, loop=self.loop)
|
|
await original_drain()
|
|
|
|
- self.protocol.writer.drain = delayed_drain
|
|
+ self.protocol._drain = delayed_drain
|
|
|
|
close_frame = Frame(True, OP_CLOSE, serialize_close(1000, "close"))
|
|
local_close = Frame(True, OP_CLOSE, serialize_close(1000, "local"))
|
|
@@ -321,32 +321,32 @@ def test_local_address(self):
|
|
self.transport.get_extra_info = get_extra_info
|
|
|
|
self.assertEqual(self.protocol.local_address, ("host", 4312))
|
|
- get_extra_info.assert_called_with("sockname", None)
|
|
+ get_extra_info.assert_called_with("sockname")
|
|
|
|
def test_local_address_before_connection(self):
|
|
# Emulate the situation before connection_open() runs.
|
|
- self.protocol.writer, _writer = None, self.protocol.writer
|
|
+ self.protocol.transport, _transport = None, self.protocol.transport
|
|
|
|
try:
|
|
self.assertEqual(self.protocol.local_address, None)
|
|
finally:
|
|
- self.protocol.writer = _writer
|
|
+ self.protocol.transport = _transport
|
|
|
|
def test_remote_address(self):
|
|
get_extra_info = unittest.mock.Mock(return_value=("host", 4312))
|
|
self.transport.get_extra_info = get_extra_info
|
|
|
|
self.assertEqual(self.protocol.remote_address, ("host", 4312))
|
|
- get_extra_info.assert_called_with("peername", None)
|
|
+ get_extra_info.assert_called_with("peername")
|
|
|
|
def test_remote_address_before_connection(self):
|
|
# Emulate the situation before connection_open() runs.
|
|
- self.protocol.writer, _writer = None, self.protocol.writer
|
|
+ self.protocol.transport, _transport = None, self.protocol.transport
|
|
|
|
try:
|
|
self.assertEqual(self.protocol.remote_address, None)
|
|
finally:
|
|
- self.protocol.writer = _writer
|
|
+ self.protocol.transport = _transport
|
|
|
|
def test_open(self):
|
|
self.assertTrue(self.protocol.open)
|
|
|
|
From c55f374a93f7e9b14bc2bdcc8868f38f8a834a87 Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 13:55:16 +0200
|
|
Subject: [PATCH 13/20] Rename writer to write.
|
|
|
|
It's a better name for a function that writes bytes.
|
|
---
|
|
src/websockets/framing.py | 8 ++++----
|
|
tests/test_framing.py | 14 +++++++-------
|
|
tests/test_protocol.py | 4 ++--
|
|
3 files changed, 13 insertions(+), 13 deletions(-)
|
|
|
|
diff --git a/src/websockets/framing.py b/src/websockets/framing.py
|
|
index 81a3185b..c24b8a73 100644
|
|
--- a/src/websockets/framing.py
|
|
+++ b/src/websockets/framing.py
|
|
@@ -147,7 +147,7 @@ class Frame(NamedTuple):
|
|
|
|
def write(
|
|
frame,
|
|
- writer: Callable[[bytes], Any],
|
|
+ write: Callable[[bytes], Any],
|
|
*,
|
|
mask: bool,
|
|
extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None,
|
|
@@ -156,7 +156,7 @@ def write(
|
|
Write a WebSocket frame.
|
|
|
|
:param frame: frame to write
|
|
- :param writer: function that writes bytes
|
|
+ :param write: function that writes bytes
|
|
:param mask: whether the frame should be masked i.e. whether the write
|
|
happens on the client side
|
|
:param extensions: list of classes with an ``encode()`` method that
|
|
@@ -210,10 +210,10 @@ def write(
|
|
|
|
# Send the frame.
|
|
|
|
- # The frame is written in a single call to writer in order to prevent
|
|
+ # The frame is written in a single call to write in order to prevent
|
|
# TCP fragmentation. See #68 for details. This also makes it safe to
|
|
# send frames concurrently from multiple coroutines.
|
|
- writer(output.getvalue())
|
|
+ write(output.getvalue())
|
|
|
|
def check(frame) -> None:
|
|
"""
|
|
diff --git a/tests/test_framing.py b/tests/test_framing.py
|
|
index 9e6f1871..5def415d 100644
|
|
--- a/tests/test_framing.py
|
|
+++ b/tests/test_framing.py
|
|
@@ -27,15 +27,15 @@ def decode(self, message, mask=False, max_size=None, extensions=None):
|
|
return frame
|
|
|
|
def encode(self, frame, mask=False, extensions=None):
|
|
- writer = unittest.mock.Mock()
|
|
- frame.write(writer, mask=mask, extensions=extensions)
|
|
- # Ensure the entire frame is sent with a single call to writer().
|
|
+ write = unittest.mock.Mock()
|
|
+ frame.write(write, mask=mask, extensions=extensions)
|
|
+ # Ensure the entire frame is sent with a single call to write().
|
|
# Multiple calls cause TCP fragmentation and degrade performance.
|
|
- self.assertEqual(writer.call_count, 1)
|
|
+ self.assertEqual(write.call_count, 1)
|
|
# The frame data is the single positional argument of that call.
|
|
- self.assertEqual(len(writer.call_args[0]), 1)
|
|
- self.assertEqual(len(writer.call_args[1]), 0)
|
|
- return writer.call_args[0][0]
|
|
+ self.assertEqual(len(write.call_args[0]), 1)
|
|
+ self.assertEqual(len(write.call_args[1]), 0)
|
|
+ return write.call_args[0][0]
|
|
|
|
def round_trip(self, message, expected, mask=False, extensions=None):
|
|
decoded = self.decode(message, mask, extensions=extensions)
|
|
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
|
|
index dfc2c6d4..d2793faf 100644
|
|
--- a/tests/test_protocol.py
|
|
+++ b/tests/test_protocol.py
|
|
@@ -114,9 +114,9 @@ def receive_frame(self, frame):
|
|
Make the protocol receive a frame.
|
|
|
|
"""
|
|
- writer = self.protocol.data_received
|
|
+ write = self.protocol.data_received
|
|
mask = not self.protocol.is_client
|
|
- frame.write(writer, mask=mask)
|
|
+ frame.write(write, mask=mask)
|
|
|
|
def receive_eof(self):
|
|
"""
|
|
|
|
From 3e03d1a19dbbf54e1c714c933b90ec283caa7bb0 Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 15:49:08 +0200
|
|
Subject: [PATCH 14/20] Update to the latest version of mypy.
|
|
|
|
The bugs that were locking us on an old version are fixed.
|
|
---
|
|
src/websockets/__init__.py | 15 ++++++++-------
|
|
src/websockets/__main__.py | 8 ++++----
|
|
src/websockets/client.py | 3 +--
|
|
src/websockets/handshake.py | 15 ++++++++++-----
|
|
src/websockets/protocol.py | 14 +++++++++-----
|
|
src/websockets/server.py | 11 +++--------
|
|
tox.ini | 2 +-
|
|
7 files changed, 36 insertions(+), 32 deletions(-)
|
|
|
|
diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py
|
|
index e7ba31ce..6bad0f7b 100644
|
|
--- a/src/websockets/__init__.py
|
|
+++ b/src/websockets/__init__.py
|
|
@@ -1,12 +1,13 @@
|
|
# This relies on each of the submodules having an __all__ variable.
|
|
|
|
-from .auth import *
|
|
-from .client import *
|
|
-from .exceptions import *
|
|
-from .protocol import *
|
|
-from .server import *
|
|
-from .typing import *
|
|
-from .uri import *
|
|
+from . import auth, client, exceptions, protocol, server, typing, uri
|
|
+from .auth import * # noqa
|
|
+from .client import * # noqa
|
|
+from .exceptions import * # noqa
|
|
+from .protocol import * # noqa
|
|
+from .server import * # noqa
|
|
+from .typing import * # noqa
|
|
+from .uri import * # noqa
|
|
from .version import version as __version__ # noqa
|
|
|
|
|
|
diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py
|
|
index bccb8aa5..394f7ac7 100644
|
|
--- a/src/websockets/__main__.py
|
|
+++ b/src/websockets/__main__.py
|
|
@@ -6,8 +6,8 @@
|
|
import threading
|
|
from typing import Any, Set
|
|
|
|
-import websockets
|
|
-from websockets.exceptions import format_close
|
|
+from .client import connect
|
|
+from .exceptions import ConnectionClosed, format_close
|
|
|
|
|
|
if sys.platform == "win32":
|
|
@@ -95,7 +95,7 @@ def print_over_input(string: str) -> None:
|
|
stop: "asyncio.Future[None]",
|
|
) -> None:
|
|
try:
|
|
- websocket = await websockets.connect(uri)
|
|
+ websocket = await connect(uri)
|
|
except Exception as exc:
|
|
print_over_input(f"Failed to connect to {uri}: {exc}.")
|
|
exit_from_event_loop_thread(loop, stop)
|
|
@@ -122,7 +122,7 @@ def print_over_input(string: str) -> None:
|
|
if incoming in done:
|
|
try:
|
|
message = incoming.result()
|
|
- except websockets.ConnectionClosed:
|
|
+ except ConnectionClosed:
|
|
break
|
|
else:
|
|
if isinstance(message, str):
|
|
diff --git a/src/websockets/client.py b/src/websockets/client.py
|
|
index 34cd8624..725ec1e7 100644
|
|
--- a/src/websockets/client.py
|
|
+++ b/src/websockets/client.py
|
|
@@ -24,7 +24,6 @@
|
|
from .extensions.permessage_deflate import ClientPerMessageDeflateFactory
|
|
from .handshake import build_request, check_response
|
|
from .headers import (
|
|
- ExtensionHeader,
|
|
build_authorization_basic,
|
|
build_extension,
|
|
build_subprotocol,
|
|
@@ -33,7 +32,7 @@
|
|
)
|
|
from .http import USER_AGENT, Headers, HeadersLike, read_response
|
|
from .protocol import WebSocketCommonProtocol
|
|
-from .typing import Origin, Subprotocol
|
|
+from .typing import ExtensionHeader, Origin, Subprotocol
|
|
from .uri import WebSocketURI, parse_uri
|
|
|
|
|
|
diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py
|
|
index 17332d15..9bfe2775 100644
|
|
--- a/src/websockets/handshake.py
|
|
+++ b/src/websockets/handshake.py
|
|
@@ -29,9 +29,10 @@
|
|
import binascii
|
|
import hashlib
|
|
import random
|
|
+from typing import List
|
|
|
|
from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade
|
|
-from .headers import parse_connection, parse_upgrade
|
|
+from .headers import ConnectionOption, UpgradeProtocol, parse_connection, parse_upgrade
|
|
from .http import Headers, MultipleValuesError
|
|
|
|
|
|
@@ -74,14 +75,16 @@ def check_request(headers: Headers) -> str:
|
|
is invalid; then the server must return 400 Bad Request error
|
|
|
|
"""
|
|
- connection = sum(
|
|
+ connection: List[ConnectionOption] = sum(
|
|
[parse_connection(value) for value in headers.get_all("Connection")], []
|
|
)
|
|
|
|
if not any(value.lower() == "upgrade" for value in connection):
|
|
raise InvalidUpgrade("Connection", ", ".join(connection))
|
|
|
|
- upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], [])
|
|
+ upgrade: List[UpgradeProtocol] = sum(
|
|
+ [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
|
+ )
|
|
|
|
# For compatibility with non-strict implementations, ignore case when
|
|
# checking the Upgrade header. It's supposed to be 'WebSocket'.
|
|
@@ -148,14 +151,16 @@ def check_response(headers: Headers, key: str) -> None:
|
|
is invalid
|
|
|
|
"""
|
|
- connection = sum(
|
|
+ connection: List[ConnectionOption] = sum(
|
|
[parse_connection(value) for value in headers.get_all("Connection")], []
|
|
)
|
|
|
|
if not any(value.lower() == "upgrade" for value in connection):
|
|
raise InvalidUpgrade("Connection", " ".join(connection))
|
|
|
|
- upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], [])
|
|
+ upgrade: List[UpgradeProtocol] = sum(
|
|
+ [parse_upgrade(value) for value in headers.get_all("Upgrade")], []
|
|
+ )
|
|
|
|
# For compatibility with non-strict implementations, ignore case when
|
|
# checking the Upgrade header. It's supposed to be 'WebSocket'.
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index eb3d6bcc..b7c1f19c 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -601,10 +601,14 @@ def closed(self) -> bool:
|
|
|
|
elif isinstance(message, AsyncIterable):
|
|
# aiter_message = aiter(message) without aiter
|
|
- aiter_message = type(message).__aiter__(message)
|
|
+ # https://github.com/python/mypy/issues/5738
|
|
+ aiter_message = type(message).__aiter__(message) # type: ignore
|
|
try:
|
|
# message_chunk = anext(aiter_message) without anext
|
|
- message_chunk = await type(aiter_message).__anext__(aiter_message)
|
|
+ # https://github.com/python/mypy/issues/5738
|
|
+ message_chunk = await type(aiter_message).__anext__( # type: ignore
|
|
+ aiter_message
|
|
+ )
|
|
except StopAsyncIteration:
|
|
return
|
|
opcode, data = prepare_data(message_chunk)
|
|
@@ -615,7 +619,8 @@ def closed(self) -> bool:
|
|
await self.write_frame(False, opcode, data)
|
|
|
|
# Other fragments.
|
|
- async for message_chunk in aiter_message:
|
|
+ # https://github.com/python/mypy/issues/5738
|
|
+ async for message_chunk in aiter_message: # type: ignore
|
|
confirm_opcode, data = prepare_data(message_chunk)
|
|
if confirm_opcode != opcode:
|
|
raise TypeError("data contains inconsistent types")
|
|
@@ -899,8 +904,7 @@ def connection_closed_exc(self) -> ConnectionClosed:
|
|
max_size = self.max_size
|
|
if text:
|
|
decoder_factory = codecs.getincrementaldecoder("utf-8")
|
|
- # https://github.com/python/typeshed/pull/2752
|
|
- decoder = decoder_factory(errors="strict") # type: ignore
|
|
+ decoder = decoder_factory(errors="strict")
|
|
if max_size is None:
|
|
|
|
def append(frame: Frame) -> None:
|
|
diff --git a/src/websockets/server.py b/src/websockets/server.py
|
|
index 1e8ae861..5114646d 100644
|
|
--- a/src/websockets/server.py
|
|
+++ b/src/websockets/server.py
|
|
@@ -39,15 +39,10 @@
|
|
from .extensions.base import Extension, ServerExtensionFactory
|
|
from .extensions.permessage_deflate import ServerPerMessageDeflateFactory
|
|
from .handshake import build_response, check_request
|
|
-from .headers import (
|
|
- ExtensionHeader,
|
|
- build_extension,
|
|
- parse_extension,
|
|
- parse_subprotocol,
|
|
-)
|
|
+from .headers import build_extension, parse_extension, parse_subprotocol
|
|
from .http import USER_AGENT, Headers, HeadersLike, MultipleValuesError, read_request
|
|
from .protocol import WebSocketCommonProtocol
|
|
-from .typing import Origin, Subprotocol
|
|
+from .typing import ExtensionHeader, Origin, Subprotocol
|
|
|
|
|
|
__all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"]
|
|
@@ -662,7 +657,7 @@ def is_serving(self) -> bool:
|
|
"""
|
|
try:
|
|
# Python ≥ 3.7
|
|
- return self.server.is_serving() # type: ignore
|
|
+ return self.server.is_serving()
|
|
except AttributeError: # pragma: no cover
|
|
# Python < 3.7
|
|
return self.server.sockets is not None
|
|
diff --git a/tox.ini b/tox.ini
|
|
index 801d4d5d..7397c90a 100644
|
|
--- a/tox.ini
|
|
+++ b/tox.ini
|
|
@@ -25,4 +25,4 @@ deps = isort
|
|
|
|
[testenv:mypy]
|
|
commands = mypy --strict src
|
|
-deps = mypy==0.670
|
|
+deps = mypy
|
|
|
|
From aef1e8781826653720f4db110a4d2af4418fa5b8 Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 15:49:51 +0200
|
|
Subject: [PATCH 15/20] Fix deprecation warnings on Python 3.8.
|
|
|
|
* Don't pass the deprecated loop argument.
|
|
* Ignore deprecation warnings for @asyncio.coroutine.
|
|
---
|
|
src/websockets/protocol.py | 28 ++++++++++++++++++++--------
|
|
src/websockets/server.py | 10 +++++++---
|
|
tests/test_client_server.py | 36 +++++++++++++++++++++---------------
|
|
tests/test_protocol.py | 5 ++++-
|
|
4 files changed, 52 insertions(+), 27 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index b7c1f19c..76d46ad9 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -14,6 +14,7 @@
|
|
import logging
|
|
import random
|
|
import struct
|
|
+import sys
|
|
import warnings
|
|
from typing import (
|
|
Any,
|
|
@@ -230,7 +231,9 @@ def __init__(
|
|
)
|
|
|
|
self.transport: asyncio.Transport
|
|
- self._drain_lock = asyncio.Lock(loop=loop)
|
|
+ self._drain_lock = asyncio.Lock(
|
|
+ loop=loop if sys.version_info[:2] < (3, 8) else None
|
|
+ )
|
|
|
|
# Copied from asyncio.FlowControlMixin
|
|
self._paused = False
|
|
@@ -312,7 +315,9 @@ def __init__(
|
|
# write(...); yield from drain()
|
|
# in a loop would never call connection_lost(), so it
|
|
# would not see an error when the socket is closed.
|
|
- await asyncio.sleep(0)
|
|
+ await asyncio.sleep(
|
|
+ 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None
|
|
+ )
|
|
await self._drain_helper()
|
|
|
|
def connection_open(self) -> None:
|
|
@@ -483,7 +488,7 @@ def closed(self) -> bool:
|
|
# pop_message_waiter and self.transfer_data_task.
|
|
await asyncio.wait(
|
|
[pop_message_waiter, self.transfer_data_task],
|
|
- loop=self.loop,
|
|
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
finally:
|
|
@@ -668,7 +673,7 @@ def closed(self) -> bool:
|
|
await asyncio.wait_for(
|
|
self.write_close_frame(serialize_close(code, reason)),
|
|
self.close_timeout,
|
|
- loop=self.loop,
|
|
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
# If the close frame cannot be sent because the send buffers
|
|
@@ -687,7 +692,9 @@ def closed(self) -> bool:
|
|
# If close() is canceled during the wait, self.transfer_data_task
|
|
# is canceled before the timeout elapses.
|
|
await asyncio.wait_for(
|
|
- self.transfer_data_task, self.close_timeout, loop=self.loop
|
|
+ self.transfer_data_task,
|
|
+ self.close_timeout,
|
|
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
|
|
)
|
|
except (asyncio.TimeoutError, asyncio.CancelledError):
|
|
pass
|
|
@@ -1106,7 +1113,10 @@ def append(frame: Frame) -> None:
|
|
|
|
try:
|
|
while True:
|
|
- await asyncio.sleep(self.ping_interval, loop=self.loop)
|
|
+ await asyncio.sleep(
|
|
+ self.ping_interval,
|
|
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
|
|
+ )
|
|
|
|
# ping() raises CancelledError if the connection is closed,
|
|
# when close_connection() cancels self.keepalive_ping_task.
|
|
@@ -1119,7 +1129,9 @@ def append(frame: Frame) -> None:
|
|
if self.ping_timeout is not None:
|
|
try:
|
|
await asyncio.wait_for(
|
|
- ping_waiter, self.ping_timeout, loop=self.loop
|
|
+ ping_waiter,
|
|
+ self.ping_timeout,
|
|
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logger.debug("%s ! timed out waiting for pong", self.side)
|
|
@@ -1211,7 +1223,7 @@ def append(frame: Frame) -> None:
|
|
await asyncio.wait_for(
|
|
asyncio.shield(self.connection_lost_waiter),
|
|
self.close_timeout,
|
|
- loop=self.loop,
|
|
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
diff --git a/src/websockets/server.py b/src/websockets/server.py
|
|
index 5114646d..4f5e9e0e 100644
|
|
--- a/src/websockets/server.py
|
|
+++ b/src/websockets/server.py
|
|
@@ -10,6 +10,7 @@
|
|
import http
|
|
import logging
|
|
import socket
|
|
+import sys
|
|
import warnings
|
|
from types import TracebackType
|
|
from typing import (
|
|
@@ -698,7 +699,9 @@ def close(self) -> None:
|
|
|
|
# Wait until all accepted connections reach connection_made() and call
|
|
# register(). See https://bugs.python.org/issue34852 for details.
|
|
- await asyncio.sleep(0)
|
|
+ await asyncio.sleep(
|
|
+ 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None
|
|
+ )
|
|
|
|
# Close OPEN connections with status code 1001. Since the server was
|
|
# closed, handshake() closes OPENING conections with a HTTP 503 error.
|
|
@@ -707,7 +710,8 @@ def close(self) -> None:
|
|
# asyncio.wait doesn't accept an empty first argument
|
|
if self.websockets:
|
|
await asyncio.wait(
|
|
- [websocket.close(1001) for websocket in self.websockets], loop=self.loop
|
|
+ [websocket.close(1001) for websocket in self.websockets],
|
|
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
|
|
)
|
|
|
|
# Wait until all connection handlers are complete.
|
|
@@ -716,7 +720,7 @@ def close(self) -> None:
|
|
if self.websockets:
|
|
await asyncio.wait(
|
|
[websocket.handler_task for websocket in self.websockets],
|
|
- loop=self.loop,
|
|
+ loop=self.loop if sys.version_info[:2] < (3, 8) else None,
|
|
)
|
|
|
|
# Tell wait_closed() to return.
|
|
diff --git a/tests/test_client_server.py b/tests/test_client_server.py
|
|
index 6171f21b..85828bdb 100644
|
|
--- a/tests/test_client_server.py
|
|
+++ b/tests/test_client_server.py
|
|
@@ -1381,13 +1381,16 @@ def test_client(self):
|
|
start_server = serve(handler, "localhost", 0)
|
|
server = self.loop.run_until_complete(start_server)
|
|
|
|
- @asyncio.coroutine
|
|
- def run_client():
|
|
- # Yield from connect.
|
|
- client = yield from connect(get_server_uri(server))
|
|
- self.assertEqual(client.state, State.OPEN)
|
|
- yield from client.close()
|
|
- self.assertEqual(client.state, State.CLOSED)
|
|
+ # @asyncio.coroutine is deprecated on Python ≥ 3.8
|
|
+ with warnings.catch_warnings(record=True):
|
|
+
|
|
+ @asyncio.coroutine
|
|
+ def run_client():
|
|
+ # Yield from connect.
|
|
+ client = yield from connect(get_server_uri(server))
|
|
+ self.assertEqual(client.state, State.OPEN)
|
|
+ yield from client.close()
|
|
+ self.assertEqual(client.state, State.CLOSED)
|
|
|
|
self.loop.run_until_complete(run_client())
|
|
|
|
@@ -1395,14 +1398,17 @@ def run_client():
|
|
self.loop.run_until_complete(server.wait_closed())
|
|
|
|
def test_server(self):
|
|
- @asyncio.coroutine
|
|
- def run_server():
|
|
- # Yield from serve.
|
|
- server = yield from serve(handler, "localhost", 0)
|
|
- self.assertTrue(server.sockets)
|
|
- server.close()
|
|
- yield from server.wait_closed()
|
|
- self.assertFalse(server.sockets)
|
|
+ # @asyncio.coroutine is deprecated on Python ≥ 3.8
|
|
+ with warnings.catch_warnings(record=True):
|
|
+
|
|
+ @asyncio.coroutine
|
|
+ def run_server():
|
|
+ # Yield from serve.
|
|
+ server = yield from serve(handler, "localhost", 0)
|
|
+ self.assertTrue(server.sockets)
|
|
+ server.close()
|
|
+ yield from server.wait_closed()
|
|
+ self.assertFalse(server.sockets)
|
|
|
|
self.loop.run_until_complete(run_server())
|
|
|
|
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
|
|
index d2793faf..04e2a38f 100644
|
|
--- a/tests/test_protocol.py
|
|
+++ b/tests/test_protocol.py
|
|
@@ -1,6 +1,7 @@
|
|
import asyncio
|
|
import contextlib
|
|
import logging
|
|
+import sys
|
|
import unittest
|
|
import unittest.mock
|
|
import warnings
|
|
@@ -100,7 +101,9 @@ def make_drain_slow(self, delay=MS):
|
|
original_drain = self.protocol._drain
|
|
|
|
async def delayed_drain():
|
|
- await asyncio.sleep(delay, loop=self.loop)
|
|
+ await asyncio.sleep(
|
|
+ delay, loop=self.loop if sys.version_info[:2] < (3, 8) else None
|
|
+ )
|
|
await original_drain()
|
|
|
|
self.protocol._drain = delayed_drain
|
|
|
|
From 504c66cf01bc7e2c4f5fbcceb1387cadf056a678 Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 14:02:51 +0200
|
|
Subject: [PATCH 16/20] Document and test support for Python 3.8.
|
|
|
|
---
|
|
.circleci/config.yml | 12 ++++++++++++
|
|
docs/changelog.rst | 2 ++
|
|
setup.py | 1 +
|
|
tox.ini | 2 +-
|
|
4 files changed, 16 insertions(+), 1 deletion(-)
|
|
|
|
diff --git a/.circleci/config.yml b/.circleci/config.yml
|
|
index a6c85d23..0877c161 100644
|
|
--- a/.circleci/config.yml
|
|
+++ b/.circleci/config.yml
|
|
@@ -29,6 +29,15 @@ jobs:
|
|
- checkout
|
|
- run: sudo pip install tox
|
|
- run: tox -e py37
|
|
+ py38:
|
|
+ docker:
|
|
+ - image: circleci/python:3.8.0rc1
|
|
+ steps:
|
|
+ # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway.
|
|
+ - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc
|
|
+ - checkout
|
|
+ - run: sudo pip install tox
|
|
+ - run: tox -e py38
|
|
|
|
workflows:
|
|
version: 2
|
|
@@ -41,3 +50,6 @@ workflows:
|
|
- py37:
|
|
requires:
|
|
- main
|
|
+ - py38:
|
|
+ requires:
|
|
+ - main
|
|
diff --git a/docs/changelog.rst b/docs/changelog.rst
|
|
index 87b2e438..2a106fbc 100644
|
|
--- a/docs/changelog.rst
|
|
+++ b/docs/changelog.rst
|
|
@@ -8,6 +8,8 @@ Changelog
|
|
|
|
*In development*
|
|
|
|
+* Added compatibility with Python 3.8.
|
|
+
|
|
8.0.2
|
|
.....
|
|
|
|
diff --git a/setup.py b/setup.py
|
|
index c7643010..f3581924 100644
|
|
--- a/setup.py
|
|
+++ b/setup.py
|
|
@@ -53,6 +53,7 @@
|
|
'Programming Language :: Python :: 3',
|
|
'Programming Language :: Python :: 3.6',
|
|
'Programming Language :: Python :: 3.7',
|
|
+ 'Programming Language :: Python :: 3.8',
|
|
],
|
|
package_dir = {'': 'src'},
|
|
package_data = {'websockets': ['py.typed']},
|
|
diff --git a/tox.ini b/tox.ini
|
|
index 7397c90a..825e3406 100644
|
|
--- a/tox.ini
|
|
+++ b/tox.ini
|
|
@@ -1,5 +1,5 @@
|
|
[tox]
|
|
-envlist = py36,py37,coverage,black,flake8,isort,mypy
|
|
+envlist = py36,py37,py38,coverage,black,flake8,isort,mypy
|
|
|
|
[testenv]
|
|
commands = python -W default -m unittest {posargs}
|
|
|
|
From fc99c16a53e5d0c85448714021a5deb74021920a Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 21:18:30 +0200
|
|
Subject: [PATCH 17/20] Move test logging configuration to a single place.
|
|
|
|
---
|
|
tests/__init__.py | 5 +++++
|
|
tests/test_client_server.py | 5 -----
|
|
tests/test_protocol.py | 5 -----
|
|
3 files changed, 5 insertions(+), 10 deletions(-)
|
|
|
|
diff --git a/tests/__init__.py b/tests/__init__.py
|
|
index e69de29b..dd78609f 100644
|
|
--- a/tests/__init__.py
|
|
+++ b/tests/__init__.py
|
|
@@ -0,0 +1,5 @@
|
|
+import logging
|
|
+
|
|
+
|
|
+# Avoid displaying stack traces at the ERROR logging level.
|
|
+logging.basicConfig(level=logging.CRITICAL)
|
|
diff --git a/tests/test_client_server.py b/tests/test_client_server.py
|
|
index 85828bdb..ce0f66ce 100644
|
|
--- a/tests/test_client_server.py
|
|
+++ b/tests/test_client_server.py
|
|
@@ -2,7 +2,6 @@
|
|
import contextlib
|
|
import functools
|
|
import http
|
|
-import logging
|
|
import pathlib
|
|
import random
|
|
import socket
|
|
@@ -37,10 +36,6 @@
|
|
from .utils import AsyncioTestCase
|
|
|
|
|
|
-# Avoid displaying stack traces at the ERROR logging level.
|
|
-logging.basicConfig(level=logging.CRITICAL)
|
|
-
|
|
-
|
|
# Generate TLS certificate with:
|
|
# $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \
|
|
# -out test_localhost.crt -keyout test_localhost.key
|
|
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
|
|
index 04e2a38f..d95260a8 100644
|
|
--- a/tests/test_protocol.py
|
|
+++ b/tests/test_protocol.py
|
|
@@ -1,6 +1,5 @@
|
|
import asyncio
|
|
import contextlib
|
|
-import logging
|
|
import sys
|
|
import unittest
|
|
import unittest.mock
|
|
@@ -13,10 +12,6 @@
|
|
from .utils import MS, AsyncioTestCase
|
|
|
|
|
|
-# Avoid displaying stack traces at the ERROR logging level.
|
|
-logging.basicConfig(level=logging.CRITICAL)
|
|
-
|
|
-
|
|
async def async_iterable(iterable):
|
|
for item in iterable:
|
|
yield item
|
|
|
|
From 91dea74a9430d8573f7fd58c35fd731eddbb6cdb Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 21:33:27 +0200
|
|
Subject: [PATCH 18/20] Remove test that no longer makes sense.
|
|
|
|
Since version 7.0, when the server closes, it terminates connections
|
|
with close code 1001 instead of canceling them.
|
|
---
|
|
tests/test_client_server.py | 18 ++----------------
|
|
1 file changed, 2 insertions(+), 16 deletions(-)
|
|
|
|
diff --git a/tests/test_client_server.py b/tests/test_client_server.py
|
|
index ce0f66ce..35913666 100644
|
|
--- a/tests/test_client_server.py
|
|
+++ b/tests/test_client_server.py
|
|
@@ -173,7 +173,7 @@ class HealthCheckServerProtocol(WebSocketServerProtocol):
|
|
return http.HTTPStatus.OK, [("X-Access", "OK")], b"status = green\n"
|
|
|
|
|
|
-class SlowServerProtocol(WebSocketServerProtocol):
|
|
+class SlowOpeningHandshakeProtocol(WebSocketServerProtocol):
|
|
async def process_request(self, path, request_headers):
|
|
await asyncio.sleep(10 * MS)
|
|
|
|
@@ -1165,7 +1165,7 @@ def test_client_closes_connection_before_handshake(self, handshake):
|
|
# The server should stop properly anyway. It used to hang because the
|
|
# task handling the connection was waiting for the opening handshake.
|
|
|
|
- @with_server(create_protocol=SlowServerProtocol)
|
|
+ @with_server(create_protocol=SlowOpeningHandshakeProtocol)
|
|
def test_server_shuts_down_during_opening_handshake(self):
|
|
self.loop.call_later(5 * MS, self.server.close)
|
|
with self.assertRaises(InvalidStatusCode) as raised:
|
|
@@ -1188,20 +1188,6 @@ def test_server_shuts_down_during_connection_handling(self):
|
|
self.assertEqual(self.client.close_code, 1001)
|
|
self.assertEqual(server_ws.close_code, 1001)
|
|
|
|
- @with_server()
|
|
- @unittest.mock.patch("websockets.server.WebSocketServerProtocol.close")
|
|
- def test_server_shuts_down_during_connection_close(self, _close):
|
|
- _close.side_effect = asyncio.CancelledError
|
|
-
|
|
- self.server.closing = True
|
|
- with self.temp_client():
|
|
- self.loop.run_until_complete(self.client.send("Hello!"))
|
|
- reply = self.loop.run_until_complete(self.client.recv())
|
|
- self.assertEqual(reply, "Hello!")
|
|
-
|
|
- # Websocket connection terminates abnormally.
|
|
- self.assertEqual(self.client.close_code, 1006)
|
|
-
|
|
@with_server()
|
|
def test_server_shuts_down_waits_until_handlers_terminate(self):
|
|
# This handler waits a bit after the connection is closed in order
|
|
|
|
From 88073fe7da7e8c07a5f46ae7c77e1a4ef80736ee Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 21:41:13 +0200
|
|
Subject: [PATCH 19/20] Fix refactoring error.
|
|
|
|
WebSocketCommonProtocol.transport can be unset, but it cannot be None.
|
|
---
|
|
src/websockets/protocol.py | 14 ++++++++++----
|
|
tests/test_protocol.py | 8 ++++----
|
|
2 files changed, 14 insertions(+), 8 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index 76d46ad9..0623e136 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -366,9 +366,12 @@ def local_address(self) -> Any:
|
|
been established yet.
|
|
|
|
"""
|
|
- if self.transport is None:
|
|
+ try:
|
|
+ transport = self.transport
|
|
+ except AttributeError:
|
|
return None
|
|
- return self.transport.get_extra_info("sockname")
|
|
+ else:
|
|
+ return transport.get_extra_info("sockname")
|
|
|
|
@property
|
|
def remote_address(self) -> Any:
|
|
@@ -379,9 +382,12 @@ def remote_address(self) -> Any:
|
|
been established yet.
|
|
|
|
"""
|
|
- if self.transport is None:
|
|
+ try:
|
|
+ transport = self.transport
|
|
+ except AttributeError:
|
|
return None
|
|
- return self.transport.get_extra_info("peername")
|
|
+ else:
|
|
+ return transport.get_extra_info("peername")
|
|
|
|
@property
|
|
def open(self) -> bool:
|
|
diff --git a/tests/test_protocol.py b/tests/test_protocol.py
|
|
index d95260a8..66a822e7 100644
|
|
--- a/tests/test_protocol.py
|
|
+++ b/tests/test_protocol.py
|
|
@@ -323,8 +323,8 @@ def test_local_address(self):
|
|
|
|
def test_local_address_before_connection(self):
|
|
# Emulate the situation before connection_open() runs.
|
|
- self.protocol.transport, _transport = None, self.protocol.transport
|
|
-
|
|
+ _transport = self.protocol.transport
|
|
+ del self.protocol.transport
|
|
try:
|
|
self.assertEqual(self.protocol.local_address, None)
|
|
finally:
|
|
@@ -339,8 +339,8 @@ def test_remote_address(self):
|
|
|
|
def test_remote_address_before_connection(self):
|
|
# Emulate the situation before connection_open() runs.
|
|
- self.protocol.transport, _transport = None, self.protocol.transport
|
|
-
|
|
+ _transport = self.protocol.transport
|
|
+ del self.protocol.transport
|
|
try:
|
|
self.assertEqual(self.protocol.remote_address, None)
|
|
finally:
|
|
|
|
From d75d4a4d34a7485ab8e5f2e65dee06dc37b7cfba Mon Sep 17 00:00:00 2001
|
|
From: Aymeric Augustin <aymeric.augustin@m4x.org>
|
|
Date: Sat, 5 Oct 2019 21:39:14 +0200
|
|
Subject: [PATCH 20/20] Remove useless type declaration.
|
|
|
|
---
|
|
src/websockets/protocol.py | 13 +++++--------
|
|
1 file changed, 5 insertions(+), 8 deletions(-)
|
|
|
|
diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py
|
|
index 0623e136..6c29b2a5 100644
|
|
--- a/src/websockets/protocol.py
|
|
+++ b/src/websockets/protocol.py
|
|
@@ -226,19 +226,16 @@ def __init__(
|
|
# ``self.read_limit``. The ``limit`` argument controls the line length
|
|
# limit and half the buffer limit of :class:`~asyncio.StreamReader`.
|
|
# That's why it must be set to half of ``self.read_limit``.
|
|
- self.reader: asyncio.StreamReader = asyncio.StreamReader(
|
|
- limit=read_limit // 2, loop=loop
|
|
- )
|
|
-
|
|
- self.transport: asyncio.Transport
|
|
- self._drain_lock = asyncio.Lock(
|
|
- loop=loop if sys.version_info[:2] < (3, 8) else None
|
|
- )
|
|
+ self.reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop)
|
|
|
|
# Copied from asyncio.FlowControlMixin
|
|
self._paused = False
|
|
self._drain_waiter: Optional[asyncio.Future[None]] = None
|
|
|
|
+ self._drain_lock = asyncio.Lock(
|
|
+ loop=loop if sys.version_info[:2] < (3, 8) else None
|
|
+ )
|
|
+
|
|
# This class implements the data transfer and closing handshake, which
|
|
# are shared between the client-side and the server-side.
|
|
# Subclasses implement the opening handshake and, on success, execute
|