diff --git a/rest_framework_simplejwt/exceptions.py b/rest_framework_simplejwt/exceptions.py index 8cc58e976..80365a343 100644 --- a/rest_framework_simplejwt/exceptions.py +++ b/rest_framework_simplejwt/exceptions.py @@ -54,3 +54,11 @@ class InvalidToken(AuthenticationFailed): status_code = status.HTTP_401_UNAUTHORIZED default_detail = _("Token is invalid or expired") default_code = "token_not_valid" + + +class TokenBlacklistNotConfigured(DetailDictMixin, exceptions.APIException): + status_code = status.HTTP_501_NOT_IMPLEMENTED + default_detail = _( + "Token blacklist functionality is not enabled or available. Please check your configuration." + ) + default_code = "blacklist_not_configured" diff --git a/rest_framework_simplejwt/serializers.py b/rest_framework_simplejwt/serializers.py index 45c5a771c..db79a38d5 100644 --- a/rest_framework_simplejwt/serializers.py +++ b/rest_framework_simplejwt/serializers.py @@ -7,6 +7,7 @@ from rest_framework import exceptions, serializers from rest_framework.exceptions import AuthenticationFailed, ValidationError +from .exceptions import TokenBlacklistNotConfigured from .models import TokenUser from .settings import api_settings from .tokens import RefreshToken, SlidingToken, Token, UntypedToken @@ -189,5 +190,6 @@ def validate(self, attrs: dict[str, Any]) -> dict[Any, Any]: try: refresh.blacklist() except AttributeError: - pass - return {} + raise TokenBlacklistNotConfigured() + + return {"message": "Token blacklisted"} diff --git a/tests/test_serializers.py b/tests/test_serializers.py index 7ce35c085..b6f5aa019 100644 --- a/tests/test_serializers.py +++ b/tests/test_serializers.py @@ -8,7 +8,7 @@ from django.test import TestCase, override_settings from rest_framework import exceptions as drf_exceptions -from rest_framework_simplejwt.exceptions import TokenError +from rest_framework_simplejwt.exceptions import TokenBlacklistNotConfigured, TokenError from rest_framework_simplejwt.serializers import ( TokenBlacklistSerializer, TokenObtainPairSerializer, @@ -561,7 +561,7 @@ def test_it_should_raise_token_error_if_token_has_wrong_type(self): self.assertIn("wrong type", e.exception.args[0]) - def test_it_should_return_nothing_if_everything_ok(self): + def test_it_should_return_message_if_everything_ok(self): refresh = RefreshToken() refresh["test_claim"] = "arst" @@ -574,7 +574,7 @@ def test_it_should_return_nothing_if_everything_ok(self): fake_aware_utcnow.return_value = now self.assertTrue(s.is_valid()) - self.assertDictEqual(s.validated_data, {}) + self.assertDictEqual(s.validated_data, {"message": "Token blacklisted"}) def test_it_should_blacklist_refresh_token_if_everything_ok(self): self.assertEqual(OutstandingToken.objects.count(), 0) @@ -601,24 +601,31 @@ def test_it_should_blacklist_refresh_token_if_everything_ok(self): # Assert old refresh token is blacklisted self.assertEqual(BlacklistedToken.objects.first().token.jti, old_jti) - def test_blacklist_app_not_installed_should_pass(self): + def test_blacklist_app_not_installed_should_raise_token_blacklist_not_configured( + self, + ): from rest_framework_simplejwt import serializers, tokens # Remove blacklist app new_apps = list(settings.INSTALLED_APPS) new_apps.remove("rest_framework_simplejwt.token_blacklist") - with self.settings(INSTALLED_APPS=tuple(new_apps)): - # Reload module that blacklist app not installed - reload(tokens) - reload(serializers) + try: + with self.settings(INSTALLED_APPS=tuple(new_apps)): + # Reload module that blacklist app not installed + reload(tokens) + reload(serializers) - refresh = tokens.RefreshToken() + refresh = tokens.RefreshToken() - # Serializer validates - ser = serializers.TokenBlacklistSerializer(data={"refresh": str(refresh)}) - ser.validate({"refresh": str(refresh)}) + # Serializer validates + ser = serializers.TokenBlacklistSerializer( + data={"refresh": str(refresh)} + ) - # Restore origin module without mock - reload(tokens) - reload(serializers) + with self.assertRaises(TokenBlacklistNotConfigured): + ser.validate({"refresh": str(refresh)}) + finally: + # Restore origin module without mock + reload(tokens) + reload(serializers) diff --git a/tests/test_views.py b/tests/test_views.py index 4927e6ffd..0e7dae81b 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -406,7 +406,7 @@ def test_it_should_return_if_everything_ok(self): self.assertEqual(res.status_code, 200) - self.assertDictEqual(res.data, {}) + self.assertDictEqual(res.data, {"message": "Token blacklisted"}) def test_it_should_return_401_if_token_is_blacklisted(self): refresh = RefreshToken()