Skip to content

Commit

Permalink
refactor: use typed patch instance for user updates (#497)
Browse files Browse the repository at this point in the history
See #485.
  • Loading branch information
leafty authored Nov 1, 2024
1 parent 0cb8a17 commit f1e525a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 27 deletions.
66 changes: 39 additions & 27 deletions components/renku_data_services/users/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import secrets
from collections.abc import AsyncGenerator, Callable
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from typing import Any, cast

from sanic.log import logger
Expand All @@ -28,9 +28,11 @@
DeletedUser,
KeycloakAdminEvent,
PinnedProjects,
UnsavedUserInfo,
UserInfo,
UserInfoFieldUpdate,
UserInfoUpdate,
UserPatch,
UserPreferences,
)
from renku_data_services.users.orm import LastKeycloakEventTimestamp, UserORM, UserPreferencesORM
Expand Down Expand Up @@ -65,12 +67,12 @@ async def _add_api_user(self, user: APIUser) -> UserInfo:
if not user.id:
raise errors.UnauthorizedError(message="The user has to be authenticated to be inserted in the DB.")
result = await self._users_sync.update_or_insert_user(
user_id=user.id,
payload=dict(
user=UnsavedUserInfo(
id=user.id,
email=user.email,
first_name=user.first_name,
last_name=user.last_name,
email=user.email,
),
)
)
return result.new

Expand Down Expand Up @@ -214,37 +216,44 @@ async def _get_user(self, id: str) -> UserInfo | None:
@Authz.authz_change(AuthzOperation.update_or_insert, ResourceType.user)
@dispatch_message(events.UpdateOrInsertUser)
async def update_or_insert_user(
self, user_id: str, payload: dict[str, Any], *, session: AsyncSession | None = None
self, user: UnsavedUserInfo, *, session: AsyncSession | None = None
) -> UserInfoUpdate:
"""Update a user or insert it if it does not exist."""
if not session:
raise errors.ProgrammingError(message="A database session is required")
res = await session.execute(select(UserORM).where(UserORM.keycloak_id == user_id))
res = await session.execute(select(UserORM).where(UserORM.keycloak_id == user.id))
existing_user = res.scalar_one_or_none()
if existing_user:
return await self._update_user(session=session, user_id=user_id, existing_user=existing_user, **payload)
return await self._update_user(
session=session,
user_id=user.id,
existing_user=existing_user,
patch=UserPatch.from_unsaved_user_info(user),
)
else:
return await self._insert_user(session=session, user_id=user_id, **payload)
return await self._insert_user(session=session, user=user)

async def _insert_user(self, session: AsyncSession, user_id: str, **kwargs: Any) -> UserInfoUpdate:
async def _insert_user(self, session: AsyncSession, user: UnsavedUserInfo) -> UserInfoUpdate:
"""Insert a user."""
kwargs.pop("keycloak_id", None)
kwargs.pop("id", None)
slug = base_models.Slug.from_user(
kwargs.get("email"), kwargs.get("first_name"), kwargs.get("last_name"), user_id
).value
slug = base_models.Slug.from_user(user.email, user.first_name, user.last_name, user.id).value
namespace = await self.group_repo._create_user_namespace_slug(
session, user_slug=slug, retry_enumerate=5, retry_random=True
)
slug = base_models.Slug.from_name(namespace)
new_user = UserORM(keycloak_id=user_id, namespace=NamespaceORM(slug=slug.value, user_id=user_id), **kwargs)
new_user = UserORM(
keycloak_id=user.id,
namespace=NamespaceORM(slug=slug.value, user_id=user.id),
email=user.email,
first_name=user.first_name,
last_name=user.last_name,
)
new_user.namespace.user = new_user
session.add(new_user)
await session.flush()
return UserInfoUpdate(None, new_user.dump())

async def _update_user(
self, session: AsyncSession, user_id: str, existing_user: UserORM | None, **kwargs: Any
self, session: AsyncSession, user_id: str, existing_user: UserORM | None, patch: UserPatch
) -> UserInfoUpdate:
"""Update a user."""
if not existing_user:
Expand All @@ -254,13 +263,13 @@ async def _update_user(
if not existing_user:
raise errors.MissingResourceError(message=f"The user with id '{user_id}' cannot be found")
old_user = existing_user.dump()

kwargs.pop("keycloak_id", None)
kwargs.pop("id", None)
session.add(existing_user) # reattach to session
for field_name, field_value in kwargs.items():
if getattr(existing_user, field_name, None) != field_value:
setattr(existing_user, field_name, field_value)
if patch.email is not None:
existing_user.email = patch.email if patch.email else None
if patch.first_name is not None:
existing_user.first_name = patch.first_name if patch.first_name else None
if patch.last_name is not None:
existing_user.last_name = patch.last_name if patch.last_name else None
namespace = await self.group_repo.get_user_namespace(user_id)
if not namespace:
raise errors.ProgrammingError(
Expand All @@ -279,7 +288,7 @@ async def _do_update(raw_kc_user: dict[str, Any]) -> None:
db_user = await self._get_user(kc_user.id)
if db_user != kc_user:
logger.info(f"Inserting or updating user {db_user} -> {kc_user}")
await self.update_or_insert_user(kc_user.id, asdict(kc_user))
await self.update_or_insert_user(user=kc_user)

# NOTE: If asyncio.gather is used here you quickly exhaust all DB connections
# or timeout on waiting for available connections
Expand All @@ -300,7 +309,7 @@ async def events_sync(self, kc_api: IKeycloakAPI) -> None:
latest_utc_timestamp_orm.timestamp_utc if latest_utc_timestamp_orm is not None else None
)
logger.info(f"The previous sync latest event is {previous_sync_latest_utc_timestamp} UTC")
now_utc = datetime.utcnow()
now_utc = datetime.now(tz=UTC)
start_date = now_utc.date() - timedelta(days=1)
logger.info(f"Pulling events with a start date of {start_date} UTC")
user_events = kc_api.get_user_events(start_date=start_date)
Expand All @@ -324,7 +333,10 @@ async def events_sync(self, kc_api: IKeycloakAPI) -> None:
latest_delete_timestamp = None
for update in parsed_updates:
logger.info(f"Processing update event {update}")
await self.update_or_insert_user(update.user_id, {update.field_name: update.new_value})
# TODO: add typing to `update.field_name` for safer updates
await self.update_or_insert_user(
user=UnsavedUserInfo(id=update.user_id, **{update.field_name: update.new_value})
)
latest_update_timestamp = update.timestamp_utc
for deletion in parsed_deletions:
logger.info(f"Processing deletion event {deletion}")
Expand Down
18 changes: 18 additions & 0 deletions components/renku_data_services/users/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,24 @@ class UserInfo(UnsavedUserInfo):
namespace: Namespace


@dataclass(frozen=True, eq=True, kw_only=True)
class UserPatch:
"""Model for changes requested on a user."""

first_name: str | None = None
last_name: str | None = None
email: str | None = None

@classmethod
def from_unsaved_user_info(cls, user: UnsavedUserInfo) -> "UserPatch":
"""Create a user patch from a UnsavedUserInfo instance."""
return UserPatch(
first_name=user.first_name,
last_name=user.last_name,
email=user.email,
)


@dataclass
class DeletedUser:
"""A user that was deleted from the database."""
Expand Down

0 comments on commit f1e525a

Please sign in to comment.