Skip to content

Commit

Permalink
Merge pull request #217 from bento-platform/feat/auth/include-patterns
Browse files Browse the repository at this point in the history
feat(auth): add include method/path arg for authz middleware
  • Loading branch information
davidlougheed authored Aug 5, 2024
2 parents 7407162 + 9b16807 commit c3346d3
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 10 deletions.
41 changes: 34 additions & 7 deletions bento_lib/auth/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,41 @@
__all__ = ["BaseAuthMiddleware"]


def _compile_to_regex_if_needed(pattern: str | re.Pattern) -> re.Pattern:
NonNormalizedPattern = re.Pattern | str

# Order: method pattern, path pattern
NonNormalizedRequestPattern = tuple[NonNormalizedPattern, NonNormalizedPattern]
NonNormalizedRequestPatterns = tuple[NonNormalizedRequestPattern, ...]
RequestPattern = tuple[re.Pattern, re.Pattern]
RequestPatterns = frozenset[RequestPattern]


def _compile_to_regex_if_needed(pattern: NonNormalizedPattern) -> re.Pattern:
if isinstance(pattern, str):
return re.compile(pattern)
return pattern


def _normalize_request_patterns(patterns: NonNormalizedRequestPatterns) -> RequestPatterns:
return frozenset(
(_compile_to_regex_if_needed(method_pattern), _compile_to_regex_if_needed(path_pattern))
for method_pattern, path_pattern in patterns
)


def _request_pattern_match(method: str, path: str, patterns: RequestPatterns) -> tuple[bool, ...]:
return tuple(bool(mp.fullmatch(method) and pp.fullmatch(path)) for mp, pp in patterns)


class BaseAuthMiddleware(ABC, MarkAuthzDoneMixin):
def __init__(
self,
bento_authz_service_url: str,
drs_compat: bool = False,
sr_compat: bool = False,
beacon_meta_callback: Callable[[], dict] | None = None,
exempt_request_patterns: tuple[tuple[re.Pattern | str, re.Pattern | str], ...] = (),
include_request_patterns: NonNormalizedRequestPatterns | None = None,
exempt_request_patterns: NonNormalizedRequestPatterns = (),
debug_mode: bool = False,
enabled: bool = True,
logger: logging.Logger | None = None,
Expand All @@ -43,10 +64,10 @@ def __init__(
self._sr_compat: bool = sr_compat
self._beacon_meta_callback: Callable[[], dict] | None = beacon_meta_callback

self._exempt_request_patterns: frozenset[tuple[re.Pattern, re.Pattern]] = frozenset(
(_compile_to_regex_if_needed(method_pattern), _compile_to_regex_if_needed(path_pattern))
for method_pattern, path_pattern in exempt_request_patterns
self._include_request_patterns: RequestPatterns | None = (
_normalize_request_patterns(include_request_patterns) if include_request_patterns is not None else None
)
self._exempt_request_patterns: RequestPatterns = _normalize_request_patterns(exempt_request_patterns)

self._bento_authz_service_url: str = bento_authz_service_url

Expand All @@ -65,8 +86,14 @@ def enabled(self) -> bool:
return self._enabled

def request_is_exempt(self, method: str, path: str) -> bool:
return method == "OPTIONS" or any(
mp.fullmatch(method) and pp.fullmatch(path) for mp, pp in self._exempt_request_patterns)
return (
method == "OPTIONS"
or (
self._include_request_patterns is not None
and not any(_request_pattern_match(method, path, self._include_request_patterns))
)
or any(_request_pattern_match(method, path, self._exempt_request_patterns))
)

@abstractmethod
def get_authz_header_value(self, request: Any) -> str | None: # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "bento-lib"
version = "11.11.2"
version = "11.12.0"
description = "A set of common utilities and helpers for Bento platform services."
authors = [
"David Lougheed <[email protected]>",
Expand Down
2 changes: 2 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path

__all__ = [
"authz_test_include_patterns",
"authz_test_exempt_patterns",
"authz_test_case_params",
"authz_test_cases",
Expand All @@ -14,6 +15,7 @@
]

# cases: (authz response code, authz response result, test client URL, auth header included, assert final response)
authz_test_include_patterns = ((r".*", re.compile(r"^/(get|post).*$")),)
authz_test_exempt_patterns = ((r"POST", re.compile(r"/post-exempted")),)
authz_test_case_params = "authz_code, authz_res, test_url, inc_headers, test_code"
authz_test_cases = (
Expand Down
17 changes: 16 additions & 1 deletion tests/test_platform_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from bento_lib.workflows.fastapi import build_workflow_router

from .common import (
authz_test_include_patterns,
authz_test_exempt_patterns,
authz_test_case_params,
authz_test_cases,
Expand Down Expand Up @@ -105,7 +106,11 @@ def get_403():
cors_origins=("*",),
)
auth_middleware = FastApiAuthMiddleware.build_from_pydantic_config(
app_test_auth_config, logger, exempt_request_patterns=authz_test_exempt_patterns)
app_test_auth_config,
logger,
include_request_patterns=authz_test_include_patterns,
exempt_request_patterns=authz_test_exempt_patterns,
)
app_test_auth = BentoFastAPI(auth_middleware, app_test_auth_config, logger, {}, TEST_APP_SERVICE_TYPE,
TEST_APP_VERSION)

Expand Down Expand Up @@ -208,6 +213,11 @@ async def auth_post_with_token_evaluate_to_dict(request: Request, body: TestToke
)})


@app_test_auth.put("/put-test")
async def auth_put_not_included(body: TestBody):
return JSONResponse(body.model_dump(mode="json"))


# Auth test app (disabled auth middleware) ------------------------------------

app_test_auth_disabled = FastAPI()
Expand Down Expand Up @@ -388,6 +398,11 @@ def test_fastapi_auth_post_with_token_evaluate_to_dict(aioresponse: aioresponses
assert r.text == '{"payload":[{"ingest:data":true}]}'


def test_fastapi_auth_put_not_included(fastapi_client_auth: TestClient):
r = fastapi_client_auth.put("/put-test", json=TEST_AUTHZ_VALID_POST_BODY) # no authz needed, not included
assert r.status_code == 200


@pytest.mark.asyncio
async def test_fastapi_auth_disabled(aioresponse: aioresponses, fastapi_client_auth_disabled: TestClient):
# middleware is disabled, should work anyway
Expand Down
14 changes: 13 additions & 1 deletion tests/test_platform_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from bento_lib.auth.resources import RESOURCE_EVERYTHING

from .common import (
authz_test_include_patterns,
authz_test_exempt_patterns,
authz_test_case_params,
authz_test_cases,
Expand Down Expand Up @@ -63,7 +64,9 @@ def flask_client_auth():
auth_middleware = FlaskAuthMiddleware(
bento_authz_service_url="https://bento-auth.local",
logger=logger,
exempt_request_patterns=authz_test_exempt_patterns)
include_request_patterns=authz_test_include_patterns,
exempt_request_patterns=authz_test_exempt_patterns,
)
auth_middleware.attach(test_app_auth)

test_app_auth.register_error_handler(
Expand Down Expand Up @@ -148,6 +151,10 @@ def auth_post_with_token_evaluate__to_dict():
headers_getter=(lambda _r: {"Authorization": f"Bearer {token}"}),
)})

@test_app_auth.route("/put-test", methods=["PUT"])
def auth_put_not_included():
return jsonify(request.json)

with test_app_auth.test_client() as client:
yield client

Expand Down Expand Up @@ -287,6 +294,11 @@ def test_flask_auth_post_with_token_evaluate_to_dict(flask_client_auth: FlaskCli
assert r.text == '{"payload":[{"ingest:data":true}]}\n'


def test_flask_auth_put_not_included(flask_client_auth: FlaskClient):
r = flask_client_auth.put("/put-test", json=TEST_AUTHZ_VALID_POST_BODY) # no authz needed, not included
assert r.status_code == 200


@responses.activate
def test_flask_auth_disabled(flask_client_auth_disabled_with_middleware: tuple[FlaskClient, FlaskAuthMiddleware]):
flask_client_auth_disabled, auth_middleware_disabled = flask_client_auth_disabled_with_middleware
Expand Down

0 comments on commit c3346d3

Please sign in to comment.