Skip to content

Commit

Permalink
allow renew_token as pluggable and tests for same
Browse files Browse the repository at this point in the history
  • Loading branch information
eshaan7 committed Oct 24, 2020
1 parent cc98922 commit a9818e4
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 30 deletions.
2 changes: 1 addition & 1 deletion durin/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"AUTH_HEADER_PREFIX": "Token",
"EXPIRY_DATETIME_FORMAT": api_settings.DATETIME_FORMAT,
"TOKEN_CACHE_TIMEOUT": 60,
"REFRESH_TOKEN_ON_USE": True
"REFRESH_TOKEN_ON_LOGIN": False,
}

IMPORT_STRINGS = {
Expand Down
44 changes: 24 additions & 20 deletions durin/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def validate_and_return_user(request):
return serializer.validated_data["user"]

@staticmethod
def get_token_client(request):
def get_client_obj(request) -> "Client":
client_name = request.data.get("client", None)
if not client_name:
raise ParseError("No client specified.", status.HTTP_400_BAD_REQUEST)
Expand All @@ -44,38 +44,42 @@ def format_expiry_datetime(expiry):
datetime_format = durin_settings.EXPIRY_DATETIME_FORMAT
return DateTimeField(format=datetime_format).to_representation(expiry)

def get_post_response_data(self, request, instance):
def get_post_response_data(self, request, token_obj: "AuthToken"):
UserSerializer = durin_settings.USER_SERIALIZER

data = {
"expiry": self.format_expiry_datetime(instance.expiry),
"token": instance.token,
"expiry": self.format_expiry_datetime(token_obj.expiry),
"token": token_obj.token,
}
if UserSerializer is not None:
data["user"] = UserSerializer(request.user, context=self.get_context()).data
return data

@staticmethod
def get_new_token(user, client):
return AuthToken.objects.create(user, client)
@classmethod
def renew_token(cls, token_obj: "AuthToken"):
token_obj.renew_token(renewed_by=cls)

def post(self, request, format=None):
request.user = self.validate_and_return_user(request)
client = self.get_token_client(request)
@classmethod
def get_token_obj(cls, request, client: "Client") -> "AuthToken":
try:
# a token for this user and client already exists,
# so we can return the same one by renewing it's expiry
instance = AuthToken.objects.get(user=request.user, client=client)
if durin_settings.REFRESH_TOKEN_ON_USE:
instance.renew_token(renewed_by=self.__class__)
# a token for this user and client already exists, so we can just return it
token = AuthToken.objects.get(user=request.user, client=client)
if durin_settings.REFRESH_TOKEN_ON_LOGIN:
cls.renew_token(token)
except ObjectDoesNotExist:
# create new token
instance = self.get_new_token(request.user, client)
token = AuthToken.objects.create(request.user, client)

return token

def post(self, request, *args, **kwargs):
request.user = self.validate_and_return_user(request)
client = self.get_client_obj(request)
token_obj = self.get_token_obj(request, client)
user_logged_in.send(
sender=request.user.__class__, request=request, user=request.user
)
data = self.get_post_response_data(request, instance)
data = self.get_post_response_data(request, token_obj)
return Response(data)


Expand All @@ -92,7 +96,7 @@ def format_expiry_datetime(expiry):
datetime_format = durin_settings.EXPIRY_DATETIME_FORMAT
return DateTimeField(format=datetime_format).to_representation(expiry)

def post(self, request, format=None):
def post(self, request, *args, **kwargs):
auth_token = request._auth
new_expiry = auth_token.renew_token(renewed_by=self.__class__)
new_expiry_repr = self.format_expiry_datetime(new_expiry)
Expand All @@ -109,7 +113,7 @@ class LogoutView(APIView):
authentication_classes = (TokenAuthentication,)
permission_classes = (IsAuthenticated,)

def post(self, request, format=None):
def post(self, request, *args, **kwargs):
request._auth.delete()
user_logged_out.send(
sender=request.user.__class__, request=request, user=request.user
Expand All @@ -126,7 +130,7 @@ class LogoutAllView(APIView):
authentication_classes = (TokenAuthentication,)
permission_classes = (IsAuthenticated,)

def post(self, request, format=None):
def post(self, request, *args, **kwargs):
request.user.auth_token_set.all().delete()
user_logged_out.send(
sender=request.user.__class__, request=request, user=request.user
Expand Down
46 changes: 37 additions & 9 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,19 @@
refresh_url = reverse("durin_refresh")

new_settings = durin_settings.defaults.copy()
EXPIRY_DATETIME_FORMAT = "%H:%M %d/%m/%y"
new_settings["EXPIRY_DATETIME_FORMAT"] = EXPIRY_DATETIME_FORMAT


class AuthTestCase(APITestCase):
def setUp(self):
self.authclient = Client.objects.create(name="authclientfortest")
username = "john.doe"
email = "[email protected]"
password = "hunter2"
self.user = User.objects.create_user(username, email, password)
self.creds = {
"username": username,
"password": password,
"client": "authclientfortest",
"client": self.authclient.name,
}

username2 = "jane.doe"
Expand All @@ -47,11 +46,10 @@ def setUp(self):
self.creds2 = {
"username": username2,
"password": password2,
"client": "authclientfortest",
"client": self.authclient.name,
}

self.client_names = ["web", "mobile", "cli"]
self.authclient = Client.objects.create(name="authclientfortest")

def test_create_clients(self):
self.assertEqual(Client.objects.count(), 1)
Expand Down Expand Up @@ -114,7 +112,8 @@ def test_login_returns_serialized_token_and_username_field(self):

def test_login_returns_configured_expiry_datetime_format(self):
self.assertEqual(AuthToken.objects.count(), 0)

EXPIRY_DATETIME_FORMAT = "%H:%M %d/%m/%y"
new_settings["EXPIRY_DATETIME_FORMAT"] = EXPIRY_DATETIME_FORMAT
with override_settings(REST_DURIN=new_settings):
reload(views)
self.assertEqual(
Expand Down Expand Up @@ -267,7 +266,7 @@ def test_invalid_auth_header_return_401(self):
resp2 = self.client.get(root_url)
self.assertEqual(resp2.status_code, 401)

def test_login_should_renew_token_for_existing_client(self):
def test_login_same_token_existing_client(self):
self.assertEqual(AuthToken.objects.count(), 0)
resp1 = self.client.post(login_url, self.creds, format="json")
self.assertEqual(resp1.status_code, 200)
Expand All @@ -281,17 +280,46 @@ def test_login_should_renew_token_for_existing_client(self):
1,
"should renew token, instead of creating new.",
)
self.assertNotEqual(
self.assertEqual(
resp1.data["expiry"],
resp2.data["expiry"],
"token expiry should be renewed by login",
"token expiry should be same after login",
)
self.assertEqual(
resp1.data["token"],
resp2.data["token"],
"login should return existing token",
)

def test_login_renew_token_existing_client(self):
self.assertEqual(AuthToken.objects.count(), 0)
new_settings["REFRESH_TOKEN_ON_LOGIN"] = True
with override_settings(REST_DURIN=new_settings):
reload(views)
resp1 = self.client.post(login_url, self.creds, format="json")
self.assertEqual(resp1.status_code, 200)
self.assertIn("token", resp1.data)
resp2 = self.client.post(login_url, self.creds, format="json")
self.assertEqual(resp2.status_code, 200)
self.assertIn("token", resp2.data)

reload(views)
self.assertEqual(
AuthToken.objects.count(),
1,
"should renew token, instead of creating new.",
)
self.assertNotEqual(
resp1.data["expiry"],
resp2.data["expiry"],
"token expiry should be renewed after login",
)
self.assertEqual(
resp1.data["token"],
resp2.data["token"],
"token key must remain same",
)

def test_refresh_view_and_renewed_signal(self):
self.signal_was_called = False

Expand Down

0 comments on commit a9818e4

Please sign in to comment.