Skip to content

Commit

Permalink
refactored _add_user_filter usage (#3674)
Browse files Browse the repository at this point in the history
* refactored db.connector_credential_pair

* Rerfactored the db.credentials user filtering

* the restr
  • Loading branch information
hagen-danswer authored Jan 14, 2025
1 parent 5f46205 commit b195773
Show file tree
Hide file tree
Showing 30 changed files with 255 additions and 168 deletions.
4 changes: 2 additions & 2 deletions backend/ee/onyx/db/token_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def insert_user_group_token_rate_limit(
return token_limit


def fetch_user_group_token_rate_limits(
def fetch_user_group_token_rate_limits_for_user(
db_session: Session,
group_id: int,
user: User | None = None,
user: User | None,
enabled_only: bool = False,
ordered: bool = True,
get_editable: bool = True,
Expand Down
5 changes: 4 additions & 1 deletion backend/ee/onyx/db/user_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,10 @@ def delete_user_group_cc_pair_relationship__no_commit(
connector_credential_pair_id matches the given cc_pair_id.
Should be used very carefully (only for connectors that are being deleted)."""
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
raise ValueError(f"Connector Credential Pair '{cc_pair_id}' does not exist")

Expand Down
5 changes: 4 additions & 1 deletion backend/ee/onyx/external_permissions/salesforce/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales

cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if cc_pair is None:
raise ValueError(f"CC pair {cc_pair_id} not found")
credential_json = cc_pair.credential.credential_json
Expand Down
8 changes: 5 additions & 3 deletions backend/ee/onyx/server/token_rate_limits/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy.orm import Session

from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
Expand Down Expand Up @@ -51,8 +51,10 @@ def get_group_token_limit_settings(
) -> list[TokenRateLimitDisplay]:
return [
TokenRateLimitDisplay.from_db(token_rate_limit)
for token_rate_limit in fetch_user_group_token_rate_limits(
db_session, group_id, user
for token_rate_limit in fetch_user_group_token_rate_limits_for_user(
db_session=db_session,
group_id=group_id,
user=user,
)
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
# we need to load the state of the object inside the fence
# to avoid a race condition with db.commit/fence deletion
# at the end of this taskset
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,10 @@ def connector_permission_sync_generator_task(

try:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,10 @@ def connector_external_group_sync_generator_task(
return None

with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
Expand Down
5 changes: 3 additions & 2 deletions backend/onyx/background/celery/tasks/indexing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)

cc_pair = get_connector_credential_pair_from_id(
cc_pair_id, db_session
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
continue
Expand Down Expand Up @@ -1198,8 +1199,8 @@ def connector_indexing_task(
attempt_found = True

cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
cc_pair_id=cc_pair_id,
)

if not cc_pair:
Expand Down
5 changes: 4 additions & 1 deletion backend/onyx/background/celery/tasks/pruning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
continue

Expand Down
10 changes: 8 additions & 2 deletions backend/onyx/background/celery/tasks/vespa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,10 @@ def try_generate_document_set_sync_tasks(

# don't generate sync tasks if we're up to date
# race condition with the monitor/cleanup function if we use a cached result!
document_set = get_document_set_by_id(db_session, document_set_id)
document_set = get_document_set_by_id(
db_session=db_session,
document_set_id=document_set_id,
)
if not document_set:
return None

Expand Down Expand Up @@ -532,7 +535,10 @@ def monitor_connector_deletion_taskset(
return

with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
Expand Down
3 changes: 2 additions & 1 deletion backend/onyx/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def _get_connector_runner(
# it will never succeed

cc_pair = get_connector_credential_pair_from_id(
attempt.connector_credential_pair.id, db_session
db_session=db_session,
cc_pair_id=attempt.connector_credential_pair.id,
)
if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE:
update_connector_credential_pair(
Expand Down
88 changes: 55 additions & 33 deletions backend/onyx/db/connector_credential_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from sqlalchemy.orm import Session

from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.constants import DocumentSource
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
Expand Down Expand Up @@ -92,10 +91,9 @@ def _add_user_filters(
return stmt.where(where_clause)


def get_connector_credential_pairs(
def get_connector_credential_pairs_for_user(
db_session: Session,
include_disabled: bool = True,
user: User | None = None,
user: User | None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
Expand All @@ -106,11 +104,18 @@ def get_connector_credential_pairs(
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))

stmt = _add_user_filters(stmt, user, get_editable)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))

return list(db_session.scalars(stmt).all())


def get_connector_credential_pairs(
db_session: Session,
ids: list[int] | None = None,
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair).distinct()

if not include_disabled:
stmt = stmt.where(
ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE
)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))

Expand All @@ -122,7 +127,10 @@ def add_deletion_failure_message(
cc_pair_id: int,
failure_message: str,
) -> None:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
return
cc_pair.deletion_failure_message = failure_message
Expand All @@ -132,24 +140,21 @@ def add_deletion_failure_message(
def get_cc_pair_groups_for_ids(
db_session: Session,
cc_pair_ids: list[int],
user: User | None = None,
get_editable: bool = True,
) -> list[UserGroup__ConnectorCredentialPair]:
stmt = select(UserGroup__ConnectorCredentialPair).distinct()
stmt = stmt.outerjoin(
ConnectorCredentialPair,
UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id,
)
stmt = _add_user_filters(stmt, user, get_editable)
stmt = stmt.where(UserGroup__ConnectorCredentialPair.cc_pair_id.in_(cc_pair_ids))
return list(db_session.scalars(stmt).all())


def get_connector_credential_pair(
def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,
credential_id: int,
db_session: Session,
user: User | None = None,
user: User | None,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
Expand All @@ -160,28 +165,36 @@ def get_connector_credential_pair(
return result.scalar_one_or_none()


def get_connector_credential_source_from_id(
def get_connector_credential_pair(
db_session: Session,
connector_id: int,
credential_id: int,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()


def get_connector_credential_pair_from_id_for_user(
cc_pair_id: int,
db_session: Session,
user: User | None = None,
user: User | None,
get_editable: bool = True,
) -> DocumentSource | None:
stmt = select(ConnectorCredentialPair)
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
stmt = _add_user_filters(stmt, user, get_editable)
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
result = db_session.execute(stmt)
cc_pair = result.scalar_one_or_none()
return cc_pair.connector.source if cc_pair else None
return result.scalar_one_or_none()


def get_connector_credential_pair_from_id(
cc_pair_id: int,
db_session: Session,
user: User | None = None,
get_editable: bool = True,
cc_pair_id: int,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
stmt = _add_user_filters(stmt, user, get_editable)
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
Expand All @@ -198,7 +211,9 @@ def get_last_successful_attempt_time(
the CC Pair row in the database"""
if search_settings.status == IndexModelStatus.PRESENT:
connector_credential_pair = get_connector_credential_pair(
connector_id, credential_id, db_session
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if (
connector_credential_pair is None
Expand Down Expand Up @@ -259,7 +274,10 @@ def update_connector_credential_pair_from_id(
net_docs: int | None = None,
run_dt: datetime | None = None,
) -> None:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
logger.warning(
f"Attempted to update pair for Connector Credential Pair '{cc_pair_id}'"
Expand All @@ -284,7 +302,11 @@ def update_connector_credential_pair(
net_docs: int | None = None,
run_dt: datetime | None = None,
) -> None:
cc_pair = get_connector_credential_pair(connector_id, credential_id, db_session)
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
logger.warning(
f"Attempted to update pair for connector id {connector_id} "
Expand Down Expand Up @@ -368,7 +390,7 @@ def add_credential_to_connector(
last_successful_index_time: datetime | None = None,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(
credential = fetch_credential_by_id_for_user(
credential_id,
user,
db_session,
Expand Down Expand Up @@ -450,7 +472,7 @@ def remove_credential_from_connector(
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(
credential = fetch_credential_by_id_for_user(
credential_id,
user,
db_session,
Expand All @@ -466,10 +488,10 @@ def remove_credential_from_connector(
detail="Credential does not exist or does not belong to user",
)

association = get_connector_credential_pair(
association = get_connector_credential_pair_for_user(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
db_session=db_session,
user=user,
get_editable=True,
)
Expand Down
Loading

0 comments on commit b195773

Please sign in to comment.