Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use typed patch instance for session updates #500

Merged
merged 11 commits into from
Nov 5, 2024
88 changes: 56 additions & 32 deletions components/renku_data_services/session/blueprints.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
"""Session blueprint."""

from dataclasses import dataclass
from typing import Any

from sanic import HTTPResponse, Request, json
from sanic import HTTPResponse, Request
from sanic.response import JSONResponse
from sanic_ext import validate
from ulid import ULID

import renku_data_services.base_models as base_models
from renku_data_services import base_models
from renku_data_services.base_api.auth import authenticate, validate_path_project_id
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.session import apispec
from renku_data_services.base_models.validation import validated_json
from renku_data_services.session import apispec, models
from renku_data_services.session.core import (
validate_environment_patch,
validate_session_launcher_patch,
validate_unsaved_environment,
validate_unsaved_session_launcher,
)
from renku_data_services.session.db import SessionRepository


Expand All @@ -26,9 +34,7 @@ def get_all(self) -> BlueprintFactoryResponse:

async def _get_all(_: Request) -> JSONResponse:
environments = await self.session_repo.get_environments()
return json(
[apispec.Environment.model_validate(e).model_dump(exclude_none=True, mode="json") for e in environments]
)
return validated_json(apispec.EnvironmentList, [self._dump_environment(e) for e in environments])
leafty marked this conversation as resolved.
Show resolved Hide resolved

return "/environments", ["GET"], _get_all

Expand All @@ -37,7 +43,7 @@ def get_one(self) -> BlueprintFactoryResponse:

async def _get_one(_: Request, environment_id: ULID) -> JSONResponse:
environment = await self.session_repo.get_environment(environment_id=environment_id)
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.Environment, environment)

return "/environments/<environment_id:ulid>", ["GET"], _get_one

Expand All @@ -47,8 +53,9 @@ def post(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
@validate(json=apispec.EnvironmentPost)
async def _post(_: Request, user: base_models.APIUser, body: apispec.EnvironmentPost) -> JSONResponse:
environment = await self.session_repo.insert_environment(user=user, new_environment=body)
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"), 201)
new_environment = validate_unsaved_environment(body)
environment = await self.session_repo.insert_environment(user=user, environment=new_environment)
return validated_json(apispec.Environment, environment, status=201)

return "/environments", ["POST"], _post

Expand All @@ -60,11 +67,11 @@ def patch(self) -> BlueprintFactoryResponse:
async def _patch(
_: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch
) -> JSONResponse:
body_dict = body.model_dump(exclude_none=True)
environment_patch = validate_environment_patch(body)
environment = await self.session_repo.update_environment(
user=user, environment_id=environment_id, **body_dict
user=user, environment_id=environment_id, patch=environment_patch
)
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.Environment, environment)

return "/environments/<environment_id:ulid>", ["PATCH"], _patch

Expand All @@ -78,6 +85,18 @@ async def _delete(_: Request, user: base_models.APIUser, environment_id: ULID) -

return "/environments/<environment_id:ulid>", ["DELETE"], _delete

@staticmethod
def _dump_environment(environment: models.Environment) -> dict[str, Any]:
"""Dumps a session environment for API responses."""
return dict(
id=environment.id,
name=environment.name,
creation_date=environment.creation_date,
description=environment.description,
container_image=environment.container_image,
default_url=environment.default_url,
)


@dataclass(kw_only=True)
class SessionLaunchersBP(CustomBlueprint):
Expand All @@ -92,12 +111,7 @@ def get_all(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_all(_: Request, user: base_models.APIUser) -> JSONResponse:
launchers = await self.session_repo.get_launchers(user=user)
return json(
[
apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json")
for item in launchers
]
)
return validated_json(apispec.SessionLaunchersList, [self._dump_launcher(item) for item in launchers])
leafty marked this conversation as resolved.
Show resolved Hide resolved

return "/session_launchers", ["GET"], _get_all

Expand All @@ -107,7 +121,7 @@ def get_one(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_one(_: Request, user: base_models.APIUser, launcher_id: ULID) -> JSONResponse:
launcher = await self.session_repo.get_launcher(user=user, launcher_id=launcher_id)
return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.SessionLauncher, self._dump_launcher(launcher))

return "/session_launchers/<launcher_id:ulid>", ["GET"], _get_one

Expand All @@ -117,10 +131,9 @@ def post(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
@validate(json=apispec.SessionLauncherPost)
async def _post(_: Request, user: base_models.APIUser, body: apispec.SessionLauncherPost) -> JSONResponse:
launcher = await self.session_repo.insert_launcher(user=user, new_launcher=body)
return json(
apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"), 201
)
new_launcher = validate_unsaved_session_launcher(body)
launcher = await self.session_repo.insert_launcher(user=user, launcher=new_launcher)
return validated_json(apispec.SessionLauncher, self._dump_launcher(launcher), status=201)

return "/session_launchers", ["POST"], _post

Expand All @@ -132,9 +145,9 @@ def patch(self) -> BlueprintFactoryResponse:
async def _patch(
_: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch
) -> JSONResponse:
body_dict = body.model_dump(exclude_none=True)
launcher = await self.session_repo.update_launcher(user=user, launcher_id=launcher_id, **body_dict)
return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"))
launcher_patch = validate_session_launcher_patch(body)
launcher = await self.session_repo.update_launcher(user=user, launcher_id=launcher_id, patch=launcher_patch)
return validated_json(apispec.SessionLauncher, self._dump_launcher(launcher))

return "/session_launchers/<launcher_id:ulid>", ["PATCH"], _patch

Expand All @@ -155,11 +168,22 @@ def get_project_launchers(self) -> BlueprintFactoryResponse:
@validate_path_project_id
async def _get_launcher(_: Request, user: base_models.APIUser, project_id: str) -> JSONResponse:
launchers = await self.session_repo.get_project_launchers(user=user, project_id=project_id)
return json(
[
apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json")
for item in launchers
]
)
return validated_json(apispec.SessionLaunchersList, [self._dump_launcher(item) for item in launchers])

return "/projects/<project_id>/session_launchers", ["GET"], _get_launcher

@staticmethod
def _dump_launcher(launcher: models.SessionLauncher) -> dict[str, Any]:
"""Dumps a session launcher for API responses."""
return dict(
id=str(launcher.id),
project_id=str(launcher.project_id),
name=launcher.name,
creation_date=launcher.creation_date,
description=launcher.description,
environment_kind=launcher.environment_kind,
environment_id=launcher.environment_id,
resource_class_id=launcher.resource_class_id,
container_image=launcher.container_image,
default_url=launcher.default_url,
)
52 changes: 52 additions & 0 deletions components/renku_data_services/session/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Business logic for sessions."""

from ulid import ULID

from renku_data_services.session import apispec, models


def validate_unsaved_environment(environment: apispec.EnvironmentPost) -> models.UnsavedEnvironment:
"""Validate an unsaved session environment."""
return models.UnsavedEnvironment(
name=environment.name,
description=environment.description,
container_image=environment.container_image,
default_url=environment.default_url,
)


def validate_environment_patch(patch: apispec.EnvironmentPatch) -> models.EnvironmentPatch:
"""Validate the update to a session environment."""
return models.EnvironmentPatch(
name=patch.name,
description=patch.description,
container_image=patch.container_image,
default_url=patch.default_url,
)


def validate_unsaved_session_launcher(launcher: apispec.SessionLauncherPost) -> models.UnsavedSessionLauncher:
"""Validate an unsaved session launcher."""
return models.UnsavedSessionLauncher(
project_id=ULID.from_str(launcher.project_id),
name=launcher.name,
description=launcher.description,
environment_kind=launcher.environment_kind,
environment_id=launcher.environment_id,
resource_class_id=launcher.resource_class_id,
container_image=launcher.container_image,
default_url=launcher.default_url,
)


def validate_session_launcher_patch(patch: apispec.SessionLauncherPatch) -> models.SessionLauncherPatch:
"""Validate the update to a session launcher."""
return models.SessionLauncherPatch(
name=patch.name,
description=patch.description,
environment_kind=patch.environment_kind,
environment_id=patch.environment_id,
resource_class_id=patch.resource_class_id,
container_image=patch.container_image,
default_url=patch.default_url,
)
Loading
Loading