From b1957737f2cb9419f4cff7a3714a06870a8d5027 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Tue, 14 Jan 2025 15:35:52 -0800 Subject: [PATCH] refactored _add_user_filter usage (#3674) * refactored db.connector_credential_pair * Rerfactored the db.credentials user filtering * the restr --- backend/ee/onyx/db/token_limit.py | 4 +- backend/ee/onyx/db/user_group.py | 5 +- .../external_permissions/salesforce/utils.py | 5 +- .../ee/onyx/server/token_rate_limits/api.py | 8 +- .../celery/tasks/connector_deletion/tasks.py | 5 +- .../tasks/doc_permission_syncing/tasks.py | 5 +- .../tasks/external_group_syncing/tasks.py | 5 +- .../background/celery/tasks/indexing/tasks.py | 5 +- .../background/celery/tasks/pruning/tasks.py | 5 +- .../background/celery/tasks/vespa/tasks.py | 10 ++- .../onyx/background/indexing/run_indexing.py | 3 +- backend/onyx/db/connector_credential_pair.py | 88 ++++++++++++------- backend/onyx/db/credentials.py | 75 ++++++++++------ backend/onyx/db/document.py | 6 +- backend/onyx/db/document_set.py | 19 ++-- backend/onyx/db/feedback.py | 15 ++-- backend/onyx/db/persona.py | 21 +++-- .../onyxbot/slack/handlers/handle_buttons.py | 8 -- .../redis/redis_connector_credential_pair.py | 5 +- backend/onyx/redis/redis_connector_delete.py | 5 +- backend/onyx/redis/redis_connector_prune.py | 5 +- backend/onyx/server/documents/cc_pair.py | 24 ++--- backend/onyx/server/documents/connector.py | 27 ++++-- backend/onyx/server/documents/credential.py | 14 +-- backend/onyx/server/features/persona/api.py | 6 +- backend/onyx/server/manage/administrative.py | 24 ++--- backend/onyx/server/onyx_api/ingestion.py | 3 +- .../openai_assistants_api/asssistants_api.py | 4 +- .../server/query_and_chat/chat_backend.py | 9 -- .../scripts/force_delete_connector_by_id.py | 5 +- 30 files changed, 255 insertions(+), 168 deletions(-) diff --git a/backend/ee/onyx/db/token_limit.py b/backend/ee/onyx/db/token_limit.py index ca5249e6923..1dbf3c53383 100644 --- a/backend/ee/onyx/db/token_limit.py +++ b/backend/ee/onyx/db/token_limit.py @@ -111,10 +111,10 @@ def insert_user_group_token_rate_limit( return token_limit -def fetch_user_group_token_rate_limits( +def fetch_user_group_token_rate_limits_for_user( db_session: Session, group_id: int, - user: User | None = None, + user: User | None, enabled_only: bool = False, ordered: bool = True, get_editable: bool = True, diff --git a/backend/ee/onyx/db/user_group.py b/backend/ee/onyx/db/user_group.py index 791a6cebce5..827cdcae559 100644 --- a/backend/ee/onyx/db/user_group.py +++ b/backend/ee/onyx/db/user_group.py @@ -705,7 +705,10 @@ def delete_user_group_cc_pair_relationship__no_commit( connector_credential_pair_id matches the given cc_pair_id. Should be used very carefully (only for connectors that are being deleted).""" - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + cc_pair = get_connector_credential_pair_from_id( + db_session=db_session, + cc_pair_id=cc_pair_id, + ) if not cc_pair: raise ValueError(f"Connector Credential Pair '{cc_pair_id}' does not exist") diff --git a/backend/ee/onyx/external_permissions/salesforce/utils.py b/backend/ee/onyx/external_permissions/salesforce/utils.py index 289e14e37e2..6875d19405e 100644 --- a/backend/ee/onyx/external_permissions/salesforce/utils.py +++ b/backend/ee/onyx/external_permissions/salesforce/utils.py @@ -161,7 +161,10 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP: - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + cc_pair = get_connector_credential_pair_from_id( + db_session=db_session, + cc_pair_id=cc_pair_id, + ) if cc_pair is None: raise ValueError(f"CC pair {cc_pair_id} not found") credential_json = cc_pair.credential.credential_json diff --git a/backend/ee/onyx/server/token_rate_limits/api.py b/backend/ee/onyx/server/token_rate_limits/api.py index fe382a50f09..458d602efda 100644 --- a/backend/ee/onyx/server/token_rate_limits/api.py +++ b/backend/ee/onyx/server/token_rate_limits/api.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group -from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits +from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user from ee.onyx.db.token_limit import insert_user_group_token_rate_limit from onyx.auth.users import current_admin_user from onyx.auth.users import current_curator_or_admin_user @@ -51,8 +51,10 @@ def get_group_token_limit_settings( ) -> list[TokenRateLimitDisplay]: return [ TokenRateLimitDisplay.from_db(token_rate_limit) - for token_rate_limit in fetch_user_group_token_rate_limits( - db_session, group_id, user + for token_rate_limit in fetch_user_group_token_rate_limits_for_user( + db_session=db_session, + group_id=group_id, + user=user, ) ] diff --git a/backend/onyx/background/celery/tasks/connector_deletion/tasks.py b/backend/onyx/background/celery/tasks/connector_deletion/tasks.py index bf7e949d3f8..c64fb97a0c0 100644 --- a/backend/onyx/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/onyx/background/celery/tasks/connector_deletion/tasks.py @@ -116,7 +116,10 @@ def try_generate_document_cc_pair_cleanup_tasks( # we need to load the state of the object inside the fence # to avoid a race condition with db.commit/fence deletion # at the end of this taskset - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + cc_pair = get_connector_credential_pair_from_id( + db_session=db_session, + cc_pair_id=cc_pair_id, + ) if not cc_pair: return None diff --git a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py index 5e1e3c2c0f4..69d425bdb45 100644 --- a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py @@ -279,7 +279,10 @@ def connector_permission_sync_generator_task( try: with get_session_with_tenant(tenant_id) as db_session: - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + cc_pair = get_connector_credential_pair_from_id( + db_session=db_session, + cc_pair_id=cc_pair_id, + ) if cc_pair is None: raise ValueError( f"No connector credential pair found for id: {cc_pair_id}" diff --git a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py index 238e147c9af..9af9f381520 100644 --- a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py @@ -250,7 +250,10 @@ def connector_external_group_sync_generator_task( return None with get_session_with_tenant(tenant_id) as db_session: - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + cc_pair = get_connector_credential_pair_from_id( + cc_pair_id=cc_pair_id, + db_session=db_session, + ) if cc_pair is None: raise ValueError( f"No connector credential pair found for id: {cc_pair_id}" diff --git a/backend/onyx/background/celery/tasks/indexing/tasks.py b/backend/onyx/background/celery/tasks/indexing/tasks.py index 7c64ba61787..dd3d8117d2e 100644 --- a/backend/onyx/background/celery/tasks/indexing/tasks.py +++ b/backend/onyx/background/celery/tasks/indexing/tasks.py @@ -304,7 +304,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: ) cc_pair = get_connector_credential_pair_from_id( - cc_pair_id, db_session + db_session=db_session, + cc_pair_id=cc_pair_id, ) if not cc_pair: continue @@ -1198,8 +1199,8 @@ def connector_indexing_task( attempt_found = True cc_pair = get_connector_credential_pair_from_id( - cc_pair_id=cc_pair_id, db_session=db_session, + cc_pair_id=cc_pair_id, ) if not cc_pair: diff --git a/backend/onyx/background/celery/tasks/pruning/tasks.py b/backend/onyx/background/celery/tasks/pruning/tasks.py index a1e891365f3..c9483e97f54 100644 --- a/backend/onyx/background/celery/tasks/pruning/tasks.py +++ b/backend/onyx/background/celery/tasks/pruning/tasks.py @@ -103,7 +103,10 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None: for cc_pair_id in cc_pair_ids: lock_beat.reacquire() with get_session_with_tenant(tenant_id) as db_session: - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + cc_pair = get_connector_credential_pair_from_id( + db_session=db_session, + cc_pair_id=cc_pair_id, + ) if not cc_pair: continue diff --git a/backend/onyx/background/celery/tasks/vespa/tasks.py b/backend/onyx/background/celery/tasks/vespa/tasks.py index dea1981f0fa..821f03189e9 100644 --- a/backend/onyx/background/celery/tasks/vespa/tasks.py +++ b/backend/onyx/background/celery/tasks/vespa/tasks.py @@ -285,7 +285,10 @@ def try_generate_document_set_sync_tasks( # don't generate sync tasks if we're up to date # race condition with the monitor/cleanup function if we use a cached result! - document_set = get_document_set_by_id(db_session, document_set_id) + document_set = get_document_set_by_id( + db_session=db_session, + document_set_id=document_set_id, + ) if not document_set: return None @@ -532,7 +535,10 @@ def monitor_connector_deletion_taskset( return with get_session_with_tenant(tenant_id) as db_session: - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + cc_pair = get_connector_credential_pair_from_id( + db_session=db_session, + cc_pair_id=cc_pair_id, + ) if not cc_pair: task_logger.warning( f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}" diff --git a/backend/onyx/background/indexing/run_indexing.py b/backend/onyx/background/indexing/run_indexing.py index 8a7889bbb19..623eb8edbe3 100644 --- a/backend/onyx/background/indexing/run_indexing.py +++ b/backend/onyx/background/indexing/run_indexing.py @@ -75,7 +75,8 @@ def _get_connector_runner( # it will never succeed cc_pair = get_connector_credential_pair_from_id( - attempt.connector_credential_pair.id, db_session + db_session=db_session, + cc_pair_id=attempt.connector_credential_pair.id, ) if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE: update_connector_credential_pair( diff --git a/backend/onyx/db/connector_credential_pair.py b/backend/onyx/db/connector_credential_pair.py index 3378a8d493b..c8651e30637 100644 --- a/backend/onyx/db/connector_credential_pair.py +++ b/backend/onyx/db/connector_credential_pair.py @@ -11,9 +11,8 @@ from sqlalchemy.orm import Session from onyx.configs.app_configs import DISABLE_AUTH -from onyx.configs.constants import DocumentSource from onyx.db.connector import fetch_connector_by_id -from onyx.db.credentials import fetch_credential_by_id +from onyx.db.credentials import fetch_credential_by_id_for_user from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.models import ConnectorCredentialPair @@ -92,10 +91,9 @@ def _add_user_filters( return stmt.where(where_clause) -def get_connector_credential_pairs( +def get_connector_credential_pairs_for_user( db_session: Session, - include_disabled: bool = True, - user: User | None = None, + user: User | None, get_editable: bool = True, ids: list[int] | None = None, eager_load_connector: bool = False, @@ -106,11 +104,18 @@ def get_connector_credential_pairs( stmt = stmt.options(joinedload(ConnectorCredentialPair.connector)) stmt = _add_user_filters(stmt, user, get_editable) + if ids: + stmt = stmt.where(ConnectorCredentialPair.id.in_(ids)) + + return list(db_session.scalars(stmt).all()) + + +def get_connector_credential_pairs( + db_session: Session, + ids: list[int] | None = None, +) -> list[ConnectorCredentialPair]: + stmt = select(ConnectorCredentialPair).distinct() - if not include_disabled: - stmt = stmt.where( - ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE - ) if ids: stmt = stmt.where(ConnectorCredentialPair.id.in_(ids)) @@ -122,7 +127,10 @@ def add_deletion_failure_message( cc_pair_id: int, failure_message: str, ) -> None: - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + cc_pair = get_connector_credential_pair_from_id( + db_session=db_session, + cc_pair_id=cc_pair_id, + ) if not cc_pair: return cc_pair.deletion_failure_message = failure_message @@ -132,24 +140,21 @@ def add_deletion_failure_message( def get_cc_pair_groups_for_ids( db_session: Session, cc_pair_ids: list[int], - user: User | None = None, - get_editable: bool = True, ) -> list[UserGroup__ConnectorCredentialPair]: stmt = select(UserGroup__ConnectorCredentialPair).distinct() stmt = stmt.outerjoin( ConnectorCredentialPair, UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id, ) - stmt = _add_user_filters(stmt, user, get_editable) stmt = stmt.where(UserGroup__ConnectorCredentialPair.cc_pair_id.in_(cc_pair_ids)) return list(db_session.scalars(stmt).all()) -def get_connector_credential_pair( +def get_connector_credential_pair_for_user( + db_session: Session, connector_id: int, credential_id: int, - db_session: Session, - user: User | None = None, + user: User | None, get_editable: bool = True, ) -> ConnectorCredentialPair | None: stmt = select(ConnectorCredentialPair) @@ -160,28 +165,36 @@ def get_connector_credential_pair( return result.scalar_one_or_none() -def get_connector_credential_source_from_id( +def get_connector_credential_pair( + db_session: Session, + connector_id: int, + credential_id: int, +) -> ConnectorCredentialPair | None: + stmt = select(ConnectorCredentialPair) + stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id) + stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id) + result = db_session.execute(stmt) + return result.scalar_one_or_none() + + +def get_connector_credential_pair_from_id_for_user( cc_pair_id: int, db_session: Session, - user: User | None = None, + user: User | None, get_editable: bool = True, -) -> DocumentSource | None: - stmt = select(ConnectorCredentialPair) +) -> ConnectorCredentialPair | None: + stmt = select(ConnectorCredentialPair).distinct() stmt = _add_user_filters(stmt, user, get_editable) stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id) result = db_session.execute(stmt) - cc_pair = result.scalar_one_or_none() - return cc_pair.connector.source if cc_pair else None + return result.scalar_one_or_none() def get_connector_credential_pair_from_id( - cc_pair_id: int, db_session: Session, - user: User | None = None, - get_editable: bool = True, + cc_pair_id: int, ) -> ConnectorCredentialPair | None: stmt = select(ConnectorCredentialPair).distinct() - stmt = _add_user_filters(stmt, user, get_editable) stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id) result = db_session.execute(stmt) return result.scalar_one_or_none() @@ -198,7 +211,9 @@ def get_last_successful_attempt_time( the CC Pair row in the database""" if search_settings.status == IndexModelStatus.PRESENT: connector_credential_pair = get_connector_credential_pair( - connector_id, credential_id, db_session + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, ) if ( connector_credential_pair is None @@ -259,7 +274,10 @@ def update_connector_credential_pair_from_id( net_docs: int | None = None, run_dt: datetime | None = None, ) -> None: - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + cc_pair = get_connector_credential_pair_from_id( + db_session=db_session, + cc_pair_id=cc_pair_id, + ) if not cc_pair: logger.warning( f"Attempted to update pair for Connector Credential Pair '{cc_pair_id}'" @@ -284,7 +302,11 @@ def update_connector_credential_pair( net_docs: int | None = None, run_dt: datetime | None = None, ) -> None: - cc_pair = get_connector_credential_pair(connector_id, credential_id, db_session) + cc_pair = get_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + ) if not cc_pair: logger.warning( f"Attempted to update pair for connector id {connector_id} " @@ -368,7 +390,7 @@ def add_credential_to_connector( last_successful_index_time: datetime | None = None, ) -> StatusResponse: connector = fetch_connector_by_id(connector_id, db_session) - credential = fetch_credential_by_id( + credential = fetch_credential_by_id_for_user( credential_id, user, db_session, @@ -450,7 +472,7 @@ def remove_credential_from_connector( db_session: Session, ) -> StatusResponse[int]: connector = fetch_connector_by_id(connector_id, db_session) - credential = fetch_credential_by_id( + credential = fetch_credential_by_id_for_user( credential_id, user, db_session, @@ -466,10 +488,10 @@ def remove_credential_from_connector( detail="Credential does not exist or does not belong to user", ) - association = get_connector_credential_pair( + association = get_connector_credential_pair_for_user( + db_session=db_session, connector_id=connector_id, credential_id=credential_id, - db_session=db_session, user=user, get_editable=True, ) diff --git a/backend/onyx/db/credentials.py b/backend/onyx/db/credentials.py index 86cb31aa811..ca82ceb1624 100644 --- a/backend/onyx/db/credentials.py +++ b/backend/onyx/db/credentials.py @@ -9,6 +9,7 @@ from sqlalchemy.sql.expression import or_ from onyx.auth.schemas import UserRole +from onyx.configs.app_configs import DISABLE_AUTH from onyx.configs.constants import DocumentSource from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, @@ -42,22 +43,21 @@ def _add_user_filters( stmt: Select, user: User | None, - assume_admin: bool = False, # Used with API key get_editable: bool = True, ) -> Select: """Attaches filters to the statement to ensure that the user can only access the appropriate credentials""" - if not user: - if assume_admin: - # apply admin filters minus the user_id check - stmt = stmt.where( - or_( - Credential.user_id.is_(None), - Credential.admin_public == True, # noqa: E712 - Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE), - ) + if user is None: + if not DISABLE_AUTH: + raise ValueError("Anonymous users are not allowed to access credentials") + # If user is None and auth is disabled, assume the user is an admin + return stmt.where( + or_( + Credential.user_id.is_(None), + Credential.admin_public == True, # noqa: E712 + Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE), ) - return stmt + ) if user.role == UserRole.ADMIN: # Admins can access all credentials that are public or owned by them @@ -138,9 +138,9 @@ def _relate_credential_to_user_groups__no_commit( db_session.add_all(credential_user_groups) -def fetch_credentials( +def fetch_credentials_for_user( db_session: Session, - user: User | None = None, + user: User | None, get_editable: bool = True, ) -> list[Credential]: stmt = select(Credential) @@ -149,11 +149,10 @@ def fetch_credentials( return list(results.all()) -def fetch_credential_by_id( +def fetch_credential_by_id_for_user( credential_id: int, user: User | None, db_session: Session, - assume_admin: bool = False, get_editable: bool = True, ) -> Credential | None: stmt = select(Credential).distinct() @@ -161,7 +160,6 @@ def fetch_credential_by_id( stmt = _add_user_filters( stmt=stmt, user=user, - assume_admin=assume_admin, get_editable=get_editable, ) result = db_session.execute(stmt) @@ -169,7 +167,18 @@ def fetch_credential_by_id( return credential -def fetch_credentials_by_source( +def fetch_credential_by_id( + db_session: Session, + credential_id: int, +) -> Credential | None: + stmt = select(Credential).distinct() + stmt = stmt.where(Credential.id == credential_id) + result = db_session.execute(stmt) + credential = result.scalar_one_or_none() + return credential + + +def fetch_credentials_by_source_for_user( db_session: Session, user: User | None, document_source: DocumentSource | None = None, @@ -181,11 +190,22 @@ def fetch_credentials_by_source( return list(credentials) +def fetch_credentials_by_source( + db_session: Session, + document_source: DocumentSource | None = None, +) -> list[Credential]: + base_query = select(Credential).where(Credential.source == document_source) + credentials = db_session.execute(base_query).scalars().all() + return list(credentials) + + def swap_credentials_connector( new_credential_id: int, connector_id: int, user: User | None, db_session: Session ) -> ConnectorCredentialPair: # Check if the user has permission to use the new credential - new_credential = fetch_credential_by_id(new_credential_id, user, db_session) + new_credential = fetch_credential_by_id_for_user( + new_credential_id, user, db_session + ) if not new_credential: raise ValueError( f"No Credential found with id {new_credential_id} or user doesn't have permission to use it" @@ -275,7 +295,7 @@ def alter_credential( db_session: Session, ) -> Credential | None: # TODO: add user group relationship update - credential = fetch_credential_by_id(credential_id, user, db_session) + credential = fetch_credential_by_id_for_user(credential_id, user, db_session) if credential is None: return None @@ -299,7 +319,7 @@ def update_credential( user: User, db_session: Session, ) -> Credential | None: - credential = fetch_credential_by_id(credential_id, user, db_session) + credential = fetch_credential_by_id_for_user(credential_id, user, db_session) if credential is None: return None @@ -316,7 +336,7 @@ def update_credential_json( user: User, db_session: Session, ) -> Credential | None: - credential = fetch_credential_by_id(credential_id, user, db_session) + credential = fetch_credential_by_id_for_user(credential_id, user, db_session) if credential is None: return None @@ -341,7 +361,7 @@ def delete_credential( db_session: Session, force: bool = False, ) -> None: - credential = fetch_credential_by_id(credential_id, user, db_session) + credential = fetch_credential_by_id_for_user(credential_id, user, db_session) if credential is None: raise ValueError( f"Credential by provided id {credential_id} does not exist or does not belong to user" @@ -396,7 +416,10 @@ def create_initial_public_credential(db_session: Session) -> None: "DB is not in a valid initial state." "There must exist an empty public credential for data connectors that do not require additional Auth." ) - first_credential = fetch_credential_by_id(PUBLIC_CREDENTIAL_ID, None, db_session) + first_credential = fetch_credential_by_id( + db_session=db_session, + credential_id=PUBLIC_CREDENTIAL_ID, + ) if first_credential is not None: if first_credential.credential_json != {} or first_credential.user is not None: @@ -414,7 +437,7 @@ def create_initial_public_credential(db_session: Session) -> None: def cleanup_gmail_credentials(db_session: Session) -> None: gmail_credentials = fetch_credentials_by_source( - db_session=db_session, user=None, document_source=DocumentSource.GMAIL + db_session=db_session, document_source=DocumentSource.GMAIL ) for credential in gmail_credentials: db_session.delete(credential) @@ -423,7 +446,7 @@ def cleanup_gmail_credentials(db_session: Session) -> None: def cleanup_google_drive_credentials(db_session: Session) -> None: google_drive_credentials = fetch_credentials_by_source( - db_session=db_session, user=None, document_source=DocumentSource.GOOGLE_DRIVE + db_session=db_session, document_source=DocumentSource.GOOGLE_DRIVE ) for credential in google_drive_credentials: db_session.delete(credential) @@ -433,7 +456,7 @@ def cleanup_google_drive_credentials(db_session: Session) -> None: def delete_service_account_credentials( user: User | None, db_session: Session, source: DocumentSource ) -> None: - credentials = fetch_credentials(db_session=db_session, user=user) + credentials = fetch_credentials_for_user(db_session=db_session, user=user) for credential in credentials: if ( credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY) diff --git a/backend/onyx/db/document.py b/backend/onyx/db/document.py index d9ff82d797a..27778fe0686 100644 --- a/backend/onyx/db/document.py +++ b/backend/onyx/db/document.py @@ -107,7 +107,8 @@ def get_all_documents_needing_vespa_sync_for_cc_pair( db_session: Session, cc_pair_id: int ) -> list[DbDocument]: cc_pair = get_connector_credential_pair_from_id( - cc_pair_id=cc_pair_id, db_session=db_session + db_session=db_session, + cc_pair_id=cc_pair_id, ) if not cc_pair: raise ValueError(f"No CC pair found with ID: {cc_pair_id}") @@ -137,7 +138,8 @@ def get_documents_for_cc_pair( cc_pair_id: int, ) -> list[DbDocument]: cc_pair = get_connector_credential_pair_from_id( - cc_pair_id=cc_pair_id, db_session=db_session + db_session=db_session, + cc_pair_id=cc_pair_id, ) if not cc_pair: raise ValueError(f"No CC pair found with ID: {cc_pair_id}") diff --git a/backend/onyx/db/document_set.py b/backend/onyx/db/document_set.py index 7df2ca0ac12..6383a6c02ca 100644 --- a/backend/onyx/db/document_set.py +++ b/backend/onyx/db/document_set.py @@ -116,10 +116,10 @@ def delete_document_set_privacy__no_commit( """No private document sets in Onyx MIT""" -def get_document_set_by_id( +def get_document_set_by_id_for_user( db_session: Session, document_set_id: int, - user: User | None = None, + user: User | None, get_editable: bool = True, ) -> DocumentSetDBModel | None: stmt = select(DocumentSetDBModel).distinct() @@ -128,6 +128,15 @@ def get_document_set_by_id( return db_session.scalar(stmt) +def get_document_set_by_id( + db_session: Session, + document_set_id: int, +) -> DocumentSetDBModel | None: + stmt = select(DocumentSetDBModel).distinct() + stmt = stmt.where(DocumentSetDBModel.id == document_set_id) + return db_session.scalar(stmt) + + def get_document_set_by_name( db_session: Session, document_set_name: str ) -> DocumentSetDBModel | None: @@ -275,7 +284,7 @@ def update_document_set( try: # update the description - document_set_row = get_document_set_by_id( + document_set_row = get_document_set_by_id_for_user( db_session=db_session, document_set_id=document_set_update_request.id, user=user, @@ -366,7 +375,7 @@ def mark_document_set_as_to_be_deleted( job which syncs these changes to Vespa.""" try: - document_set_row = get_document_set_by_id( + document_set_row = get_document_set_by_id_for_user( db_session=db_session, document_set_id=document_set_id, user=user, @@ -478,7 +487,7 @@ def fetch_document_sets( def fetch_all_document_sets_for_user( db_session: Session, - user: User | None = None, + user: User | None, get_editable: bool = True, ) -> Sequence[DocumentSetDBModel]: stmt = select(DocumentSetDBModel).distinct() diff --git a/backend/onyx/db/feedback.py b/backend/onyx/db/feedback.py index 0a8f9e969c6..762a754ab9e 100644 --- a/backend/onyx/db/feedback.py +++ b/backend/onyx/db/feedback.py @@ -27,7 +27,6 @@ from onyx.db.models import User__UserGroup from onyx.db.models import UserGroup__ConnectorCredentialPair from onyx.db.models import UserRole -from onyx.document_index.interfaces import DocumentIndex from onyx.utils.logger import setup_logger logger = setup_logger() @@ -108,9 +107,9 @@ def _add_user_filters( return stmt.where(where_clause) -def fetch_docs_ranked_by_boost( +def fetch_docs_ranked_by_boost_for_user( db_session: Session, - user: User | None = None, + user: User | None, ascending: bool = False, limit: int = 100, ) -> list[DbDocument]: @@ -129,11 +128,11 @@ def fetch_docs_ranked_by_boost( return list(doc_list) -def update_document_boost( +def update_document_boost_for_user( db_session: Session, document_id: str, boost: int, - user: User | None = None, + user: User | None, ) -> None: stmt = select(DbDocument).where(DbDocument.id == document_id) stmt = _add_user_filters(stmt, user, get_editable=True) @@ -151,12 +150,11 @@ def update_document_boost( db_session.commit() -def update_document_hidden( +def update_document_hidden_for_user( db_session: Session, document_id: str, hidden: bool, - document_index: DocumentIndex, - user: User | None = None, + user: User | None, ) -> None: stmt = select(DbDocument).where(DbDocument.id == document_id) stmt = _add_user_filters(stmt, user, get_editable=True) @@ -178,7 +176,6 @@ def create_doc_retrieval_feedback( message_id: int, document_id: str, document_rank: int, - document_index: DocumentIndex, db_session: Session, clicked: bool = False, feedback: SearchFeedbackType | None = None, diff --git a/backend/onyx/db/persona.py b/backend/onyx/db/persona.py index 002ee0d4edb..6da756eba10 100644 --- a/backend/onyx/db/persona.py +++ b/backend/onyx/db/persona.py @@ -110,7 +110,7 @@ def _add_user_filters( # fetch_persona_by_id is used to fetch a persona by its ID. It is used to fetch a persona by its ID. -def fetch_persona_by_id( +def fetch_persona_by_id_for_user( db_session: Session, persona_id: int, user: User | None, get_editable: bool = True ) -> Persona: stmt = select(Persona).where(Persona.id == persona_id).distinct() @@ -229,7 +229,7 @@ def update_persona_shared_users( """Simplified version of `create_update_persona` which only touches the accessibility rather than any of the logic (e.g. prompt, connected data sources, etc.).""" - persona = fetch_persona_by_id( + persona = fetch_persona_by_id_for_user( db_session=db_session, persona_id=persona_id, user=user, get_editable=True ) @@ -255,7 +255,7 @@ def update_persona_public_status( db_session: Session, user: User | None, ) -> None: - persona = fetch_persona_by_id( + persona = fetch_persona_by_id_for_user( db_session=db_session, persona_id=persona_id, user=user, get_editable=True ) if user and user.role != UserRole.ADMIN and persona.user_id != user.id: @@ -283,7 +283,7 @@ def get_prompts( return db_session.scalars(stmt).all() -def get_personas( +def get_personas_for_user( # if user is `None` assume the user is an admin or auth is disabled user: User | None, db_session: Session, @@ -314,6 +314,13 @@ def get_personas( return db_session.execute(stmt).unique().scalars().all() +def get_personas(db_session: Session) -> Sequence[Persona]: + stmt = select(Persona).distinct() + stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX))) + stmt = stmt.where(Persona.deleted.is_(False)) + return db_session.execute(stmt).unique().scalars().all() + + def mark_persona_as_deleted( persona_id: int, user: User | None, @@ -357,7 +364,7 @@ def update_all_personas_display_priority( db_session: Session, ) -> None: """Updates the display priority of all lives Personas""" - personas = get_personas(user=None, db_session=db_session) + personas = get_personas(db_session=db_session) available_persona_ids = {persona.id for persona in personas} if available_persona_ids != set(display_priority_map.keys()): raise ValueError("Invalid persona IDs provided") @@ -511,7 +518,7 @@ def upsert_persona( # this checks if the user has permission to edit the persona # will raise an Exception if the user does not have permission - existing_persona = fetch_persona_by_id( + existing_persona = fetch_persona_by_id_for_user( db_session=db_session, persona_id=existing_persona.id, user=user, @@ -637,7 +644,7 @@ def update_persona_visibility( db_session: Session, user: User | None = None, ) -> None: - persona = fetch_persona_by_id( + persona = fetch_persona_by_id_for_user( db_session=db_session, persona_id=persona_id, user=user, get_editable=True ) diff --git a/backend/onyx/onyxbot/slack/handlers/handle_buttons.py b/backend/onyx/onyxbot/slack/handlers/handle_buttons.py index 6079b22f026..7b26e0bfea1 100644 --- a/backend/onyx/onyxbot/slack/handlers/handle_buttons.py +++ b/backend/onyx/onyxbot/slack/handlers/handle_buttons.py @@ -14,8 +14,6 @@ from onyx.db.engine import get_session_with_tenant from onyx.db.feedback import create_chat_message_feedback from onyx.db.feedback import create_doc_retrieval_feedback -from onyx.document_index.document_index_utils import get_both_index_names -from onyx.document_index.factory import get_default_document_index from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks from onyx.onyxbot.slack.blocks import get_document_feedback_blocks from onyx.onyxbot.slack.config import get_slack_channel_config_for_bot_and_channel @@ -186,16 +184,10 @@ def handle_slack_feedback( else: feedback = SearchFeedbackType.HIDE - curr_ind_name, sec_ind_name = get_both_index_names(db_session) - document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name - ) - create_doc_retrieval_feedback( message_id=message_id, document_id=doc_id, document_rank=doc_rank, - document_index=document_index, db_session=db_session, clicked=False, # Not tracking this for Slack feedback=feedback, diff --git a/backend/onyx/redis/redis_connector_credential_pair.py b/backend/onyx/redis/redis_connector_credential_pair.py index 0d53c2dc806..e648bc563f6 100644 --- a/backend/onyx/redis/redis_connector_credential_pair.py +++ b/backend/onyx/redis/redis_connector_credential_pair.py @@ -83,7 +83,10 @@ def generate_tasks( last_lock_time = time.monotonic() async_results = [] - cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session) + cc_pair = get_connector_credential_pair_from_id( + db_session=db_session, + cc_pair_id=int(self._id), + ) if not cc_pair: return None diff --git a/backend/onyx/redis/redis_connector_delete.py b/backend/onyx/redis/redis_connector_delete.py index b5285fb71ad..17651bf6636 100644 --- a/backend/onyx/redis/redis_connector_delete.py +++ b/backend/onyx/redis/redis_connector_delete.py @@ -92,7 +92,10 @@ def generate_tasks( last_lock_time = time.monotonic() async_results = [] - cc_pair = get_connector_credential_pair_from_id(int(self.id), db_session) + cc_pair = get_connector_credential_pair_from_id( + db_session=db_session, + cc_pair_id=int(self.id), + ) if not cc_pair: return None diff --git a/backend/onyx/redis/redis_connector_prune.py b/backend/onyx/redis/redis_connector_prune.py index bbecc1b8cbf..371c7ad1852 100644 --- a/backend/onyx/redis/redis_connector_prune.py +++ b/backend/onyx/redis/redis_connector_prune.py @@ -115,7 +115,10 @@ def generate_tasks( last_lock_time = time.monotonic() async_results = [] - cc_pair = get_connector_credential_pair_from_id(int(self.id), db_session) + cc_pair = get_connector_credential_pair_from_id( + db_session=db_session, + cc_pair_id=int(self.id), + ) if not cc_pair: return None diff --git a/backend/onyx/server/documents/cc_pair.py b/backend/onyx/server/documents/cc_pair.py index 64086de5df0..593e84f8fa6 100644 --- a/backend/onyx/server/documents/cc_pair.py +++ b/backend/onyx/server/documents/cc_pair.py @@ -20,7 +20,9 @@ ) from onyx.background.celery.versioned_apps.primary import app as primary_app from onyx.db.connector_credential_pair import add_credential_to_connector -from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id +from onyx.db.connector_credential_pair import ( + get_connector_credential_pair_from_id_for_user, +) from onyx.db.connector_credential_pair import remove_credential_from_connector from onyx.db.connector_credential_pair import ( update_connector_credential_pair_from_id, @@ -65,7 +67,7 @@ def get_cc_pair_index_attempts( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> PaginatedReturn[IndexAttemptSnapshot]: - cc_pair = get_connector_credential_pair_from_id( + cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id, db_session, user, get_editable=False ) if not cc_pair: @@ -98,14 +100,14 @@ def get_cc_pair_full_info( db_session: Session = Depends(get_session), tenant_id: str | None = Depends(get_current_tenant_id), ) -> CCPairFullInfo: - cc_pair = get_connector_credential_pair_from_id( + cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id, db_session, user, get_editable=False ) if not cc_pair: raise HTTPException( status_code=404, detail="CC Pair not found for current user permissions" ) - editable_cc_pair = get_connector_credential_pair_from_id( + editable_cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id, db_session, user, get_editable=True ) is_editable_for_current_user = editable_cc_pair is not None @@ -170,7 +172,7 @@ def update_cc_pair_status( Returns HTTPStatus.OK if everything finished. """ - cc_pair = get_connector_credential_pair_from_id( + cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, @@ -235,7 +237,7 @@ def update_cc_pair_name( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: - cc_pair = get_connector_credential_pair_from_id( + cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, @@ -264,7 +266,7 @@ def update_cc_pair_property( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: - cc_pair = get_connector_credential_pair_from_id( + cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, @@ -303,7 +305,7 @@ def get_cc_pair_last_pruned( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> datetime | None: - cc_pair = get_connector_credential_pair_from_id( + cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, @@ -327,7 +329,7 @@ def prune_cc_pair( ) -> StatusResponse[list[int]]: """Triggers pruning on a particular cc_pair immediately""" - cc_pair = get_connector_credential_pair_from_id( + cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, @@ -375,7 +377,7 @@ def get_cc_pair_latest_sync( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> datetime | None: - cc_pair = get_connector_credential_pair_from_id( + cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, @@ -399,7 +401,7 @@ def sync_cc_pair( ) -> StatusResponse[list[int]]: """Triggers permissions sync on a particular cc_pair immediately""" - cc_pair = get_connector_credential_pair_from_id( + cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, diff --git a/backend/onyx/server/documents/connector.py b/backend/onyx/server/documents/connector.py index 6be024cb227..4182568fbcb 100644 --- a/backend/onyx/server/documents/connector.py +++ b/backend/onyx/server/documents/connector.py @@ -67,12 +67,12 @@ from onyx.db.connector_credential_pair import add_credential_to_connector from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids from onyx.db.connector_credential_pair import get_connector_credential_pair -from onyx.db.connector_credential_pair import get_connector_credential_pairs +from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user from onyx.db.credentials import cleanup_gmail_credentials from onyx.db.credentials import cleanup_google_drive_credentials from onyx.db.credentials import create_credential from onyx.db.credentials import delete_service_account_credentials -from onyx.db.credentials import fetch_credential_by_id +from onyx.db.credentials import fetch_credential_by_id_for_user from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed from onyx.db.document import get_document_counts_for_cc_pairs from onyx.db.engine import get_current_tenant_id @@ -361,7 +361,7 @@ def check_drive_tokens( user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> AuthStatus: - db_credentials = fetch_credential_by_id(credential_id, user, db_session) + db_credentials = fetch_credential_by_id_for_user(credential_id, user, db_session) if ( not db_credentials or DB_CREDENTIALS_DICT_TOKEN_KEY not in db_credentials.credential_json @@ -467,7 +467,7 @@ def get_currently_failed_indexing_status( ) # Get all connector credential pairs - cc_pairs = get_connector_credential_pairs( + cc_pairs = get_connector_credential_pairs_for_user( db_session=db_session, user=user, get_editable=get_editable, @@ -536,7 +536,7 @@ def get_connector_status( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> list[ConnectorStatus]: - cc_pairs = get_connector_credential_pairs( + cc_pairs = get_connector_credential_pairs_for_user( db_session=db_session, user=user, ) @@ -583,7 +583,7 @@ def get_connector_indexing_status( # Additional checks are done to make sure the connector and credential still exist. # TODO: make this one query ... possibly eager load or wrap in a read transaction # to avoid the complexity of trying to error check throughout the function - cc_pairs = get_connector_credential_pairs( + cc_pairs = get_connector_credential_pairs_for_user( db_session=db_session, user=user, get_editable=get_editable, @@ -936,7 +936,11 @@ def connector_run_once( ] connector_credential_pairs = [ - get_connector_credential_pair(connector_id, credential_id, db_session) + get_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + ) for credential_id in credential_ids if credential_id not in skipped_credentials ] @@ -1118,10 +1122,15 @@ class BasicCCPairInfo(BaseModel): @router.get("/connector-status") def get_basic_connector_indexing_status( - _: User = Depends(current_chat_accesssible_user), + user: User = Depends(current_chat_accesssible_user), db_session: Session = Depends(get_session), ) -> list[BasicCCPairInfo]: - cc_pairs = get_connector_credential_pairs(db_session, eager_load_connector=True) + cc_pairs = get_connector_credential_pairs_for_user( + db_session=db_session, + eager_load_connector=True, + get_editable=False, + user=user, + ) return [ BasicCCPairInfo( has_successful_run=cc_pair.last_successful_index_time is not None, diff --git a/backend/onyx/server/documents/credential.py b/backend/onyx/server/documents/credential.py index b68ee660cb7..61a84792e99 100644 --- a/backend/onyx/server/documents/credential.py +++ b/backend/onyx/server/documents/credential.py @@ -12,9 +12,9 @@ from onyx.db.credentials import create_credential from onyx.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE from onyx.db.credentials import delete_credential -from onyx.db.credentials import fetch_credential_by_id -from onyx.db.credentials import fetch_credentials -from onyx.db.credentials import fetch_credentials_by_source +from onyx.db.credentials import fetch_credential_by_id_for_user +from onyx.db.credentials import fetch_credentials_by_source_for_user +from onyx.db.credentials import fetch_credentials_for_user from onyx.db.credentials import swap_credentials_connector from onyx.db.credentials import update_credential from onyx.db.engine import get_session @@ -48,7 +48,7 @@ def list_credentials_admin( db_session: Session = Depends(get_session), ) -> list[CredentialSnapshot]: """Lists all public credentials""" - credentials = fetch_credentials( + credentials = fetch_credentials_for_user( db_session=db_session, user=user, get_editable=False, @@ -68,7 +68,7 @@ def get_cc_source_full_info( False, description="If true, return editable credentials" ), ) -> list[CredentialSnapshot]: - credentials = fetch_credentials_by_source( + credentials = fetch_credentials_by_source_for_user( db_session=db_session, user=user, document_source=source_type, @@ -148,7 +148,7 @@ def list_credentials( user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> list[CredentialSnapshot]: - credentials = fetch_credentials(db_session=db_session, user=user) + credentials = fetch_credentials_for_user(db_session=db_session, user=user) return [ CredentialSnapshot.from_credential_db_model(credential) for credential in credentials @@ -161,7 +161,7 @@ def get_credential_by_id( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> CredentialSnapshot | StatusResponse[int]: - credential = fetch_credential_by_id( + credential = fetch_credential_by_id_for_user( credential_id, user, db_session, diff --git a/backend/onyx/server/features/persona/api.py b/backend/onyx/server/features/persona/api.py index 814d81595ce..0d4ecc4e2ee 100644 --- a/backend/onyx/server/features/persona/api.py +++ b/backend/onyx/server/features/persona/api.py @@ -28,7 +28,7 @@ from onyx.db.persona import delete_persona_category from onyx.db.persona import get_assistant_categories from onyx.db.persona import get_persona_by_id -from onyx.db.persona import get_personas +from onyx.db.persona import get_personas_for_user from onyx.db.persona import mark_persona_as_deleted from onyx.db.persona import mark_persona_as_not_deleted from onyx.db.persona import update_all_personas_display_priority @@ -125,7 +125,7 @@ def list_personas_admin( ) -> list[PersonaSnapshot]: return [ PersonaSnapshot.from_model(persona) - for persona in get_personas( + for persona in get_personas_for_user( db_session=db_session, user=user, get_editable=get_editable, @@ -329,7 +329,7 @@ def list_personas( include_deleted: bool = False, persona_ids: list[int] = Query(None), ) -> list[PersonaSnapshot]: - personas = get_personas( + personas = get_personas_for_user( user=user, include_deleted=include_deleted, db_session=db_session, diff --git a/backend/onyx/server/manage/administrative.py b/backend/onyx/server/manage/administrative.py index 82577f714ca..687d0c4adff 100644 --- a/backend/onyx/server/manage/administrative.py +++ b/backend/onyx/server/manage/administrative.py @@ -16,20 +16,18 @@ from onyx.configs.constants import KV_GEN_AI_KEY_CHECK_TIME from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryTask -from onyx.db.connector_credential_pair import get_connector_credential_pair +from onyx.db.connector_credential_pair import get_connector_credential_pair_for_user from onyx.db.connector_credential_pair import ( update_connector_credential_pair_from_id, ) from onyx.db.engine import get_current_tenant_id from onyx.db.engine import get_session from onyx.db.enums import ConnectorCredentialPairStatus -from onyx.db.feedback import fetch_docs_ranked_by_boost -from onyx.db.feedback import update_document_boost -from onyx.db.feedback import update_document_hidden +from onyx.db.feedback import fetch_docs_ranked_by_boost_for_user +from onyx.db.feedback import update_document_boost_for_user +from onyx.db.feedback import update_document_hidden_for_user from onyx.db.index_attempt import cancel_indexing_attempts_for_ccpair from onyx.db.models import User -from onyx.document_index.document_index_utils import get_both_index_names -from onyx.document_index.factory import get_default_document_index from onyx.file_store.file_store import get_default_file_store from onyx.key_value_store.factory import get_kv_store from onyx.key_value_store.interface import KvKeyNotFoundError @@ -55,7 +53,7 @@ def get_most_boosted_docs( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> list[BoostDoc]: - boost_docs = fetch_docs_ranked_by_boost( + boost_docs = fetch_docs_ranked_by_boost_for_user( ascending=ascending, limit=limit, db_session=db_session, @@ -80,7 +78,7 @@ def document_boost_update( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse: - update_document_boost( + update_document_boost_for_user( db_session=db_session, document_id=boost_update.document_id, boost=boost_update.boost, @@ -95,16 +93,10 @@ def document_hidden_update( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse: - curr_ind_name, sec_ind_name = get_both_index_names(db_session) - document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name - ) - - update_document_hidden( + update_document_hidden_for_user( db_session=db_session, document_id=hidden_update.document_id, hidden=hidden_update.hidden, - document_index=document_index, user=user, ) return StatusResponse(success=True, message="Updated document boost") @@ -152,7 +144,7 @@ def create_deletion_attempt_for_connector_id( connector_id = connector_credential_pair_identifier.connector_id credential_id = connector_credential_pair_identifier.credential_id - cc_pair = get_connector_credential_pair( + cc_pair = get_connector_credential_pair_for_user( db_session=db_session, connector_id=connector_id, credential_id=credential_id, diff --git a/backend/onyx/server/onyx_api/ingestion.py b/backend/onyx/server/onyx_api/ingestion.py index cd3f90850da..06c059aa802 100644 --- a/backend/onyx/server/onyx_api/ingestion.py +++ b/backend/onyx/server/onyx_api/ingestion.py @@ -80,7 +80,8 @@ def upsert_ingestion_doc( document.source = DocumentSource.FILE cc_pair = get_connector_credential_pair_from_id( - cc_pair_id=doc_info.cc_pair_id or DEFAULT_CC_PAIR_ID, db_session=db_session + db_session=db_session, + cc_pair_id=doc_info.cc_pair_id or DEFAULT_CC_PAIR_ID, ) if cc_pair is None: raise HTTPException( diff --git a/backend/onyx/server/openai_assistants_api/asssistants_api.py b/backend/onyx/server/openai_assistants_api/asssistants_api.py index 53bd228258e..78ef8d45f8a 100644 --- a/backend/onyx/server/openai_assistants_api/asssistants_api.py +++ b/backend/onyx/server/openai_assistants_api/asssistants_api.py @@ -15,7 +15,7 @@ from onyx.db.models import Persona from onyx.db.models import User from onyx.db.persona import get_persona_by_id -from onyx.db.persona import get_personas +from onyx.db.persona import get_personas_for_user from onyx.db.persona import mark_persona_as_deleted from onyx.db.persona import upsert_persona from onyx.db.persona import upsert_prompt @@ -243,7 +243,7 @@ def list_assistants( db_session: Session = Depends(get_session), ) -> ListAssistantsResponse: personas = list( - get_personas( + get_personas_for_user( user=user, db_session=db_session, get_editable=False, diff --git a/backend/onyx/server/query_and_chat/chat_backend.py b/backend/onyx/server/query_and_chat/chat_backend.py index be7ab9eeaaa..0613556576a 100644 --- a/backend/onyx/server/query_and_chat/chat_backend.py +++ b/backend/onyx/server/query_and_chat/chat_backend.py @@ -54,8 +54,6 @@ from onyx.db.feedback import create_doc_retrieval_feedback from onyx.db.models import User from onyx.db.persona import get_persona_by_id -from onyx.document_index.document_index_utils import get_both_index_names -from onyx.document_index.factory import get_default_document_index from onyx.file_processing.extract_file_text import docx_to_txt_filename from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_store.file_store import get_default_file_store @@ -450,19 +448,12 @@ def create_search_feedback( """This endpoint isn't protected - it does not check if the user has access to the document Users could try changing boosts of arbitrary docs but this does not leak any data. """ - - curr_ind_name, sec_ind_name = get_both_index_names(db_session) - document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name - ) - create_doc_retrieval_feedback( message_id=feedback.message_id, document_id=feedback.document_id, document_rank=feedback.document_rank, clicked=feedback.click, feedback=feedback.search_feedback, - document_index=document_index, db_session=db_session, ) diff --git a/backend/scripts/force_delete_connector_by_id.py b/backend/scripts/force_delete_connector_by_id.py index 39b98b9bcb8..89038aeed66 100755 --- a/backend/scripts/force_delete_connector_by_id.py +++ b/backend/scripts/force_delete_connector_by_id.py @@ -141,7 +141,10 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None: return logger.notice("Getting connector credential pair") - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + cc_pair = get_connector_credential_pair_from_id( + db_session=db_session, + cc_pair_id=cc_pair_id, + ) if not cc_pair: logger.error(f"Connector credential pair with ID {cc_pair_id} not found")