forked from pool/python-Authlib
150 lines
6.1 KiB
Diff
150 lines
6.1 KiB
Diff
From 2808378611dd6fb2532b189a9087877d8f0c0489 Mon Sep 17 00:00:00 2001
|
|
From: Hsiaoming Yang <me@lepture.com>
|
|
Date: Fri, 12 Dec 2025 16:37:44 +0900
|
|
Subject: [PATCH] Merge commit from fork
|
|
|
|
---
|
|
.../base_client/framework_integration.py | 25 +++++-----
|
|
tests/clients/test_flask/test_oauth_client.py | 49 +++++++++++++++++--
|
|
2 files changed, 59 insertions(+), 15 deletions(-)
|
|
|
|
Index: authlib-1.5.2/authlib/integrations/base_client/framework_integration.py
|
|
===================================================================
|
|
--- authlib-1.5.2.orig/authlib/integrations/base_client/framework_integration.py
|
|
+++ authlib-1.5.2/authlib/integrations/base_client/framework_integration.py
|
|
@@ -20,11 +20,9 @@ class FrameworkIntegration:
|
|
|
|
def _clear_session_state(self, session):
|
|
now = time.time()
|
|
+ prefix = f"_state_{self.name}"
|
|
for key in dict(session):
|
|
- if "_authlib_" in key:
|
|
- # TODO: remove in future
|
|
- session.pop(key)
|
|
- elif key.startswith("_state_"):
|
|
+ if key.startswith(prefix):
|
|
value = session[key]
|
|
exp = value.get("exp")
|
|
if not exp or exp < now:
|
|
@@ -32,29 +30,32 @@ class FrameworkIntegration:
|
|
|
|
def get_state_data(self, session, state):
|
|
key = f"_state_{self.name}_{state}"
|
|
+ session_data = session.get(key)
|
|
+ if not session_data:
|
|
+ return None
|
|
if self.cache:
|
|
- value = self._get_cache_data(key)
|
|
+ cached_value = self._get_cache_data(key)
|
|
else:
|
|
- value = session.get(key)
|
|
- if value:
|
|
- return value.get("data")
|
|
+ cached_value = session_data
|
|
+ if cached_value:
|
|
+ return cached_value.get("data")
|
|
return None
|
|
|
|
def set_state_data(self, session, state, data):
|
|
key = f"_state_{self.name}_{state}"
|
|
+ now = time.time()
|
|
if self.cache:
|
|
self.cache.set(key, json.dumps({"data": data}), self.expires_in)
|
|
+ session[key] = {"exp": now + self.expires_in}
|
|
else:
|
|
- now = time.time()
|
|
session[key] = {"data": data, "exp": now + self.expires_in}
|
|
|
|
def clear_state_data(self, session, state):
|
|
key = f"_state_{self.name}_{state}"
|
|
if self.cache:
|
|
self.cache.delete(key)
|
|
- else:
|
|
- session.pop(key, None)
|
|
- self._clear_session_state(session)
|
|
+ session.pop(key, None)
|
|
+ self._clear_session_state(session)
|
|
|
|
def update_token(self, token, refresh_token=None, access_token=None):
|
|
raise NotImplementedError()
|
|
Index: authlib-1.5.2/tests/clients/test_flask/test_oauth_client.py
|
|
===================================================================
|
|
--- authlib-1.5.2.orig/tests/clients/test_flask/test_oauth_client.py
|
|
+++ authlib-1.5.2/tests/clients/test_flask/test_oauth_client.py
|
|
@@ -143,9 +143,13 @@ class FlaskOAuthTest(TestCase):
|
|
self.assertEqual(resp.status_code, 302)
|
|
url = resp.headers.get("Location")
|
|
self.assertIn("oauth_token=foo", url)
|
|
+ session_data = session["_state_dev_foo"]
|
|
+ assert "exp" in session_data
|
|
+ assert "data" not in session_data
|
|
|
|
with app.test_request_context("/?oauth_token=foo"):
|
|
with mock.patch("requests.sessions.Session.send") as send:
|
|
+ session["_state_dev_foo"] = session_data
|
|
send.return_value = mock_send_value(
|
|
"oauth_token=a&oauth_token_secret=b"
|
|
)
|
|
@@ -203,7 +207,44 @@ class FlaskOAuthTest(TestCase):
|
|
session = oauth.dev._get_oauth_client()
|
|
self.assertIsNotNone(session.update_token)
|
|
|
|
- def test_oauth2_authorize(self):
|
|
+ def test_oauth2_authorize_cache(self):
|
|
+ app = Flask(__name__)
|
|
+ app.secret_key = "!"
|
|
+ cache = SimpleCache()
|
|
+ oauth = OAuth(app, cache=cache)
|
|
+ client = oauth.register(
|
|
+ "dev",
|
|
+ client_id="dev",
|
|
+ client_secret="dev",
|
|
+ api_base_url="https://resource.test/api",
|
|
+ access_token_url="https://provider.test/token",
|
|
+ authorize_url="https://provider.test/authorize",
|
|
+ )
|
|
+ with app.test_request_context():
|
|
+ resp = client.authorize_redirect("https://client.test/callback")
|
|
+ assert resp.status_code == 302
|
|
+ url = resp.headers.get("Location")
|
|
+ assert "state=" in url
|
|
+ state = dict(url_decode(urlparse.urlparse(url).query))["state"]
|
|
+ assert state is not None
|
|
+ session_data = session[f"_state_dev_{state}"]
|
|
+ assert "exp" in session_data
|
|
+ assert "data" not in session_data
|
|
+
|
|
+ with app.test_request_context(path=f"/?code=a&state={state}"):
|
|
+ # session is cleared in tests
|
|
+ session[f"_state_dev_{state}"] = session_data
|
|
+
|
|
+ with mock.patch("requests.sessions.Session.send") as send:
|
|
+ send.return_value = mock_send_value(get_bearer_token())
|
|
+ token = client.authorize_access_token()
|
|
+ assert token["access_token"] == "a"
|
|
+
|
|
+ with app.test_request_context():
|
|
+ assert client.token is None
|
|
+
|
|
+
|
|
+ def test_oauth2_authorize_session(self):
|
|
app = Flask(__name__)
|
|
app.secret_key = "!"
|
|
oauth = OAuth(app)
|
|
@@ -223,11 +264,12 @@ class FlaskOAuthTest(TestCase):
|
|
self.assertIn("state=", url)
|
|
state = dict(url_decode(urlparse.urlparse(url).query))["state"]
|
|
self.assertIsNotNone(state)
|
|
- data = session[f"_state_dev_{state}"]
|
|
-
|
|
+ session_data = session[f"_state_dev_{state}"]
|
|
+ assert "exp" in session_data
|
|
+ assert "data" in session_data
|
|
with app.test_request_context(path=f"/?code=a&state={state}"):
|
|
# session is cleared in tests
|
|
- session[f"_state_dev_{state}"] = data
|
|
+ session[f"_state_dev_{state}"] = session_data
|
|
|
|
with mock.patch("requests.sessions.Session.send") as send:
|
|
send.return_value = mock_send_value(get_bearer_token())
|