From a9818e4ef92136bf4180ff9c1f02d473ba55b6be Mon Sep 17 00:00:00 2001 From: Eshaan Bansal Date: Sat, 24 Oct 2020 19:04:17 +0530 Subject: [PATCH] allow renew_token as pluggable and tests for same --- durin/settings.py | 2 +- durin/views.py | 44 ++++++++++++++++++++++++-------------------- tests/tests.py | 46 +++++++++++++++++++++++++++++++++++++--------- 3 files changed, 62 insertions(+), 30 deletions(-) diff --git a/durin/settings.py b/durin/settings.py index ee3244b..707a38c 100644 --- a/durin/settings.py +++ b/durin/settings.py @@ -13,7 +13,7 @@ "AUTH_HEADER_PREFIX": "Token", "EXPIRY_DATETIME_FORMAT": api_settings.DATETIME_FORMAT, "TOKEN_CACHE_TIMEOUT": 60, - "REFRESH_TOKEN_ON_USE": True + "REFRESH_TOKEN_ON_LOGIN": False, } IMPORT_STRINGS = { diff --git a/durin/views.py b/durin/views.py index 14b8a59..83d5633 100644 --- a/durin/views.py +++ b/durin/views.py @@ -33,7 +33,7 @@ def validate_and_return_user(request): return serializer.validated_data["user"] @staticmethod - def get_token_client(request): + def get_client_obj(request) -> "Client": client_name = request.data.get("client", None) if not client_name: raise ParseError("No client specified.", status.HTTP_400_BAD_REQUEST) @@ -44,38 +44,42 @@ def format_expiry_datetime(expiry): datetime_format = durin_settings.EXPIRY_DATETIME_FORMAT return DateTimeField(format=datetime_format).to_representation(expiry) - def get_post_response_data(self, request, instance): + def get_post_response_data(self, request, token_obj: "AuthToken"): UserSerializer = durin_settings.USER_SERIALIZER data = { - "expiry": self.format_expiry_datetime(instance.expiry), - "token": instance.token, + "expiry": self.format_expiry_datetime(token_obj.expiry), + "token": token_obj.token, } if UserSerializer is not None: data["user"] = UserSerializer(request.user, context=self.get_context()).data return data - @staticmethod - def get_new_token(user, client): - return AuthToken.objects.create(user, client) + @classmethod + def renew_token(cls, token_obj: "AuthToken"): + token_obj.renew_token(renewed_by=cls) - def post(self, request, format=None): - request.user = self.validate_and_return_user(request) - client = self.get_token_client(request) + @classmethod + def get_token_obj(cls, request, client: "Client") -> "AuthToken": try: - # a token for this user and client already exists, - # so we can return the same one by renewing it's expiry - instance = AuthToken.objects.get(user=request.user, client=client) - if durin_settings.REFRESH_TOKEN_ON_USE: - instance.renew_token(renewed_by=self.__class__) + # a token for this user and client already exists, so we can just return it + token = AuthToken.objects.get(user=request.user, client=client) + if durin_settings.REFRESH_TOKEN_ON_LOGIN: + cls.renew_token(token) except ObjectDoesNotExist: # create new token - instance = self.get_new_token(request.user, client) + token = AuthToken.objects.create(request.user, client) + return token + + def post(self, request, *args, **kwargs): + request.user = self.validate_and_return_user(request) + client = self.get_client_obj(request) + token_obj = self.get_token_obj(request, client) user_logged_in.send( sender=request.user.__class__, request=request, user=request.user ) - data = self.get_post_response_data(request, instance) + data = self.get_post_response_data(request, token_obj) return Response(data) @@ -92,7 +96,7 @@ def format_expiry_datetime(expiry): datetime_format = durin_settings.EXPIRY_DATETIME_FORMAT return DateTimeField(format=datetime_format).to_representation(expiry) - def post(self, request, format=None): + def post(self, request, *args, **kwargs): auth_token = request._auth new_expiry = auth_token.renew_token(renewed_by=self.__class__) new_expiry_repr = self.format_expiry_datetime(new_expiry) @@ -109,7 +113,7 @@ class LogoutView(APIView): authentication_classes = (TokenAuthentication,) permission_classes = (IsAuthenticated,) - def post(self, request, format=None): + def post(self, request, *args, **kwargs): request._auth.delete() user_logged_out.send( sender=request.user.__class__, request=request, user=request.user @@ -126,7 +130,7 @@ class LogoutAllView(APIView): authentication_classes = (TokenAuthentication,) permission_classes = (IsAuthenticated,) - def post(self, request, format=None): + def post(self, request, *args, **kwargs): request.user.auth_token_set.all().delete() user_logged_out.send( sender=request.user.__class__, request=request, user=request.user diff --git a/tests/tests.py b/tests/tests.py index 4af045e..39d7647 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -24,12 +24,11 @@ refresh_url = reverse("durin_refresh") new_settings = durin_settings.defaults.copy() -EXPIRY_DATETIME_FORMAT = "%H:%M %d/%m/%y" -new_settings["EXPIRY_DATETIME_FORMAT"] = EXPIRY_DATETIME_FORMAT class AuthTestCase(APITestCase): def setUp(self): + self.authclient = Client.objects.create(name="authclientfortest") username = "john.doe" email = "john.doe@example.com" password = "hunter2" @@ -37,7 +36,7 @@ def setUp(self): self.creds = { "username": username, "password": password, - "client": "authclientfortest", + "client": self.authclient.name, } username2 = "jane.doe" @@ -47,11 +46,10 @@ def setUp(self): self.creds2 = { "username": username2, "password": password2, - "client": "authclientfortest", + "client": self.authclient.name, } self.client_names = ["web", "mobile", "cli"] - self.authclient = Client.objects.create(name="authclientfortest") def test_create_clients(self): self.assertEqual(Client.objects.count(), 1) @@ -114,7 +112,8 @@ def test_login_returns_serialized_token_and_username_field(self): def test_login_returns_configured_expiry_datetime_format(self): self.assertEqual(AuthToken.objects.count(), 0) - + EXPIRY_DATETIME_FORMAT = "%H:%M %d/%m/%y" + new_settings["EXPIRY_DATETIME_FORMAT"] = EXPIRY_DATETIME_FORMAT with override_settings(REST_DURIN=new_settings): reload(views) self.assertEqual( @@ -267,7 +266,7 @@ def test_invalid_auth_header_return_401(self): resp2 = self.client.get(root_url) self.assertEqual(resp2.status_code, 401) - def test_login_should_renew_token_for_existing_client(self): + def test_login_same_token_existing_client(self): self.assertEqual(AuthToken.objects.count(), 0) resp1 = self.client.post(login_url, self.creds, format="json") self.assertEqual(resp1.status_code, 200) @@ -281,10 +280,10 @@ def test_login_should_renew_token_for_existing_client(self): 1, "should renew token, instead of creating new.", ) - self.assertNotEqual( + self.assertEqual( resp1.data["expiry"], resp2.data["expiry"], - "token expiry should be renewed by login", + "token expiry should be same after login", ) self.assertEqual( resp1.data["token"], @@ -292,6 +291,35 @@ def test_login_should_renew_token_for_existing_client(self): "login should return existing token", ) + def test_login_renew_token_existing_client(self): + self.assertEqual(AuthToken.objects.count(), 0) + new_settings["REFRESH_TOKEN_ON_LOGIN"] = True + with override_settings(REST_DURIN=new_settings): + reload(views) + resp1 = self.client.post(login_url, self.creds, format="json") + self.assertEqual(resp1.status_code, 200) + self.assertIn("token", resp1.data) + resp2 = self.client.post(login_url, self.creds, format="json") + self.assertEqual(resp2.status_code, 200) + self.assertIn("token", resp2.data) + + reload(views) + self.assertEqual( + AuthToken.objects.count(), + 1, + "should renew token, instead of creating new.", + ) + self.assertNotEqual( + resp1.data["expiry"], + resp2.data["expiry"], + "token expiry should be renewed after login", + ) + self.assertEqual( + resp1.data["token"], + resp2.data["token"], + "token key must remain same", + ) + def test_refresh_view_and_renewed_signal(self): self.signal_was_called = False