From 2808378611dd6fb2532b189a9087877d8f0c0489 Mon Sep 17 00:00:00 2001 From: Hsiaoming Yang Date: Fri, 12 Dec 2025 16:37:44 +0900 Subject: [PATCH] Merge commit from fork --- .../base_client/framework_integration.py | 25 +++++----- tests/clients/test_flask/test_oauth_client.py | 49 +++++++++++++++++-- 2 files changed, 59 insertions(+), 15 deletions(-) Index: authlib-1.5.2/authlib/integrations/base_client/framework_integration.py =================================================================== --- authlib-1.5.2.orig/authlib/integrations/base_client/framework_integration.py +++ authlib-1.5.2/authlib/integrations/base_client/framework_integration.py @@ -20,11 +20,9 @@ class FrameworkIntegration: def _clear_session_state(self, session): now = time.time() + prefix = f"_state_{self.name}" for key in dict(session): - if "_authlib_" in key: - # TODO: remove in future - session.pop(key) - elif key.startswith("_state_"): + if key.startswith(prefix): value = session[key] exp = value.get("exp") if not exp or exp < now: @@ -32,29 +30,32 @@ class FrameworkIntegration: def get_state_data(self, session, state): key = f"_state_{self.name}_{state}" + session_data = session.get(key) + if not session_data: + return None if self.cache: - value = self._get_cache_data(key) + cached_value = self._get_cache_data(key) else: - value = session.get(key) - if value: - return value.get("data") + cached_value = session_data + if cached_value: + return cached_value.get("data") return None def set_state_data(self, session, state, data): key = f"_state_{self.name}_{state}" + now = time.time() if self.cache: self.cache.set(key, json.dumps({"data": data}), self.expires_in) + session[key] = {"exp": now + self.expires_in} else: - now = time.time() session[key] = {"data": data, "exp": now + self.expires_in} def clear_state_data(self, session, state): key = f"_state_{self.name}_{state}" if self.cache: self.cache.delete(key) - else: - session.pop(key, None) - self._clear_session_state(session) + session.pop(key, None) + self._clear_session_state(session) def update_token(self, token, refresh_token=None, access_token=None): raise NotImplementedError() Index: authlib-1.5.2/tests/clients/test_flask/test_oauth_client.py =================================================================== --- authlib-1.5.2.orig/tests/clients/test_flask/test_oauth_client.py +++ authlib-1.5.2/tests/clients/test_flask/test_oauth_client.py @@ -143,9 +143,13 @@ class FlaskOAuthTest(TestCase): self.assertEqual(resp.status_code, 302) url = resp.headers.get("Location") self.assertIn("oauth_token=foo", url) + session_data = session["_state_dev_foo"] + assert "exp" in session_data + assert "data" not in session_data with app.test_request_context("/?oauth_token=foo"): with mock.patch("requests.sessions.Session.send") as send: + session["_state_dev_foo"] = session_data send.return_value = mock_send_value( "oauth_token=a&oauth_token_secret=b" ) @@ -203,7 +207,44 @@ class FlaskOAuthTest(TestCase): session = oauth.dev._get_oauth_client() self.assertIsNotNone(session.update_token) - def test_oauth2_authorize(self): + def test_oauth2_authorize_cache(self): + app = Flask(__name__) + app.secret_key = "!" + cache = SimpleCache() + oauth = OAuth(app, cache=cache) + client = oauth.register( + "dev", + client_id="dev", + client_secret="dev", + api_base_url="https://resource.test/api", + access_token_url="https://provider.test/token", + authorize_url="https://provider.test/authorize", + ) + with app.test_request_context(): + resp = client.authorize_redirect("https://client.test/callback") + assert resp.status_code == 302 + url = resp.headers.get("Location") + assert "state=" in url + state = dict(url_decode(urlparse.urlparse(url).query))["state"] + assert state is not None + session_data = session[f"_state_dev_{state}"] + assert "exp" in session_data + assert "data" not in session_data + + with app.test_request_context(path=f"/?code=a&state={state}"): + # session is cleared in tests + session[f"_state_dev_{state}"] = session_data + + with mock.patch("requests.sessions.Session.send") as send: + send.return_value = mock_send_value(get_bearer_token()) + token = client.authorize_access_token() + assert token["access_token"] == "a" + + with app.test_request_context(): + assert client.token is None + + + def test_oauth2_authorize_session(self): app = Flask(__name__) app.secret_key = "!" oauth = OAuth(app) @@ -223,11 +264,12 @@ class FlaskOAuthTest(TestCase): self.assertIn("state=", url) state = dict(url_decode(urlparse.urlparse(url).query))["state"] self.assertIsNotNone(state) - data = session[f"_state_dev_{state}"] - + session_data = session[f"_state_dev_{state}"] + assert "exp" in session_data + assert "data" in session_data with app.test_request_context(path=f"/?code=a&state={state}"): # session is cleared in tests - session[f"_state_dev_{state}"] = data + session[f"_state_dev_{state}"] = session_data with mock.patch("requests.sessions.Session.send") as send: send.return_value = mock_send_value(get_bearer_token())