Skip to content

Commit

Permalink
fixed permission check before schema validation as described in #116, #…
Browse files Browse the repository at this point in the history
…192 issues
  • Loading branch information
eadwinCode committed Nov 16, 2024
1 parent 4e31ab4 commit 8ae0e27
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 19 deletions.
47 changes: 32 additions & 15 deletions ninja_extra/controllers/route/route_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

from django.http import HttpRequest, HttpResponse

from ...dependency_resolver import get_injector, service_resolver
from ninja_extra.dependency_resolver import get_injector, service_resolver

from .context import RouteContext, get_route_execution_context

if TYPE_CHECKING: # pragma: no cover
from ninja_extra.controllers.base import APIController, ControllerBase
from ninja_extra.controllers.route import Route
from ninja_extra.operation import Operation

from ...controllers.base import APIController, ControllerBase
from ...controllers.route import Route


class RouteFunctionContext:
def __init__(
Expand Down Expand Up @@ -74,6 +74,13 @@ def _resolve_api_func_signature_(self, context_func: Callable) -> Callable:
context_func.__signature__ = sig_replaced # type: ignore
return context_func

def run_permission_check(self, route_context: RouteContext) -> None:
_route_context = route_context or cast(
RouteContext, service_resolver(RouteContext)
)
with self._prep_controller_route_execution(_route_context) as ctx:
ctx.controller_instance.check_permissions()

def get_view_function(self) -> Callable:
def as_view(
request: HttpRequest,
Expand All @@ -85,23 +92,30 @@ def as_view(
RouteContext, service_resolver(RouteContext)
)
with self._prep_controller_route_execution(_route_context, **kwargs) as ctx:
ctx.controller_instance.check_permissions()
# ctx.controller_instance.check_permissions()
result = self.route.view_func(
ctx.controller_instance, *args, **ctx.view_func_kwargs
)
return self._process_view_function_result(result)
return result

as_view.get_route_function = lambda: self # type:ignore
return as_view

def _process_view_function_result(self, result: Any) -> Any:
"""
This process any an returned value from view_func
and creates an api response if result is ControllerResponseSchema
"""
This process any a returned value from view_func
and creates an api response if a result is ControllerResponseSchema
# if result and isinstance(result, ControllerResponse):
# return result.status_code, result.convert_to_schema()
deprecated:: 0.21.5
This method is deprecated and will be removed in a future version.
The result processing should be handled by the response handlers.
"""
warnings.warn(
"_process_view_function_result() is deprecated and will be removed in a future version. "
"The result processing should be handled by the response handlers.",
DeprecationWarning,
stacklevel=2,
)
return result

def _get_controller_instance(self) -> "ControllerBase":
Expand Down Expand Up @@ -163,24 +177,27 @@ def __repr__(self) -> str: # pragma: no cover


class AsyncRouteFunction(RouteFunction):
async def async_run_check_permissions(self, route_context: RouteContext) -> None:
from asgiref.sync import sync_to_async

await sync_to_async(self.run_permission_check)(route_context)

def get_view_function(self) -> Callable:
async def as_view(
request: HttpRequest,
route_context: Optional[RouteContext] = None,
*args: Any,
**kwargs: Any,
) -> Any:
from asgiref.sync import sync_to_async

_route_context = route_context or cast(
RouteContext, service_resolver(RouteContext)
)
with self._prep_controller_route_execution(_route_context, **kwargs) as ctx:
await sync_to_async(ctx.controller_instance.check_permissions)()
# await sync_to_async(ctx.controller_instance.check_permissions)()
result = await self.route.view_func(
ctx.controller_instance, *args, **ctx.view_func_kwargs
)
return self._process_view_function_result(result)
return result

as_view.get_route_function = lambda: self # type:ignore
return as_view
Expand Down
26 changes: 23 additions & 3 deletions ninja_extra/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@
from .details import ViewSignature

if TYPE_CHECKING: # pragma: no cover
from .controllers.route.route_functions import RouteFunction
from .controllers.route.route_functions import AsyncRouteFunction, RouteFunction


class Operation(NinjaOperation):
view_func: Callable

def __init__(
self,
path: str,
Expand Down Expand Up @@ -88,6 +90,16 @@ def _set_auth(
f"N:B - {get_function_name(callback)} can only be used on Asynchronous view functions"
)

def _get_route_function(
self,
) -> Optional[Union["RouteFunction", "AsyncRouteFunction"]]:
if hasattr(self.view_func, "get_route_function"):
return cast(
Union["RouteFunction", "AsyncRouteFunction"],
self.view_func.get_route_function(),
)
return None

def _log_action(
self,
logger: Callable[..., Any],
Expand All @@ -102,8 +114,8 @@ def _log_action(
f'{self.view_func.__name__} {request.path}" '
f"{duration if duration else ''}"
)
if hasattr(self.view_func, "get_route_function"):
route_function: "RouteFunction" = self.view_func.get_route_function()
route_function = self._get_route_function()
if route_function:
api_controller = route_function.get_api_controller()

msg = (
Expand Down Expand Up @@ -185,6 +197,10 @@ def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase:
with self._prep_run(
request, temporal_response=temporal_response, **kw
) as ctx:
route_function = self._get_route_function()
if route_function:
route_function.run_permission_check(ctx)

error = self._run_checks(request)
if error:
return error
Expand Down Expand Up @@ -309,6 +325,10 @@ async def run(self, request: HttpRequest, **kw: Any) -> HttpResponseBase: # typ
async with self._prep_run(
request, temporal_response=temporal_response, **kw
) as ctx:
route_function = self._get_route_function()
if route_function:
await route_function.async_run_check_permissions(ctx) # type: ignore[attr-defined]

error = await self._run_checks(request)
if error:
return error
Expand Down
30 changes: 29 additions & 1 deletion tests/test_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from unittest.mock import Mock

import pytest
from asgiref.sync import sync_to_async
from django.contrib.auth.models import AnonymousUser, User

from ninja_extra import ControllerBase, api_controller, http_get, permissions
from ninja_extra.testing import TestClient
from ninja_extra.testing import TestAsyncClient, TestClient

anonymous_request = Mock()
anonymous_request.user = AnonymousUser()
Expand Down Expand Up @@ -250,6 +251,13 @@ def index(self):
def permission_accept_type_and_instance(self):
return {"success": True}

@http_get(
"permission/async/",
permissions=[permissions.IsAdminUser() & permissions.IsAuthenticatedOrReadOnly],
)
async def permission_accept_type_and_instance_async(self):
return {"success": True}


@pytest.mark.django_db
@pytest.mark.parametrize("route", ["permission/", "index/"])
Expand All @@ -269,3 +277,23 @@ def test_permission_controller_instance(route):
res = client.get(route, user=user)
assert res.status_code == 200
assert res.json() == {"success": True}


@pytest.mark.django_db
@pytest.mark.asyncio
async def test_permission_controller_instance_async():
user = await sync_to_async(User.objects.create_user)(
username="eadwin",
email="[email protected]",
password="password",
is_staff=True,
is_superuser=True,
)

client = TestAsyncClient(Some2Controller)
res = await client.get("/permission/async/", user=AnonymousUser())
assert res.status_code == 403

res = await client.get("/permission/async/", user=user)
assert res.status_code == 200
assert res.json() == {"success": True}

0 comments on commit 8ae0e27

Please sign in to comment.