Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use typed patch instance for user updates #497

Merged
merged 4 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 49 additions & 27 deletions components/renku_data_services/users/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import secrets
from collections.abc import AsyncGenerator, Callable
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from typing import Any, cast

from sanic.log import logger
Expand All @@ -28,9 +28,11 @@
DeletedUser,
KeycloakAdminEvent,
PinnedProjects,
UnsavedUserInfo,
UserInfo,
UserInfoFieldUpdate,
UserInfoUpdate,
UserPatch,
UserPreferences,
)
from renku_data_services.users.orm import LastKeycloakEventTimestamp, UserORM, UserPreferencesORM
Expand Down Expand Up @@ -65,12 +67,12 @@ async def _add_api_user(self, user: APIUser) -> UserInfo:
if not user.id:
raise errors.UnauthorizedError(message="The user has to be authenticated to be inserted in the DB.")
result = await self._users_sync.update_or_insert_user(
user_id=user.id,
payload=dict(
user=UnsavedUserInfo(
id=user.id,
email=user.email,
first_name=user.first_name,
last_name=user.last_name,
email=user.email,
),
)
)
return result.new

Expand Down Expand Up @@ -214,37 +216,54 @@ async def _get_user(self, id: str) -> UserInfo | None:
@Authz.authz_change(AuthzOperation.update_or_insert, ResourceType.user)
@dispatch_message(events.UpdateOrInsertUser)
async def update_or_insert_user(
self, user_id: str, payload: dict[str, Any], *, session: AsyncSession | None = None
self,
# user_id: str,
# payload: dict[str, Any],
leafty marked this conversation as resolved.
Show resolved Hide resolved
leafty marked this conversation as resolved.
Show resolved Hide resolved
user: UnsavedUserInfo,
*,
session: AsyncSession | None = None,
) -> UserInfoUpdate:
"""Update a user or insert it if it does not exist."""
if not session:
raise errors.ProgrammingError(message="A database session is required")
res = await session.execute(select(UserORM).where(UserORM.keycloak_id == user_id))
res = await session.execute(select(UserORM).where(UserORM.keycloak_id == user.id))
existing_user = res.scalar_one_or_none()
if existing_user:
return await self._update_user(session=session, user_id=user_id, existing_user=existing_user, **payload)
return await self._update_user(
session=session,
user_id=user.id,
existing_user=existing_user,
patch=UserPatch.from_unsaved_user_info(user),
)
else:
return await self._insert_user(session=session, user_id=user_id, **payload)
return await self._insert_user(session=session, user=user)

async def _insert_user(self, session: AsyncSession, user_id: str, **kwargs: Any) -> UserInfoUpdate:
async def _insert_user(self, session: AsyncSession, user: UnsavedUserInfo) -> UserInfoUpdate:
"""Insert a user."""
kwargs.pop("keycloak_id", None)
kwargs.pop("id", None)
slug = base_models.Slug.from_user(
kwargs.get("email"), kwargs.get("first_name"), kwargs.get("last_name"), user_id
).value
slug = base_models.Slug.from_user(user.email, user.first_name, user.last_name, user.id).value
namespace = await self.group_repo._create_user_namespace_slug(
session, user_slug=slug, retry_enumerate=5, retry_random=True
)
slug = base_models.Slug.from_name(namespace)
new_user = UserORM(keycloak_id=user_id, namespace=NamespaceORM(slug=slug.value, user_id=user_id), **kwargs)
new_user = UserORM(
keycloak_id=user.id,
namespace=NamespaceORM(slug=slug.value, user_id=user.id),
email=user.email,
first_name=user.first_name,
last_name=user.last_name,
)
new_user.namespace.user = new_user
session.add(new_user)
await session.flush()
return UserInfoUpdate(None, new_user.dump())

async def _update_user(
self, session: AsyncSession, user_id: str, existing_user: UserORM | None, **kwargs: Any
self,
session: AsyncSession,
user_id: str,
existing_user: UserORM | None,
# **kwargs: Any
leafty marked this conversation as resolved.
Show resolved Hide resolved
leafty marked this conversation as resolved.
Show resolved Hide resolved
patch: UserPatch,
) -> UserInfoUpdate:
"""Update a user."""
if not existing_user:
Expand All @@ -254,13 +273,13 @@ async def _update_user(
if not existing_user:
raise errors.MissingResourceError(message=f"The user with id '{user_id}' cannot be found")
old_user = existing_user.dump()

kwargs.pop("keycloak_id", None)
kwargs.pop("id", None)
session.add(existing_user) # reattach to session
for field_name, field_value in kwargs.items():
if getattr(existing_user, field_name, None) != field_value:
setattr(existing_user, field_name, field_value)
if patch.email is not None:
existing_user.email = patch.email if patch.email else None
if patch.first_name is not None:
existing_user.first_name = patch.first_name if patch.first_name else None
if patch.last_name is not None:
existing_user.last_name = patch.last_name if patch.last_name else None
namespace = await self.group_repo.get_user_namespace(user_id)
if not namespace:
raise errors.ProgrammingError(
Expand All @@ -279,7 +298,7 @@ async def _do_update(raw_kc_user: dict[str, Any]) -> None:
db_user = await self._get_user(kc_user.id)
if db_user != kc_user:
logger.info(f"Inserting or updating user {db_user} -> {kc_user}")
await self.update_or_insert_user(kc_user.id, asdict(kc_user))
await self.update_or_insert_user(user=kc_user)

# NOTE: If asyncio.gather is used here you quickly exhaust all DB connections
# or timeout on waiting for available connections
Expand All @@ -300,7 +319,7 @@ async def events_sync(self, kc_api: IKeycloakAPI) -> None:
latest_utc_timestamp_orm.timestamp_utc if latest_utc_timestamp_orm is not None else None
)
logger.info(f"The previous sync latest event is {previous_sync_latest_utc_timestamp} UTC")
now_utc = datetime.utcnow()
now_utc = datetime.now(tz=UTC)
start_date = now_utc.date() - timedelta(days=1)
logger.info(f"Pulling events with a start date of {start_date} UTC")
user_events = kc_api.get_user_events(start_date=start_date)
Expand All @@ -324,7 +343,10 @@ async def events_sync(self, kc_api: IKeycloakAPI) -> None:
latest_delete_timestamp = None
for update in parsed_updates:
logger.info(f"Processing update event {update}")
await self.update_or_insert_user(update.user_id, {update.field_name: update.new_value})
# TODO: add typing to `update.field_name` for safer updates
await self.update_or_insert_user(
user=UnsavedUserInfo(id=update.user_id, **{update.field_name: update.new_value})
)
latest_update_timestamp = update.timestamp_utc
for deletion in parsed_deletions:
logger.info(f"Processing deletion event {deletion}")
Expand Down
18 changes: 18 additions & 0 deletions components/renku_data_services/users/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,24 @@ class UserInfo(UnsavedUserInfo):
namespace: Namespace


@dataclass(frozen=True, eq=True, kw_only=True)
class UserPatch:
"""Model for changes requested on a user."""

first_name: str | None = None
last_name: str | None = None
email: str | None = None

@classmethod
def from_unsaved_user_info(cls, user: UnsavedUserInfo) -> "UserPatch":
"""Create a user patch from a UnsavedUserInfo instance."""
return UserPatch(
first_name=user.first_name,
last_name=user.last_name,
email=user.email,
)


@dataclass
class DeletedUser:
"""A user that was deleted from the database."""
Expand Down
Loading