From ae1ea58daacb76a1d72e202198734053022a6efe Mon Sep 17 00:00:00 2001 From: Riccardo Magliocchetti Date: Mon, 22 Feb 2021 16:51:38 +0100 Subject: [PATCH] Support jwt 2 (#376) * Add PyJWT 2.0.0 support * Many tests relied on passing a payload into encode and changing it payload by reference. In PyJWT 2.0.0, the payload is copied (`dict.copy()`). * Verify was deprecated in favor of using the option verify_signature. This is reflected in backends.py. Unless you wrote verify=False, you are not effected by this change. * Isort lint Co-authored-by: Andrew-Chen-Wang --- rest_framework_simplejwt/backends.py | 8 +++-- tests/test_backends.py | 44 ++++++++++++++++++---------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/rest_framework_simplejwt/backends.py b/rest_framework_simplejwt/backends.py index 887ecac..de7128a 100644 --- a/rest_framework_simplejwt/backends.py +++ b/rest_framework_simplejwt/backends.py @@ -65,9 +65,11 @@ def decode(self, token, verify=True): signature check fails, or if its 'exp' claim indicates it has expired. """ try: - return jwt.decode(token, self.verifying_key, algorithms=[self.algorithm], verify=verify, - audience=self.audience, issuer=self.issuer, - options={'verify_aud': self.audience is not None}) + return jwt.decode( + token, self.verifying_key, algorithms=[self.algorithm], verify=verify, + audience=self.audience, issuer=self.issuer, + options={'verify_aud': self.audience is not None, "verify_signature": verify} + ) except InvalidAlgorithmError as ex: raise TokenBackendError(_('Invalid algorithm specified')) from ex except InvalidTokenError: diff --git a/tests/test_backends.py b/tests/test_backends.py index f9a9a23..bfb8da6 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -3,11 +3,13 @@ import jwt from django.test import TestCase -from jwt import PyJWT, algorithms +from jwt import PyJWS, algorithms from rest_framework_simplejwt.backends import TokenBackend from rest_framework_simplejwt.exceptions import TokenBackendError -from rest_framework_simplejwt.utils import aware_utcnow, make_utc +from rest_framework_simplejwt.utils import ( + aware_utcnow, datetime_to_epoch, make_utc, +) SECRET = 'not_secret' @@ -163,9 +165,9 @@ def test_decode_hmac_with_expiry(self): def test_decode_hmac_with_invalid_sig(self): self.payload['exp'] = aware_utcnow() + timedelta(days=1) - token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8') + token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256') self.payload['foo'] = 'baz' - token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8') + token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256') token_2_payload = token_2.rsplit('.', 1)[0] token_1_sig = token_1.rsplit('.', 1)[-1] @@ -176,9 +178,11 @@ def test_decode_hmac_with_invalid_sig(self): def test_decode_hmac_with_invalid_sig_no_verify(self): self.payload['exp'] = aware_utcnow() + timedelta(days=1) - token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8') + token_1 = jwt.encode(self.payload, SECRET, algorithm='HS256') self.payload['foo'] = 'baz' - token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8') + token_2 = jwt.encode(self.payload, SECRET, algorithm='HS256') + # Payload copied + self.payload["exp"] = datetime_to_epoch(self.payload["exp"]) token_2_payload = token_2.rsplit('.', 1)[0] token_1_sig = token_1.rsplit('.', 1)[-1] @@ -193,7 +197,9 @@ def test_decode_hmac_success(self): self.payload['exp'] = aware_utcnow() + timedelta(days=1) self.payload['foo'] = 'baz' - token = jwt.encode(self.payload, SECRET, algorithm='HS256').decode('utf-8') + token = jwt.encode(self.payload, SECRET, algorithm='HS256') + # Payload copied + self.payload["exp"] = datetime_to_epoch(self.payload["exp"]) self.assertEqual(self.hmac_token_backend.decode(token), self.payload) @@ -220,9 +226,9 @@ def test_decode_rsa_with_expiry(self): def test_decode_rsa_with_invalid_sig(self): self.payload['exp'] = aware_utcnow() + timedelta(days=1) - token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') + token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256') self.payload['foo'] = 'baz' - token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') + token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256') token_2_payload = token_2.rsplit('.', 1)[0] token_1_sig = token_1.rsplit('.', 1)[-1] @@ -233,13 +239,15 @@ def test_decode_rsa_with_invalid_sig(self): def test_decode_rsa_with_invalid_sig_no_verify(self): self.payload['exp'] = aware_utcnow() + timedelta(days=1) - token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') + token_1 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256') self.payload['foo'] = 'baz' - token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') + token_2 = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256') token_2_payload = token_2.rsplit('.', 1)[0] token_1_sig = token_1.rsplit('.', 1)[-1] invalid_token = token_2_payload + '.' + token_1_sig + # Payload copied + self.payload["exp"] = datetime_to_epoch(self.payload["exp"]) self.assertEqual( self.hmac_token_backend.decode(invalid_token, verify=False), @@ -250,7 +258,9 @@ def test_decode_rsa_success(self): self.payload['exp'] = aware_utcnow() + timedelta(days=1) self.payload['foo'] = 'baz' - token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') + token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256') + # Payload copied + self.payload["exp"] = datetime_to_epoch(self.payload["exp"]) self.assertEqual(self.rsa_token_backend.decode(token), self.payload) @@ -260,21 +270,23 @@ def test_decode_aud_iss_success(self): self.payload['aud'] = AUDIENCE self.payload['iss'] = ISSUER - token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') + token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256') + # Payload copied + self.payload["exp"] = datetime_to_epoch(self.payload["exp"]) self.assertEqual(self.aud_iss_token_backend.decode(token), self.payload) def test_decode_when_algorithm_not_available(self): - token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') + token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256') - pyjwt_without_rsa = PyJWT() + pyjwt_without_rsa = PyJWS() pyjwt_without_rsa.unregister_algorithm('RS256') with patch.object(jwt, 'decode', new=pyjwt_without_rsa.decode): with self.assertRaisesRegex(TokenBackendError, 'Invalid algorithm specified'): self.rsa_token_backend.decode(token) def test_decode_when_token_algorithm_does_not_match(self): - token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256').decode('utf-8') + token = jwt.encode(self.payload, PRIVATE_KEY, algorithm='RS256') with self.assertRaisesRegex(TokenBackendError, 'Invalid algorithm specified'): self.hmac_token_backend.decode(token)