diff --git a/bento_authorization_service/cli.py b/bento_authorization_service/cli.py index a9b574d..bc3eb54 100644 --- a/bento_authorization_service/cli.py +++ b/bento_authorization_service/cli.py @@ -50,7 +50,7 @@ async def list_cmd(_config: Config, db: Database, args): async def create_grant(_config: Config, db: Database, args) -> int: - g, created = await db.create_grant( + g = await db.create_grant( GrantModel( subject=SubjectModel.model_validate_json(getattr(args, "subject", "null")), resource=ResourceModel.model_validate_json(getattr(args, "resource", "null")), @@ -60,7 +60,7 @@ async def create_grant(_config: Config, db: Database, args) -> int: ) ) - if created: + if g: print(f"Grant successfully created: {g}") return 0 @@ -154,7 +154,7 @@ async def delete_cmd(_config: Config, db: Database, args) -> int: async def assign_all_cmd(_config: Config, db: Database, args) -> int: - g, created = await db.create_grant( + g = await db.create_grant( GrantModel( subject=SubjectModel.model_validate({"iss": args.iss, "sub": args.sub}), resource=RESOURCE_EVERYTHING, @@ -164,7 +164,7 @@ async def assign_all_cmd(_config: Config, db: Database, args) -> int: ) ) - if created: + if g: print(f"Grant successfully created: {g}") return 0 diff --git a/bento_authorization_service/db.py b/bento_authorization_service/db.py index 07bfc38..9f38f68 100644 --- a/bento_authorization_service/db.py +++ b/bento_authorization_service/db.py @@ -202,7 +202,7 @@ async def get_grants(self) -> tuple[StoredGrantModel, ...]: ) return tuple(grant_db_deserialize(r) for r in res) - async def create_grant(self, grant: GrantModel) -> tuple[int | None, bool]: # id, created + async def create_grant(self, grant: GrantModel) -> int | None: conn: asyncpg.Connection async with self.connect() as conn: async with conn.transaction(): @@ -229,9 +229,9 @@ async def create_grant(self, grant: GrantModel) -> tuple[int | None, bool]: # i ) except AssertionError: # Failed for some reason - return None, False + return None - return res, res is not None + return res async def delete_grant(self, grant_id: int) -> None: conn: asyncpg.Connection diff --git a/bento_authorization_service/policy_engine/evaluation.py b/bento_authorization_service/policy_engine/evaluation.py index a5c751d..76855c4 100644 --- a/bento_authorization_service/policy_engine/evaluation.py +++ b/bento_authorization_service/policy_engine/evaluation.py @@ -327,10 +327,6 @@ def evaluate_on_resource_and_permission( return permission in permissions -async def _get_token_data(idp_manager: BaseIdPManager, token: TokenData | str | None) -> TokenData | None: - return (await idp_manager.decode(token)) if isinstance(token, str) else token - - async def evaluate( idp_manager: BaseIdPManager, db: Database, diff --git a/bento_authorization_service/routers/grants.py b/bento_authorization_service/routers/grants.py index ae076a7..440fd69 100644 --- a/bento_authorization_service/routers/grants.py +++ b/bento_authorization_service/routers/grants.py @@ -56,6 +56,8 @@ async def create_grant( idp_manager: IdPManagerDependency, authorization: OptionalBearerToken, ) -> StoredGrantModel: + # Make sure the token is allowed to edit permissions (in this case, 'editing permissions' + # extends to creating grants) on the resource in question. await raise_if_no_resource_access( request, extract_token(authorization), grant.resource, P_EDIT_PERMISSIONS, db, idp_manager ) @@ -63,19 +65,15 @@ async def create_grant( # Flag that we have thought about auth MarkAuthzDone.mark_authz_done(request) + # Forbid creating a grant which is expired from the get-go. if grant.expiry is not None and grant.expiry < datetime.now(timezone.utc): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Grant is already expired") - g_id, g_created = await db.create_grant(grant) - if g_id is not None: - if g_created: - if (g := await db.get_grant(g_id)) is not None: - return g # Successfully created, return - raise grant_could_not_be_created() # Somehow immediately removed - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Grant with this subject + resource + permission already exists", - ) + # Create the grant + if (g_id := await db.create_grant(grant)) is not None: + if (g := await db.get_grant(g_id)) is not None: + return g # Successfully created, return + raise grant_could_not_be_created() # Somehow immediately removed raise grant_could_not_be_created() diff --git a/bento_authorization_service/routers/policy.py b/bento_authorization_service/routers/policy.py index dee658a..f76cc36 100644 --- a/bento_authorization_service/routers/policy.py +++ b/bento_authorization_service/routers/policy.py @@ -32,12 +32,6 @@ class ListPermissionsResponse(BaseModel): result: list[list[str]] -def apply_scalar_or_vector(func: Callable[[T], U], v: T | tuple[T, ...]) -> U | tuple[U, ...]: - if isinstance(v, tuple): - return tuple(func(x) for x in v) - return func(v) - - def list_permissions_for_resource( grants: tuple[StoredGrantModel], groups: dict[int, StoredGroupModel], diff --git a/tests/test_grants.py b/tests/test_grants.py index f32d0e6..cad8926 100644 --- a/tests/test_grants.py +++ b/tests/test_grants.py @@ -107,7 +107,7 @@ async def test_grant_endpoints_get(test_client: TestClient, db: Database, db_cle assert res.status_code == status.HTTP_404_NOT_FOUND # create grant in database - g_id, _ = await db.create_grant(sd.TEST_GRANT_DAVID_PROJECT_1_QUERY_DATA) + g_id = await db.create_grant(sd.TEST_GRANT_DAVID_PROJECT_1_QUERY_DATA) db_grant: StoredGrantModel = await db.get_grant(g_id) # test that without a token, we cannot see anything @@ -126,7 +126,7 @@ async def test_grant_endpoints_list(test_client: TestClient, db: Database, db_cl headers = {"Authorization": f"Bearer {sd.make_fresh_david_token_encoded()}"} # create grant in database - g_id, _ = await db.create_grant(sd.TEST_GRANT_DAVID_PROJECT_1_QUERY_DATA) + g_id = await db.create_grant(sd.TEST_GRANT_DAVID_PROJECT_1_QUERY_DATA) db_grant: StoredGrantModel = await db.get_grant(g_id) db_grant_json = json.dumps(db_grant.model_dump(mode="json"), sort_keys=True) @@ -144,7 +144,7 @@ async def test_grant_endpoints_list(test_client: TestClient, db: Database, db_cl @pytest.mark.asyncio async def test_grant_endpoints_delete(auth_headers: dict[str, str], test_client: TestClient, db: Database, db_cleanup): # create grant in database - g_id, _ = await db.create_grant(sd.TEST_GRANT_DAVID_PROJECT_1_QUERY_DATA) + g_id = await db.create_grant(sd.TEST_GRANT_DAVID_PROJECT_1_QUERY_DATA) # test that without a token, we cannot delete it res = test_client.delete(f"/grants/{g_id}")