diff --git a/backend/danswer/db/users.py b/backend/danswer/db/users.py index 24f7dc501c7..5014c35cbec 100644 --- a/backend/danswer/db/users.py +++ b/backend/danswer/db/users.py @@ -103,17 +103,6 @@ def list_users( return db_session.scalars(stmt).unique().all() -def get_users_by_emails( - db_session: Session, emails: list[str] -) -> tuple[list[User], list[str]]: - # Use distinct to avoid duplicates - stmt = select(User).filter(User.email.in_(emails)) # type: ignore - found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list - found_users_emails = [user.email for user in found_users] - missing_user_emails = [email for email in emails if email not in found_users_emails] - return found_users, missing_user_emails - - def get_user_by_email(email: str, db_session: Session) -> User | None: user = ( db_session.query(User) @@ -128,7 +117,7 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None: return db_session.query(User).filter(User.id == user_id).first() # type: ignore -def _generate_non_web_slack_user(email: str) -> User: +def _generate_slack_user(email: str) -> User: fastapi_users_pw_helper = PasswordHelper() password = fastapi_users_pw_helper.generate() hashed_pass = fastapi_users_pw_helper.hash(password) @@ -149,13 +138,29 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User: db_session.commit() return user - user = _generate_non_web_slack_user(email=email) + user = _generate_slack_user(email=email) db_session.add(user) db_session.commit() return user -def _generate_non_web_permissioned_user(email: str) -> User: +def _get_users_by_emails( + db_session: Session, lower_emails: list[str] +) -> tuple[list[User], list[str]]: + stmt = select(User).filter(func.lower(User.email).in_(lower_emails)) # type: ignore + found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list + + # Extract found emails and convert to lowercase to avoid case sensitivity issues + found_users_emails = [user.email.lower() for user in found_users] + + # Separate emails for users that were not found + missing_user_emails = [ + email for email in lower_emails if email not in found_users_emails + ] + return found_users, missing_user_emails + + +def _generate_ext_permissioned_user(email: str) -> User: fastapi_users_pw_helper = PasswordHelper() password = fastapi_users_pw_helper.generate() hashed_pass = fastapi_users_pw_helper.hash(password) @@ -169,12 +174,12 @@ def _generate_non_web_permissioned_user(email: str) -> User: def batch_add_ext_perm_user_if_not_exists( db_session: Session, emails: list[str] ) -> list[User]: - emails = [email.lower() for email in emails] - found_users, missing_user_emails = get_users_by_emails(db_session, emails) + lower_emails = [email.lower() for email in emails] + found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails) new_users: list[User] = [] - for email in missing_user_emails: - new_users.append(_generate_non_web_permissioned_user(email=email)) + for email in missing_lower_emails: + new_users.append(_generate_ext_permissioned_user(email=email)) db_session.add_all(new_users) db_session.commit()