diff --git a/components/renku_data_services/session/apispec_base.py b/components/renku_data_services/session/apispec_base.py index a16833290..71347548c 100644 --- a/components/renku_data_services/session/apispec_base.py +++ b/components/renku_data_services/session/apispec_base.py @@ -1,5 +1,7 @@ """Base models for API specifications.""" +from typing import Any + from pydantic import BaseModel, field_validator from ulid import ULID @@ -12,8 +14,10 @@ class Config: from_attributes = True - @field_validator("id", mode="before", check_fields=False) + @field_validator("*", mode="before", check_fields=False) @classmethod - def serialize_id(cls, id: str | ULID) -> str: - """Custom serializer that can handle ULIDs.""" - return str(id) + def serialize_ulid(cls, value: Any) -> Any: + """Handle ULIDs.""" + if isinstance(value, ULID): + return str(value) + return value diff --git a/components/renku_data_services/session/blueprints.py b/components/renku_data_services/session/blueprints.py index 8fe42995b..c92df46a7 100644 --- a/components/renku_data_services/session/blueprints.py +++ b/components/renku_data_services/session/blueprints.py @@ -2,15 +2,22 @@ from dataclasses import dataclass -from sanic import HTTPResponse, Request, json +from sanic import HTTPResponse, Request from sanic.response import JSONResponse from sanic_ext import validate from ulid import ULID -import renku_data_services.base_models as base_models +from renku_data_services import base_models from renku_data_services.base_api.auth import authenticate, validate_path_project_id from renku_data_services.base_api.blueprint import BlueprintFactoryResponse, CustomBlueprint +from renku_data_services.base_models.validation import validated_json from renku_data_services.session import apispec +from renku_data_services.session.core import ( + validate_environment_patch, + validate_session_launcher_patch, + validate_unsaved_environment, + validate_unsaved_session_launcher, +) from renku_data_services.session.db import SessionRepository @@ -26,9 +33,7 @@ def get_all(self) -> BlueprintFactoryResponse: async def _get_all(_: Request) -> JSONResponse: environments = await self.session_repo.get_environments() - return json( - [apispec.Environment.model_validate(e).model_dump(exclude_none=True, mode="json") for e in environments] - ) + return validated_json(apispec.EnvironmentList, environments) return "/environments", ["GET"], _get_all @@ -37,7 +42,7 @@ def get_one(self) -> BlueprintFactoryResponse: async def _get_one(_: Request, environment_id: ULID) -> JSONResponse: environment = await self.session_repo.get_environment(environment_id=environment_id) - return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json")) + return validated_json(apispec.Environment, environment) return "/environments/", ["GET"], _get_one @@ -47,8 +52,9 @@ def post(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @validate(json=apispec.EnvironmentPost) async def _post(_: Request, user: base_models.APIUser, body: apispec.EnvironmentPost) -> JSONResponse: - environment = await self.session_repo.insert_environment(user=user, new_environment=body) - return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json"), 201) + new_environment = validate_unsaved_environment(body) + environment = await self.session_repo.insert_environment(user=user, environment=new_environment) + return validated_json(apispec.Environment, environment, status=201) return "/environments", ["POST"], _post @@ -60,11 +66,11 @@ def patch(self) -> BlueprintFactoryResponse: async def _patch( _: Request, user: base_models.APIUser, environment_id: ULID, body: apispec.EnvironmentPatch ) -> JSONResponse: - body_dict = body.model_dump(exclude_none=True) + environment_patch = validate_environment_patch(body) environment = await self.session_repo.update_environment( - user=user, environment_id=environment_id, **body_dict + user=user, environment_id=environment_id, patch=environment_patch ) - return json(apispec.Environment.model_validate(environment).model_dump(exclude_none=True, mode="json")) + return validated_json(apispec.Environment, environment) return "/environments/", ["PATCH"], _patch @@ -92,12 +98,7 @@ def get_all(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) async def _get_all(_: Request, user: base_models.APIUser) -> JSONResponse: launchers = await self.session_repo.get_launchers(user=user) - return json( - [ - apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json") - for item in launchers - ] - ) + return validated_json(apispec.SessionLaunchersList, launchers) return "/session_launchers", ["GET"], _get_all @@ -107,7 +108,7 @@ def get_one(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) async def _get_one(_: Request, user: base_models.APIUser, launcher_id: ULID) -> JSONResponse: launcher = await self.session_repo.get_launcher(user=user, launcher_id=launcher_id) - return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json")) + return validated_json(apispec.SessionLauncher, launcher) return "/session_launchers/", ["GET"], _get_one @@ -117,10 +118,9 @@ def post(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @validate(json=apispec.SessionLauncherPost) async def _post(_: Request, user: base_models.APIUser, body: apispec.SessionLauncherPost) -> JSONResponse: - launcher = await self.session_repo.insert_launcher(user=user, new_launcher=body) - return json( - apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json"), 201 - ) + new_launcher = validate_unsaved_session_launcher(body) + launcher = await self.session_repo.insert_launcher(user=user, launcher=new_launcher) + return validated_json(apispec.SessionLauncher, launcher, status=201) return "/session_launchers", ["POST"], _post @@ -132,9 +132,9 @@ def patch(self) -> BlueprintFactoryResponse: async def _patch( _: Request, user: base_models.APIUser, launcher_id: ULID, body: apispec.SessionLauncherPatch ) -> JSONResponse: - body_dict = body.model_dump(exclude_none=True) - launcher = await self.session_repo.update_launcher(user=user, launcher_id=launcher_id, **body_dict) - return json(apispec.SessionLauncher.model_validate(launcher).model_dump(exclude_none=True, mode="json")) + launcher_patch = validate_session_launcher_patch(body) + launcher = await self.session_repo.update_launcher(user=user, launcher_id=launcher_id, patch=launcher_patch) + return validated_json(apispec.SessionLauncher, launcher) return "/session_launchers/", ["PATCH"], _patch @@ -155,11 +155,6 @@ def get_project_launchers(self) -> BlueprintFactoryResponse: @validate_path_project_id async def _get_launcher(_: Request, user: base_models.APIUser, project_id: str) -> JSONResponse: launchers = await self.session_repo.get_project_launchers(user=user, project_id=project_id) - return json( - [ - apispec.SessionLauncher.model_validate(item).model_dump(exclude_none=True, mode="json") - for item in launchers - ] - ) + return validated_json(apispec.SessionLaunchersList, launchers) return "/projects//session_launchers", ["GET"], _get_launcher diff --git a/components/renku_data_services/session/core.py b/components/renku_data_services/session/core.py new file mode 100644 index 000000000..09a29839b --- /dev/null +++ b/components/renku_data_services/session/core.py @@ -0,0 +1,52 @@ +"""Business logic for sessions.""" + +from ulid import ULID + +from renku_data_services.session import apispec, models + + +def validate_unsaved_environment(environment: apispec.EnvironmentPost) -> models.UnsavedEnvironment: + """Validate an unsaved session environment.""" + return models.UnsavedEnvironment( + name=environment.name, + description=environment.description, + container_image=environment.container_image, + default_url=environment.default_url, + ) + + +def validate_environment_patch(patch: apispec.EnvironmentPatch) -> models.EnvironmentPatch: + """Validate the update to a session environment.""" + return models.EnvironmentPatch( + name=patch.name, + description=patch.description, + container_image=patch.container_image, + default_url=patch.default_url, + ) + + +def validate_unsaved_session_launcher(launcher: apispec.SessionLauncherPost) -> models.UnsavedSessionLauncher: + """Validate an unsaved session launcher.""" + return models.UnsavedSessionLauncher( + project_id=ULID.from_str(launcher.project_id), + name=launcher.name, + description=launcher.description, + environment_kind=launcher.environment_kind, + environment_id=launcher.environment_id, + resource_class_id=launcher.resource_class_id, + container_image=launcher.container_image, + default_url=launcher.default_url, + ) + + +def validate_session_launcher_patch(patch: apispec.SessionLauncherPatch) -> models.SessionLauncherPatch: + """Validate the update to a session launcher.""" + return models.SessionLauncherPatch( + name=patch.name, + description=patch.description, + environment_kind=patch.environment_kind, + environment_id=patch.environment_id, + resource_class_id=patch.resource_class_id, + container_image=patch.container_image, + default_url=patch.default_url, + ) diff --git a/components/renku_data_services/session/db.py b/components/renku_data_services/session/db.py index 06f9c89b7..3d2b2483d 100644 --- a/components/renku_data_services/session/db.py +++ b/components/renku_data_services/session/db.py @@ -4,7 +4,6 @@ from collections.abc import Callable from datetime import UTC, datetime -from typing import Any from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -15,7 +14,7 @@ from renku_data_services.authz.authz import Authz, ResourceType from renku_data_services.authz.models import Scope from renku_data_services.crc.db import ResourcePoolRepository -from renku_data_services.session import apispec, models +from renku_data_services.session import models from renku_data_services.session import orm as schemas from renku_data_services.session.apispec import EnvironmentKind @@ -51,9 +50,7 @@ async def get_environment(self, environment_id: ULID) -> models.Environment: return environment.dump() async def insert_environment( - self, - user: base_models.APIUser, - new_environment: apispec.EnvironmentPost, + self, user: base_models.APIUser, environment: models.UnsavedEnvironment ) -> models.Environment: """Insert a new session environment.""" if user.id is None: @@ -61,23 +58,21 @@ async def insert_environment( if not user.is_admin: raise errors.ForbiddenError(message="You do not have the required permissions for this operation.") - environment_model = models.Environment( - id=None, - name=new_environment.name, - description=new_environment.description, - container_image=new_environment.container_image, - default_url=new_environment.default_url, - created_by=models.Member(id=user.id), + environment_orm = schemas.EnvironmentORM( + name=environment.name, + description=environment.description if environment.description else None, + container_image=environment.container_image, + default_url=environment.default_url if environment.default_url else None, + created_by_id=user.id, creation_date=datetime.now(UTC).replace(microsecond=0), ) - environment = schemas.EnvironmentORM.load(environment_model) async with self.session_maker() as session, session.begin(): - session.add(environment) - return environment.dump() + session.add(environment_orm) + return environment_orm.dump() async def update_environment( - self, user: base_models.APIUser, environment_id: ULID, **kwargs: dict + self, user: base_models.APIUser, environment_id: ULID, patch: models.EnvironmentPatch ) -> models.Environment: """Update a session environment entry.""" if not user.is_admin: @@ -93,10 +88,14 @@ async def update_environment( message=f"Session environment with id '{environment_id}' does not exist." ) - for key, value in kwargs.items(): - # NOTE: Only ``name``, ``description``, ``container_image`` and ``default_url`` can be edited - if key in ["name", "description", "container_image", "default_url"]: - setattr(environment, key, value) + if patch.name is not None: + environment.name = patch.name + if patch.description is not None: + environment.description = patch.description if patch.description else None + if patch.container_image is not None: + environment.container_image = patch.container_image + if patch.default_url is not None: + environment.default_url = patch.default_url if patch.default_url else None return environment.dump() @@ -171,37 +170,32 @@ async def get_launcher(self, user: base_models.APIUser, launcher_id: ULID) -> mo return launcher.dump() async def insert_launcher( - self, user: base_models.APIUser, new_launcher: apispec.SessionLauncherPost + self, user: base_models.APIUser, launcher: models.UnsavedSessionLauncher ) -> models.SessionLauncher: """Insert a new session launcher.""" if not user.is_authenticated or user.id is None: raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.") - project_id = new_launcher.project_id - authorized = await self.project_authz.has_permission( - user, ResourceType.project, ULID.from_str(project_id), Scope.WRITE - ) + project_id = launcher.project_id + authorized = await self.project_authz.has_permission(user, ResourceType.project, project_id, Scope.WRITE) if not authorized: raise errors.MissingResourceError( message=f"Project with id '{project_id}' does not exist or you do not have access to it." ) - launcher_model = models.SessionLauncher( - id=None, - name=new_launcher.name, - project_id=new_launcher.project_id, - description=new_launcher.description, - environment_kind=new_launcher.environment_kind, - environment_id=new_launcher.environment_id, - resource_class_id=new_launcher.resource_class_id, - container_image=new_launcher.container_image, - default_url=new_launcher.default_url, - created_by=models.Member(id=user.id), + launcher_orm = schemas.SessionLauncherORM( + name=launcher.name, + project_id=launcher.project_id, + description=launcher.description if launcher.description else None, + environment_kind=launcher.environment_kind, + environment_id=launcher.environment_id, + resource_class_id=launcher.resource_class_id, + container_image=launcher.container_image if launcher.container_image else None, + default_url=launcher.default_url if launcher.default_url else None, + created_by_id=user.id, creation_date=datetime.now(UTC).replace(microsecond=0), ) - models.SessionLauncher.model_validate(launcher_model) - async with self.session_maker() as session, session.begin(): res = await session.scalars(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id)) project = res.one_or_none() @@ -210,7 +204,7 @@ async def insert_launcher( message=f"Project with id '{project_id}' does not exist or you do not have access to it." ) - environment_id = new_launcher.environment_id + environment_id = launcher_orm.environment_id if environment_id is not None: res = await session.scalars( select(schemas.EnvironmentORM).where(schemas.EnvironmentORM.id == environment_id) @@ -221,7 +215,7 @@ async def insert_launcher( message=f"Session environment with id '{environment_id}' does not exist or you do not have access to it." # noqa: E501 ) - resource_class_id = new_launcher.resource_class_id + resource_class_id = launcher_orm.resource_class_id if resource_class_id is not None: res = await session.scalars( select(schemas.ResourceClassORM).where(schemas.ResourceClassORM.id == resource_class_id) @@ -239,12 +233,11 @@ async def insert_launcher( message=f"You do not have access to resource class with id '{resource_class_id}'." ) - launcher = schemas.SessionLauncherORM.load(launcher_model) - session.add(launcher) - return launcher.dump() + session.add(launcher_orm) + return launcher_orm.dump() async def update_launcher( - self, user: base_models.APIUser, launcher_id: ULID, **kwargs: Any + self, user: base_models.APIUser, launcher_id: ULID, patch: models.SessionLauncherPatch ) -> models.SessionLauncher: """Update a session launcher entry.""" if not user.is_authenticated or user.id is None: @@ -269,7 +262,7 @@ async def update_launcher( if not authorized: raise errors.ForbiddenError(message="You do not have the required permissions for this operation.") - environment_id = kwargs.get("environment_id") + environment_id = patch.environment_id if environment_id is not None: res = await session.scalars( select(schemas.EnvironmentORM).where(schemas.EnvironmentORM.id == environment_id) @@ -280,7 +273,7 @@ async def update_launcher( message=f"Session environment with id '{environment_id}' does not exist or you do not have access to it." # noqa: E501 ) - resource_class_id = kwargs.get("resource_class_id") + resource_class_id = patch.resource_class_id if resource_class_id is not None: res = await session.scalars( select(schemas.ResourceClassORM).where(schemas.ResourceClassORM.id == resource_class_id) @@ -298,20 +291,20 @@ async def update_launcher( message=f"You do not have access to resource class with id '{resource_class_id}'." ) - for key, value in kwargs.items(): - # NOTE: Only ``name``, ``description``, ``environment_kind``, - # ``environment_id``, ``resource_class_id``, ``container_image`` and - # ``default_url`` can be edited. - if key in [ - "name", - "description", - "environment_kind", - "environment_id", - "resource_class_id", - "container_image", - "default_url", - ]: - setattr(launcher, key, value) + if patch.name is not None: + launcher.name = patch.name + if patch.description is not None: + launcher.description = patch.description if patch.description else None + if patch.environment_kind is not None: + launcher.environment_kind = patch.environment_kind + if patch.environment_id is not None: + launcher.environment_id = patch.environment_id + if patch.resource_class_id is not None: + launcher.resource_class_id = patch.resource_class_id + if patch.container_image is not None: + launcher.container_image = patch.container_image if patch.container_image else None + if patch.default_url is not None: + launcher.default_url = patch.default_url if patch.default_url else None if launcher.environment_kind == EnvironmentKind.global_environment: launcher.container_image = None diff --git a/components/renku_data_services/session/models.py b/components/renku_data_services/session/models.py index 422524152..99626ff74 100644 --- a/components/renku_data_services/session/models.py +++ b/components/renku_data_services/session/models.py @@ -1,4 +1,4 @@ -"""Models for Sessions.""" +"""Models for sessions.""" from dataclasses import dataclass from datetime import datetime @@ -11,43 +11,55 @@ @dataclass(frozen=True, eq=True, kw_only=True) -class Member(BaseModel): +class Member: """Member model.""" id: str @dataclass(frozen=True, eq=True, kw_only=True) -class Environment(BaseModel): - """Session environment model.""" +class UnsavedEnvironment: + """A session environment that hasn't been stored in the database.""" - id: str | None name: str - creation_date: datetime description: str | None container_image: str default_url: str | None + + +@dataclass(frozen=True, eq=True, kw_only=True) +class Environment(UnsavedEnvironment): + """Session environment model.""" + + id: str + creation_date: datetime created_by: Member @dataclass(frozen=True, eq=True, kw_only=True) -class SessionLauncher(BaseModel): - """Session launcher model.""" +class EnvironmentPatch: + """Model for changes requested on a session environment.""" + + name: str | None = None + description: str | None = None + container_image: str | None = None + default_url: str | None = None + - id: ULID | None - project_id: str +class UnsavedSessionLauncher(BaseModel): + """A session launcher that hasn't been stored in the database.""" + + project_id: ULID name: str - creation_date: datetime description: str | None environment_kind: EnvironmentKind environment_id: str | None resource_class_id: int | None container_image: str | None default_url: str | None - created_by: Member @model_validator(mode="after") - def check_launcher_environment_kind(self) -> "SessionLauncher": + def check_launcher_environment_kind(self) -> "UnsavedSessionLauncher": """Validates the environment of a launcher.""" environment_kind = self.environment_kind @@ -61,3 +73,24 @@ def check_launcher_environment_kind(self) -> "SessionLauncher": raise errors.ValidationError(message="'container_image' not set when environment_kind=container_image") return self + + +class SessionLauncher(UnsavedSessionLauncher): + """Session launcher model.""" + + id: ULID + creation_date: datetime + created_by: Member + + +@dataclass(frozen=True, eq=True, kw_only=True) +class SessionLauncherPatch: + """Model for changes requested on a session launcher.""" + + name: str | None = None + description: str | None = None + environment_kind: EnvironmentKind | None = None + environment_id: str | None = None + resource_class_id: int | None = None + container_image: str | None = None + default_url: str | None = None diff --git a/components/renku_data_services/session/orm.py b/components/renku_data_services/session/orm.py index 4b61d548c..8ed0792b4 100644 --- a/components/renku_data_services/session/orm.py +++ b/components/renku_data_services/session/orm.py @@ -48,18 +48,6 @@ class EnvironmentORM(BaseORM): default_url: Mapped[str | None] = mapped_column("default_url", String(200)) """Default URL path to open in a session.""" - @classmethod - def load(cls, environment: models.Environment) -> "EnvironmentORM": - """Create EnvironmentORM from the session environment model.""" - return cls( - name=environment.name, - created_by_id=environment.created_by.id, - creation_date=environment.creation_date, - description=environment.description, - container_image=environment.container_image, - default_url=environment.default_url, - ) - def dump(self) -> models.Environment: """Create a session environment model from the EnvironmentORM.""" return models.Environment( @@ -124,34 +112,18 @@ class SessionLauncherORM(BaseORM): ) """Id of the resource class.""" - @classmethod - def load(cls, launcher: models.SessionLauncher) -> "SessionLauncherORM": - """Create SessionLauncherORM from the session launcher model.""" - return cls( - name=launcher.name, - created_by_id=launcher.created_by.id, - creation_date=launcher.creation_date, - description=launcher.description, - environment_kind=launcher.environment_kind, - container_image=launcher.container_image, - project_id=ULID.from_str(launcher.project_id), - environment_id=launcher.environment_id, - resource_class_id=launcher.resource_class_id, - default_url=launcher.default_url, - ) - def dump(self) -> models.SessionLauncher: """Create a session launcher model from the SessionLauncherORM.""" return models.SessionLauncher( id=self.id, - project_id=str(self.project_id), + project_id=self.project_id, name=self.name, created_by=models.Member(id=self.created_by_id), creation_date=self.creation_date, description=self.description, environment_kind=self.environment_kind, - environment_id=self.environment_id if self.environment_id is not None else None, - resource_class_id=self.resource_class_id if self.resource_class_id is not None else None, + environment_id=self.environment_id, + resource_class_id=self.resource_class_id, container_image=self.container_image, default_url=self.default_url, )