1
0
python-djangorestframework-.../jwt2.patch

165 lines
7.7 KiB
Diff
Raw Normal View History

From ae1ea58daacb76a1d72e202198734053022a6efe Mon Sep 17 00:00:00 2001
From: Riccardo Magliocchetti <riccardo.magliocchetti@gmail.com>
Date: Mon, 22 Feb 2021 16:51:38 +0100
Subject: [PATCH] Support jwt 2 (#376)
* Add PyJWT 2.0.0 support
* Many tests relied on passing a payload into encode and changing it payload by reference.
In PyJWT 2.0.0, the payload is copied (`dict.copy()`).
* Verify was deprecated in favor of using the option verify_signature. This is reflected in backends.py. Unless you wrote verify=False, you are not effected by this change.
* Isort lint
Co-authored-by: Andrew-Chen-Wang <acwangpython@gmail.com>
---
rest_framework_simplejwt/backends.py | 8 +++--
tests/test_backends.py | 44 ++++++++++++++++++----------
2 files changed, 33 insertions(+), 19 deletions(-)
diff --git a/rest_framework_simplejwt/backends.py b/rest_framework_simplejwt/backends.py
index 887ecac..de7128a 100644
--- a/rest_framework_simplejwt/backends.py
+++ b/rest_framework_simplejwt/backends.py
@@ -65,9 +65,11 @@ def decode(self, token, verify=True):
signature check fails, or if its 'exp' claim indicates it has expired.
"""
try:
- return jwt.decode(token, self.verifying_key, algorithms=[self.algorithm], verify=verify,
- audience=self.audience, issuer=self.issuer,
- options={'verify_aud': self.audience is not None})
+ return jwt.decode(
+ token, self.verifying_key, algorithms=[self.algorithm], verify=verify,
+ audience=self.audience, issuer=self.issuer,
+ options={'verify_aud': self.audience is not None, "verify_signature": verify}
+ )
except InvalidAlgorithmError as ex:
raise TokenBackendError(_('Invalid algorithm specified')) from ex
except InvalidTokenError:
diff --git a/tests/test_backends.py b/tests/test_backends.py
index f9a9a23..bfb8da6 100644
--- a/tests/test_backends.py
+++ b/tests/test_backends.py
@@ -3,11 +3,13 @@
import jwt
from django.test import TestCase
-from jwt import PyJWT, algorithms
+from jwt import PyJWS, algorithms
from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.exceptions import TokenBackendError
-from rest_framework_simplejwt.utils import aware_utcnow, make_utc
+from rest_framework_simplejwt.utils import (
+ aware_utcnow, datetime_to_epoch, make_utc,
+)
SECRET = 'not_secret'
@@ -163,9 +165,9 @@ def test_decode_hmac_with_expiry(self):
def test_decode_hmac_with_invalid_sig(self):
self.payload['exp'] = aware_utcnow() + timedelta(days=1)
- token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
+ token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256')
self.payload['foo'] = 'baz'
- token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
+ token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256')
token_2_payload = token_2.rsplit('.', 1)[0]
token_1_sig = token_1.rsplit('.', 1)[-1]
@@ -176,9 +178,11 @@ def test_decode_hmac_with_invalid_sig(self):
def test_decode_hmac_with_invalid_sig_no_verify(self):
self.payload['exp'] = aware_utcnow() + timedelta(days=1)
- token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
+ token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256')
self.payload['foo'] = 'baz'
- token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
+ token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256')
+ # Payload copied
+ self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
token_2_payload = token_2.rsplit('.', 1)[0]
token_1_sig = token_1.rsplit('.', 1)[-1]
@@ -193,7 +197,9 @@ def test_decode_hmac_success(self):
self.payload['exp'] = aware_utcnow() + timedelta(days=1)
self.payload['foo'] = 'baz'
- token = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
+ token = jwt.encode(self.payload, SECRET, algorithm='HS256')
+ # Payload copied
+ self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
self.assertEqual(self.hmac_token_backend.decode(token), self.payload)
@@ -220,9 +226,9 @@ def test_decode_rsa_with_expiry(self):
def test_decode_rsa_with_invalid_sig(self):
self.payload['exp'] = aware_utcnow() + timedelta(days=1)
- token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+ token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
self.payload['foo'] = 'baz'
- token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+ token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
token_2_payload = token_2.rsplit('.', 1)[0]
token_1_sig = token_1.rsplit('.', 1)[-1]
@@ -233,13 +239,15 @@ def test_decode_rsa_with_invalid_sig(self):
def test_decode_rsa_with_invalid_sig_no_verify(self):
self.payload['exp'] = aware_utcnow() + timedelta(days=1)
- token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+ token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
self.payload['foo'] = 'baz'
- token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+ token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
token_2_payload = token_2.rsplit('.', 1)[0]
token_1_sig = token_1.rsplit('.', 1)[-1]
invalid_token = token_2_payload + '.' + token_1_sig
+ # Payload copied
+ self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
self.assertEqual(
self.hmac_token_backend.decode(invalid_token, verify=False),
@@ -250,7 +258,9 @@ def test_decode_rsa_success(self):
self.payload['exp'] = aware_utcnow() + timedelta(days=1)
self.payload['foo'] = 'baz'
- token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+ token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
+ # Payload copied
+ self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
self.assertEqual(self.rsa_token_backend.decode(token), self.payload)
@@ -260,21 +270,23 @@ def test_decode_aud_iss_success(self):
self.payload['aud'] = AUDIENCE
self.payload['iss'] = ISSUER
- token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+ token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
+ # Payload copied
+ self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
self.assertEqual(self.aud_iss_token_backend.decode(token), self.payload)
def test_decode_when_algorithm_not_available(self):
- token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+ token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
- pyjwt_without_rsa = PyJWT()
+ pyjwt_without_rsa = PyJWS()
pyjwt_without_rsa.unregister_algorithm('RS256')
with patch.object(jwt, 'decode', new=pyjwt_without_rsa.decode):
with self.assertRaisesRegex(TokenBackendError, 'Invalid algorithm specified'):
self.rsa_token_backend.decode(token)
def test_decode_when_token_algorithm_does_not_match(self):
- token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
+ token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
with self.assertRaisesRegex(TokenBackendError, 'Invalid algorithm specified'):
self.hmac_token_backend.decode(token)