165 lines
7.7 KiB
Diff
165 lines
7.7 KiB
Diff
|
From ae1ea58daacb76a1d72e202198734053022a6efe Mon Sep 17 00:00:00 2001
|
||
|
From: Riccardo Magliocchetti <riccardo.magliocchetti@gmail.com>
|
||
|
Date: Mon, 22 Feb 2021 16:51:38 +0100
|
||
|
Subject: [PATCH] Support jwt 2 (#376)
|
||
|
|
||
|
* Add PyJWT 2.0.0 support
|
||
|
|
||
|
* Many tests relied on passing a payload into encode and changing it payload by reference.
|
||
|
In PyJWT 2.0.0, the payload is copied (`dict.copy()`).
|
||
|
* Verify was deprecated in favor of using the option verify_signature. This is reflected in backends.py. Unless you wrote verify=False, you are not effected by this change.
|
||
|
|
||
|
* Isort lint
|
||
|
|
||
|
Co-authored-by: Andrew-Chen-Wang <acwangpython@gmail.com>
|
||
|
---
|
||
|
rest_framework_simplejwt/backends.py | 8 +++--
|
||
|
tests/test_backends.py | 44 ++++++++++++++++++----------
|
||
|
2 files changed, 33 insertions(+), 19 deletions(-)
|
||
|
|
||
|
diff --git a/rest_framework_simplejwt/backends.py b/rest_framework_simplejwt/backends.py
|
||
|
index 887ecac..de7128a 100644
|
||
|
--- a/rest_framework_simplejwt/backends.py
|
||
|
+++ b/rest_framework_simplejwt/backends.py
|
||
|
@@ -65,9 +65,11 @@ def decode(self, token, verify=True):
|
||
|
signature check fails, or if its 'exp' claim indicates it has expired.
|
||
|
"""
|
||
|
try:
|
||
|
- return jwt.decode(token, self.verifying_key, algorithms=[self.algorithm], verify=verify,
|
||
|
- audience=self.audience, issuer=self.issuer,
|
||
|
- options={'verify_aud': self.audience is not None})
|
||
|
+ return jwt.decode(
|
||
|
+ token, self.verifying_key, algorithms=[self.algorithm], verify=verify,
|
||
|
+ audience=self.audience, issuer=self.issuer,
|
||
|
+ options={'verify_aud': self.audience is not None, "verify_signature": verify}
|
||
|
+ )
|
||
|
except InvalidAlgorithmError as ex:
|
||
|
raise TokenBackendError(_('Invalid algorithm specified')) from ex
|
||
|
except InvalidTokenError:
|
||
|
diff --git a/tests/test_backends.py b/tests/test_backends.py
|
||
|
index f9a9a23..bfb8da6 100644
|
||
|
--- a/tests/test_backends.py
|
||
|
+++ b/tests/test_backends.py
|
||
|
@@ -3,11 +3,13 @@
|
||
|
|
||
|
import jwt
|
||
|
from django.test import TestCase
|
||
|
-from jwt import PyJWT, algorithms
|
||
|
+from jwt import PyJWS, algorithms
|
||
|
|
||
|
from rest_framework_simplejwt.backends import TokenBackend
|
||
|
from rest_framework_simplejwt.exceptions import TokenBackendError
|
||
|
-from rest_framework_simplejwt.utils import aware_utcnow, make_utc
|
||
|
+from rest_framework_simplejwt.utils import (
|
||
|
+ aware_utcnow, datetime_to_epoch, make_utc,
|
||
|
+)
|
||
|
|
||
|
SECRET = 'not_secret'
|
||
|
|
||
|
@@ -163,9 +165,9 @@ def test_decode_hmac_with_expiry(self):
|
||
|
|
||
|
def test_decode_hmac_with_invalid_sig(self):
|
||
|
self.payload['exp'] = aware_utcnow() + timedelta(days=1)
|
||
|
- token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
|
||
|
+ token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256')
|
||
|
self.payload['foo'] = 'baz'
|
||
|
- token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
|
||
|
+ token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256')
|
||
|
|
||
|
token_2_payload = token_2.rsplit('.', 1)[0]
|
||
|
token_1_sig = token_1.rsplit('.', 1)[-1]
|
||
|
@@ -176,9 +178,11 @@ def test_decode_hmac_with_invalid_sig(self):
|
||
|
|
||
|
def test_decode_hmac_with_invalid_sig_no_verify(self):
|
||
|
self.payload['exp'] = aware_utcnow() + timedelta(days=1)
|
||
|
- token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
|
||
|
+ token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256')
|
||
|
self.payload['foo'] = 'baz'
|
||
|
- token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
|
||
|
+ token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256')
|
||
|
+ # Payload copied
|
||
|
+ self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
|
||
|
|
||
|
token_2_payload = token_2.rsplit('.', 1)[0]
|
||
|
token_1_sig = token_1.rsplit('.', 1)[-1]
|
||
|
@@ -193,7 +197,9 @@ def test_decode_hmac_success(self):
|
||
|
self.payload['exp'] = aware_utcnow() + timedelta(days=1)
|
||
|
self.payload['foo'] = 'baz'
|
||
|
|
||
|
- token = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8')
|
||
|
+ token = jwt.encode(self.payload, SECRET, algorithm='HS256')
|
||
|
+ # Payload copied
|
||
|
+ self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
|
||
|
|
||
|
self.assertEqual(self.hmac_token_backend.decode(token), self.payload)
|
||
|
|
||
|
@@ -220,9 +226,9 @@ def test_decode_rsa_with_expiry(self):
|
||
|
|
||
|
def test_decode_rsa_with_invalid_sig(self):
|
||
|
self.payload['exp'] = aware_utcnow() + timedelta(days=1)
|
||
|
- token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
|
||
|
+ token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
|
||
|
self.payload['foo'] = 'baz'
|
||
|
- token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
|
||
|
+ token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
|
||
|
|
||
|
token_2_payload = token_2.rsplit('.', 1)[0]
|
||
|
token_1_sig = token_1.rsplit('.', 1)[-1]
|
||
|
@@ -233,13 +239,15 @@ def test_decode_rsa_with_invalid_sig(self):
|
||
|
|
||
|
def test_decode_rsa_with_invalid_sig_no_verify(self):
|
||
|
self.payload['exp'] = aware_utcnow() + timedelta(days=1)
|
||
|
- token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
|
||
|
+ token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
|
||
|
self.payload['foo'] = 'baz'
|
||
|
- token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
|
||
|
+ token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
|
||
|
|
||
|
token_2_payload = token_2.rsplit('.', 1)[0]
|
||
|
token_1_sig = token_1.rsplit('.', 1)[-1]
|
||
|
invalid_token = token_2_payload + '.' + token_1_sig
|
||
|
+ # Payload copied
|
||
|
+ self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
|
||
|
|
||
|
self.assertEqual(
|
||
|
self.hmac_token_backend.decode(invalid_token, verify=False),
|
||
|
@@ -250,7 +258,9 @@ def test_decode_rsa_success(self):
|
||
|
self.payload['exp'] = aware_utcnow() + timedelta(days=1)
|
||
|
self.payload['foo'] = 'baz'
|
||
|
|
||
|
- token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
|
||
|
+ token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
|
||
|
+ # Payload copied
|
||
|
+ self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
|
||
|
|
||
|
self.assertEqual(self.rsa_token_backend.decode(token), self.payload)
|
||
|
|
||
|
@@ -260,21 +270,23 @@ def test_decode_aud_iss_success(self):
|
||
|
self.payload['aud'] = AUDIENCE
|
||
|
self.payload['iss'] = ISSUER
|
||
|
|
||
|
- token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
|
||
|
+ token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
|
||
|
+ # Payload copied
|
||
|
+ self.payload["exp"] = datetime_to_epoch(self.payload["exp"])
|
||
|
|
||
|
self.assertEqual(self.aud_iss_token_backend.decode(token), self.payload)
|
||
|
|
||
|
def test_decode_when_algorithm_not_available(self):
|
||
|
- token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
|
||
|
+ token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
|
||
|
|
||
|
- pyjwt_without_rsa = PyJWT()
|
||
|
+ pyjwt_without_rsa = PyJWS()
|
||
|
pyjwt_without_rsa.unregister_algorithm('RS256')
|
||
|
with patch.object(jwt, 'decode', new=pyjwt_without_rsa.decode):
|
||
|
with self.assertRaisesRegex(TokenBackendError, 'Invalid algorithm specified'):
|
||
|
self.rsa_token_backend.decode(token)
|
||
|
|
||
|
def test_decode_when_token_algorithm_does_not_match(self):
|
||
|
- token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8')
|
||
|
+ token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256')
|
||
|
|
||
|
with self.assertRaisesRegex(TokenBackendError, 'Invalid algorithm specified'):
|
||
|
self.hmac_token_backend.decode(token)
|