Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use typed patch instance for session updates #500

Merged
merged 11 commits into from
Nov 5, 2024
12 changes: 8 additions & 4 deletions components/renku_data_services/session/apispec_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Base models for API specifications."""

from typing import Any

from pydantic import BaseModel, field_validator
from ulid import ULID

Expand All @@ -12,8 +14,10 @@ class Config:

from_attributes = True

@field_validator("id", mode="before", check_fields=False)
@field_validator("*", mode="before", check_fields=False)
@classmethod
def serialize_id(cls, id: str | ULID) -> str:
"""Custom serializer that can handle ULIDs."""
return str(id)
def serialize_ulid(cls, value: Any) -> Any:
"""Handle ULIDs."""
if isinstance(value, ULID):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a nice way to handle it 👍 , we might want to add that for the other apispec_base.py's as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that we should have a common apispec_base.py file because otherwise we have a lot of code fragmentation.

return str(value)
return value
57 changes: 26 additions & 31 deletions components/renku_data_services/session/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@

from dataclasses import dataclass

from sanic import HTTPResponse, Request, json
from sanic import HTTPResponse, Request
from sanic.response import JSONResponse
from sanic_ext import validate
from ulid import ULID

import renku_data_services.base_models as base_models
from renku_data_services import base_models
from renku_data_services.base_api.auth import authenticate, validate_path_project_id
from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint
from renku_data_services.base_models.validation import validated_json
from renku_data_services.session import apispec
from renku_data_services.session.core import (
validate_environment_patch,
validate_session_launcher_patch,
validate_unsaved_environment,
validate_unsaved_session_launcher,
)
from renku_data_services.session.db import SessionRepository


Expand All @@ -26,9 +33,7 @@ def get_all(self) -> BlueprintFactoryResponse:

async def _get_all(_: Request) -> JSONResponse:
environments = await self.session_repo.get_environments()
return json(
[apispec.Environment.model_validate(e).model_dump(exclude_none=True, mode="json") for e in environments]
)
return validated_json(apispec.EnvironmentList, environments)

return "/environments", ["GET"], _get_all

Expand All @@ -37,7 +42,7 @@ def get_one(self) -> BlueprintFactoryResponse:

async def _get_one(_: Request, environment_id: ULID) -> JSONResponse:
environment = await self.session_repo.get_environment(environment_id=environment_id)
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.Environment, environment)

return "/environments/<environment_id:ulid>", ["GET"], _get_one

Expand All @@ -47,8 +52,9 @@ def post(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
@validate(json=apispec.EnvironmentPost)
async def _post(_: Request, user: base_models.APIUser, body: apispec.EnvironmentPost) -> JSONResponse:
environment = await self.session_repo.insert_environment(user=user, new_environment=body)
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"), 201)
new_environment = validate_unsaved_environment(body)
environment = await self.session_repo.insert_environment(user=user, environment=new_environment)
return validated_json(apispec.Environment, environment, status=201)

return "/environments", ["POST"], _post

Expand All @@ -60,11 +66,11 @@ def patch(self) -> BlueprintFactoryResponse:
async def _patch(
_: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch
) -> JSONResponse:
body_dict = body.model_dump(exclude_none=True)
environment_patch = validate_environment_patch(body)
environment = await self.session_repo.update_environment(
user=user, environment_id=environment_id, **body_dict
user=user, environment_id=environment_id, patch=environment_patch
)
return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.Environment, environment)

return "/environments/<environment_id:ulid>", ["PATCH"], _patch

Expand Down Expand Up @@ -92,12 +98,7 @@ def get_all(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_all(_: Request, user: base_models.APIUser) -> JSONResponse:
launchers = await self.session_repo.get_launchers(user=user)
return json(
[
apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json")
for item in launchers
]
)
return validated_json(apispec.SessionLaunchersList, launchers)

return "/session_launchers", ["GET"], _get_all

Expand All @@ -107,7 +108,7 @@ def get_one(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
async def _get_one(_: Request, user: base_models.APIUser, launcher_id: ULID) -> JSONResponse:
launcher = await self.session_repo.get_launcher(user=user, launcher_id=launcher_id)
return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"))
return validated_json(apispec.SessionLauncher, launcher)

return "/session_launchers/<launcher_id:ulid>", ["GET"], _get_one

Expand All @@ -117,10 +118,9 @@ def post(self) -> BlueprintFactoryResponse:
@authenticate(self.authenticator)
@validate(json=apispec.SessionLauncherPost)
async def _post(_: Request, user: base_models.APIUser, body: apispec.SessionLauncherPost) -> JSONResponse:
launcher = await self.session_repo.insert_launcher(user=user, new_launcher=body)
return json(
apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"), 201
)
new_launcher = validate_unsaved_session_launcher(body)
launcher = await self.session_repo.insert_launcher(user=user, launcher=new_launcher)
return validated_json(apispec.SessionLauncher, launcher, status=201)

return "/session_launchers", ["POST"], _post

Expand All @@ -132,9 +132,9 @@ def patch(self) -> BlueprintFactoryResponse:
async def _patch(
_: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch
) -> JSONResponse:
body_dict = body.model_dump(exclude_none=True)
launcher = await self.session_repo.update_launcher(user=user, launcher_id=launcher_id, **body_dict)
return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"))
launcher_patch = validate_session_launcher_patch(body)
launcher = await self.session_repo.update_launcher(user=user, launcher_id=launcher_id, patch=launcher_patch)
return validated_json(apispec.SessionLauncher, launcher)

return "/session_launchers/<launcher_id:ulid>", ["PATCH"], _patch

Expand All @@ -155,11 +155,6 @@ def get_project_launchers(self) -> BlueprintFactoryResponse:
@validate_path_project_id
async def _get_launcher(_: Request, user: base_models.APIUser, project_id: str) -> JSONResponse:
launchers = await self.session_repo.get_project_launchers(user=user, project_id=project_id)
return json(
[
apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json")
for item in launchers
]
)
return validated_json(apispec.SessionLaunchersList, launchers)

return "/projects/<project_id>/session_launchers", ["GET"], _get_launcher
52 changes: 52 additions & 0 deletions components/renku_data_services/session/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Business logic for sessions."""

from ulid import ULID

from renku_data_services.session import apispec, models


def validate_unsaved_environment(environment: apispec.EnvironmentPost) -> models.UnsavedEnvironment:
"""Validate an unsaved session environment."""
return models.UnsavedEnvironment(
name=environment.name,
description=environment.description,
container_image=environment.container_image,
default_url=environment.default_url,
)


def validate_environment_patch(patch: apispec.EnvironmentPatch) -> models.EnvironmentPatch:
"""Validate the update to a session environment."""
return models.EnvironmentPatch(
name=patch.name,
description=patch.description,
container_image=patch.container_image,
default_url=patch.default_url,
)


def validate_unsaved_session_launcher(launcher: apispec.SessionLauncherPost) -> models.UnsavedSessionLauncher:
"""Validate an unsaved session launcher."""
return models.UnsavedSessionLauncher(
project_id=ULID.from_str(launcher.project_id),
name=launcher.name,
description=launcher.description,
environment_kind=launcher.environment_kind,
environment_id=launcher.environment_id,
resource_class_id=launcher.resource_class_id,
container_image=launcher.container_image,
default_url=launcher.default_url,
)


def validate_session_launcher_patch(patch: apispec.SessionLauncherPatch) -> models.SessionLauncherPatch:
"""Validate the update to a session launcher."""
return models.SessionLauncherPatch(
name=patch.name,
description=patch.description,
environment_kind=patch.environment_kind,
environment_id=patch.environment_id,
resource_class_id=patch.resource_class_id,
container_image=patch.container_image,
default_url=patch.default_url,
)
Loading
Loading