diff --git a/backend/danswer/access/models.py b/backend/danswer/access/models.py index 46b9c0efa93..ad4c3eb5812 100644 --- a/backend/danswer/access/models.py +++ b/backend/danswer/access/models.py @@ -16,6 +16,13 @@ class ExternalAccess: is_public: bool +@dataclass(frozen=True) +class DocumentExternalAccess: + external_access: ExternalAccess + # The document ID + doc_id: str + + @dataclass(frozen=True) class DocumentAccess(ExternalAccess): # User emails for Danswer users, None indicates admin diff --git a/backend/danswer/background/celery/apps/app_base.py b/backend/danswer/background/celery/apps/app_base.py index 117293f1c2f..7df390929d0 100644 --- a/backend/danswer/background/celery/apps/app_base.py +++ b/backend/danswer/background/celery/apps/app_base.py @@ -19,6 +19,7 @@ from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair from danswer.redis.redis_connector_delete import RedisConnectorDelete +from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorDocPermSyncs from danswer.redis.redis_connector_prune import RedisConnectorPrune from danswer.redis.redis_document_set import RedisDocumentSet from danswer.redis.redis_pool import get_redis_client @@ -132,6 +133,12 @@ def on_task_postrun( RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r) return + if task_id.startswith(RedisConnectorDocPermSyncs.SUBTASK_PREFIX): + cc_pair_id = RedisConnector.get_id_from_task_id(task_id) + if cc_pair_id is not None: + RedisConnectorDocPermSyncs.remove_from_taskset(int(cc_pair_id), task_id, r) + return + def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: """The first signal sent on celery worker startup""" diff --git a/backend/danswer/background/celery/apps/primary.py b/backend/danswer/background/celery/apps/primary.py index b0950fc8f03..4af3db17b92 100644 --- a/backend/danswer/background/celery/apps/primary.py +++ b/backend/danswer/background/celery/apps/primary.py @@ -20,6 +20,7 @@ from danswer.db.engine import SqlEngine from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair from danswer.redis.redis_connector_delete import RedisConnectorDelete +from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorDocPermSyncs from danswer.redis.redis_connector_index import RedisConnectorIndex from danswer.redis.redis_connector_prune import RedisConnectorPrune from danswer.redis.redis_connector_stop import RedisConnectorStop @@ -131,6 +132,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: RedisConnectorStop.reset_all(r) + RedisConnectorDocPermSyncs.reset_all(r) + @worker_ready.connect def on_worker_ready(sender: Any, **kwargs: Any) -> None: diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index d0df7af02d7..c8a125d11b1 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -81,7 +81,7 @@ def extract_ids_from_runnable_connector( callback: RunIndexingCallbackInterface | None = None, ) -> set[str]: """ - If the PruneConnector hasnt been implemented for the given connector, just pull + If the SlimConnector hasnt been implemented for the given connector, just pull all docs using the load_from_state and grab out the IDs. Optionally, a callback can be passed to handle the length of each document batch. diff --git a/backend/danswer/background/celery/tasks/beat_schedule.py b/backend/danswer/background/celery/tasks/beat_schedule.py index 6a20c6ba5c1..72a0dcd3bec 100644 --- a/backend/danswer/background/celery/tasks/beat_schedule.py +++ b/backend/danswer/background/celery/tasks/beat_schedule.py @@ -41,6 +41,12 @@ "schedule": timedelta(seconds=5), "options": {"priority": DanswerCeleryPriority.HIGH}, }, + { + "name": "check-for-doc-permissions-sync", + "task": "check_for_doc_permissions_sync", + "schedule": timedelta(seconds=10), + "options": {"priority": DanswerCeleryPriority.HIGH}, + }, ] diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index 360481015bb..85d27b2ba16 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -143,6 +143,12 @@ def try_generate_document_cc_pair_cleanup_tasks( f"cc_pair={cc_pair_id}" ) + if redis_connector.permissions.fenced: + raise TaskDependencyError( + f"Connector deletion - Delayed (permissions in progress): " + f"cc_pair={cc_pair_id}" + ) + # add tasks to celery and build up the task set to monitor in redis redis_connector.delete.taskset_clear() diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index af80e6b886c..56065c980a2 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -38,6 +38,35 @@ logger = setup_logger() +def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool: + """Returns boolean indicating if pruning is due.""" + + # skip pruning if no prune frequency is set + # pruning can still be forced via the API which will run a pruning task directly + if not cc_pair.connector.prune_freq: + return False + + # skip pruning if not active + if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: + return False + + # skip pruning if the next scheduled prune time hasn't been reached yet + last_pruned = cc_pair.last_pruned + if not last_pruned: + if not cc_pair.last_successful_index_time: + # if we've never indexed, we can't prune + return False + + # if never pruned, use the last time the connector indexed successfully + last_pruned = cc_pair.last_successful_index_time + + next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq) + if datetime.now(timezone.utc) < next_prune: + return False + + return True + + @shared_task( name="check_for_pruning", soft_time_limit=JOB_TIMEOUT, @@ -69,7 +98,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None: if not cc_pair: continue - if not is_pruning_due(cc_pair, db_session, r): + if not _is_pruning_due(cc_pair): continue tasks_created = try_creating_prune_generator_task( @@ -90,47 +119,6 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None: lock_beat.release() -def is_pruning_due( - cc_pair: ConnectorCredentialPair, - db_session: Session, - r: Redis, -) -> bool: - """Returns an int if pruning is triggered. - The int represents the number of prune tasks generated (in this case, only one - because the task is a long running generator task.) - Returns None if no pruning is triggered (due to not being needed or - other reasons such as simultaneous pruning restrictions. - - Checks for scheduling related conditions, then delegates the rest of the checks to - try_creating_prune_generator_task. - """ - - # skip pruning if no prune frequency is set - # pruning can still be forced via the API which will run a pruning task directly - if not cc_pair.connector.prune_freq: - return False - - # skip pruning if not active - if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: - return False - - # skip pruning if the next scheduled prune time hasn't been reached yet - last_pruned = cc_pair.last_pruned - if not last_pruned: - if not cc_pair.last_successful_index_time: - # if we've never indexed, we can't prune - return False - - # if never pruned, use the last time the connector indexed successfully - last_pruned = cc_pair.last_successful_index_time - - next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq) - if datetime.now(timezone.utc) < next_prune: - return False - - return True - - def try_creating_prune_generator_task( celery_app: Celery, cc_pair: ConnectorCredentialPair, @@ -172,6 +160,11 @@ def try_creating_prune_generator_task( if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting return None + if ( + redis_connector.permissions.fenced + ): # skip pruning if the cc_pair is deleting + return None + db_session.refresh(cc_pair) if cc_pair.status == ConnectorCredentialPairStatus.DELETING: return None diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index b01a0eac815..4c15f5753bb 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -27,6 +27,7 @@ from danswer.configs.constants import DanswerCeleryQueues from danswer.configs.constants import DanswerRedisLocks from danswer.db.connector import fetch_connector_by_id +from danswer.db.connector import mark_ccpair_as_permissions_synced from danswer.db.connector import mark_ccpair_as_pruned from danswer.db.connector_credential_pair import add_deletion_failure_message from danswer.db.connector_credential_pair import ( @@ -58,6 +59,7 @@ from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair from danswer.redis.redis_connector_delete import RedisConnectorDelete +from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorDocPermSyncs from danswer.redis.redis_connector_index import RedisConnectorIndex from danswer.redis.redis_connector_prune import RedisConnectorPrune from danswer.redis.redis_document_set import RedisDocumentSet @@ -546,6 +548,42 @@ def monitor_ccpair_pruning_taskset( redis_connector.prune.set_fence(False) +def monitor_ccpair_permissions_taskset( + tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session +) -> None: + fence_key = key_bytes.decode("utf-8") + cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) + if cc_pair_id_str is None: + task_logger.warning( + f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}" + ) + return + + cc_pair_id = int(cc_pair_id_str) + + redis_connector = RedisConnector(tenant_id, cc_pair_id) + if not redis_connector.permissions.fenced: + return + + initial = redis_connector.permissions.generator_complete + if initial is None: + return + + remaining = redis_connector.permissions.get_remaining() + task_logger.info( + f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}" + ) + if remaining > 0: + return + + mark_ccpair_as_permissions_synced(int(cc_pair_id), db_session) + task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}") + + redis_connector.permissions.taskset_clear() + redis_connector.permissions.generator_clear() + redis_connector.permissions.set_fence(False) + + def monitor_ccpair_indexing_taskset( tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session ) -> None: @@ -741,6 +779,12 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: with get_session_with_tenant(tenant_id) as db_session: monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session) + lock_beat.reacquire() + for key_bytes in r.scan_iter(RedisConnectorDocPermSyncs.FENCE_PREFIX + "*"): + lock_beat.reacquire() + with get_session_with_tenant(tenant_id) as db_session: + monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session) + # uncomment for debugging if needed # r_celery = celery_app.broker_connection().channel().client # length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery) diff --git a/backend/danswer/background/task_utils.py b/backend/danswer/background/task_utils.py index c1c24bf92a1..f4562892460 100644 --- a/backend/danswer/background/task_utils.py +++ b/backend/danswer/background/task_utils.py @@ -14,15 +14,6 @@ from danswer.db.tasks import register_task -def name_cc_prune_task( - connector_id: int | None = None, credential_id: int | None = None -) -> str: - task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}" - if not connector_id or not credential_id: - task_name = "prune_connector_credential_pair" - return task_name - - T = TypeVar("T", bound=Callable) diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 671fd13e2a0..ebf23adbd84 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -80,6 +80,8 @@ # if we can get callbacks as object bytes download, we could lower this a lot. CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min +CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min + DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:" @@ -212,6 +214,7 @@ class DanswerCeleryQueues: CONNECTOR_DELETION = "connector_deletion" CONNECTOR_PRUNING = "connector_pruning" CONNECTOR_INDEXING = "connector_indexing" + CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync" class DanswerRedisLocks: @@ -220,6 +223,9 @@ class DanswerRedisLocks: CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat" CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat" CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat" + CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = ( + "da_lock:check_connector_doc_permissions_sync_beat" + ) MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat" PRUNING_LOCK_PREFIX = "da_lock:pruning" diff --git a/backend/danswer/db/connector.py b/backend/danswer/db/connector.py index 835f74d437c..81733734946 100644 --- a/backend/danswer/db/connector.py +++ b/backend/danswer/db/connector.py @@ -282,3 +282,15 @@ def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None: cc_pair.last_pruned = datetime.now(timezone.utc) db_session.commit() + + +def mark_ccpair_as_permissions_synced(cc_pair_id: int, db_session: Session) -> None: + stmt = select(ConnectorCredentialPair).where( + ConnectorCredentialPair.id == cc_pair_id + ) + cc_pair = db_session.scalar(stmt) + if cc_pair is None: + raise ValueError(f"No cc_pair with ID: {cc_pair_id}") + + cc_pair.last_time_perm_sync = datetime.now(timezone.utc) + db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index ce8e21c52e8..f0fd61b0145 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -420,6 +420,9 @@ class ConnectorCredentialPair(Base): last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) + last_time_external_group_sync: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) # Time finished, not used for calculating backend jobs which uses time started (created) last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None diff --git a/backend/danswer/db/users.py b/backend/danswer/db/users.py index 1ff21b71006..2c173f04443 100644 --- a/backend/danswer/db/users.py +++ b/backend/danswer/db/users.py @@ -97,3 +97,18 @@ def batch_add_non_web_user_if_not_exists__no_commit( db_session.flush() # generate ids return found_users + new_users + + +def batch_add_non_web_user_if_not_exists( + db_session: Session, emails: list[str] +) -> list[User]: + found_users, missing_user_emails = get_users_by_emails(db_session, emails) + + new_users: list[User] = [] + for email in missing_user_emails: + new_users.append(_generate_non_web_user(email=email)) + + db_session.add_all(new_users) + db_session.commit() + + return found_users + new_users diff --git a/backend/danswer/redis/redis_connector.py b/backend/danswer/redis/redis_connector.py index df61f986ede..d946cae0527 100644 --- a/backend/danswer/redis/redis_connector.py +++ b/backend/danswer/redis/redis_connector.py @@ -1,6 +1,7 @@ import redis from danswer.redis.redis_connector_delete import RedisConnectorDelete +from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorDocPermSyncs from danswer.redis.redis_connector_index import RedisConnectorIndex from danswer.redis.redis_connector_prune import RedisConnectorPrune from danswer.redis.redis_connector_stop import RedisConnectorStop @@ -19,6 +20,7 @@ def __init__(self, tenant_id: str | None, id: int) -> None: self.stop = RedisConnectorStop(tenant_id, id, self.redis) self.prune = RedisConnectorPrune(tenant_id, id, self.redis) self.delete = RedisConnectorDelete(tenant_id, id, self.redis) + self.permissions = RedisConnectorDocPermSyncs(tenant_id, id, self.redis) def new_index(self, search_settings_id: int) -> RedisConnectorIndex: return RedisConnectorIndex( diff --git a/backend/danswer/redis/redis_connector_doc_perm_sync.py b/backend/danswer/redis/redis_connector_doc_perm_sync.py new file mode 100644 index 00000000000..221ca2fd972 --- /dev/null +++ b/backend/danswer/redis/redis_connector_doc_perm_sync.py @@ -0,0 +1,162 @@ +import time +from typing import cast +from uuid import uuid4 + +import redis +from celery import Celery +from sqlalchemy.orm import Session + +from danswer.access.models import DocumentExternalAccess +from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerCeleryQueues +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id + + +class RedisConnectorDocPermSyncs: + """Manages interactions with redis for permission sync tasks. Should only be accessed + through RedisConnector.""" + + PREFIX = "connector_doc_permissions" + + FENCE_PREFIX = f"{PREFIX}_fence" + + # phase 1 - geneartor task and progress signals + GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpermissions+generator + GENERATOR_PROGRESS_PREFIX = ( + PREFIX + "_generator_progress" + ) # connectorpermissions_generator_progress + GENERATOR_COMPLETE_PREFIX = ( + PREFIX + "_generator_complete" + ) # connectorpermissions_generator_complete + + TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset + SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub + + def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None: + self.tenant_id: str | None = tenant_id + self.id = id + self.redis = redis + + self.fence_key: str = f"{self.FENCE_PREFIX}_{id}" + self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}" + self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}" + self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}" + + self.taskset_key = f"{self.TASKSET_PREFIX}_{id}" + + self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}" + + def taskset_clear(self) -> None: + self.redis.delete(self.taskset_key) + + def generator_clear(self) -> None: + self.redis.delete(self.generator_progress_key) + self.redis.delete(self.generator_complete_key) + + def get_remaining(self) -> int: + remaining = cast(int, self.redis.scard(self.taskset_key)) + return remaining + + def get_active_task_count(self) -> int: + """Count of active permission sync tasks""" + count = 0 + for key in self.redis.scan_iter(RedisConnectorDocPermSyncs.FENCE_PREFIX + "*"): + count += 1 + return count + + @property + def fenced(self) -> bool: + if self.redis.exists(self.fence_key): + return True + return False + + def set_fence(self, value: bool) -> None: + if not value: + self.redis.delete(self.fence_key) + return + self.redis.set(self.fence_key, 0) + + @property + def generator_complete(self) -> int | None: + """the fence payload is an int representing the starting number of + permission sync tasks to be processed ... just after the generator completes.""" + fence_bytes = self.redis.get(self.generator_complete_key) + if fence_bytes is None: + return None + fence_int = cast(int, fence_bytes) + return fence_int + + @generator_complete.setter + def generator_complete(self, payload: int | None) -> None: + """Set the payload to an int to set the fence, otherwise if None it will + be deleted""" + if payload is None: + self.redis.delete(self.generator_complete_key) + return + self.redis.set(self.generator_complete_key, payload) + + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + lock: redis.lock.Lock | None, + new_permissions: list[DocumentExternalAccess], + ) -> int | None: + last_lock_time = time.monotonic() + + async_results = [] + cc_pair = get_connector_credential_pair_from_id(int(self.id), db_session) + if not cc_pair: + return None + + for doc_perm in new_permissions: + current_time = time.monotonic() + if lock and current_time - last_lock_time >= ( + CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 + ): + lock.reacquire() + last_lock_time = current_time + # Add task for document permissions sync + custom_task_id = f"{self.subtask_prefix}_{uuid4()}" + self.redis.sadd(self.taskset_key, custom_task_id) + + result = celery_app.send_task( + "update_external_document_permissions_task", + kwargs=dict( + cc_pair_id=cc_pair.id, + tenant_id=self.tenant_id, + document_external_access=doc_perm, + ), + queue=DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, + task_id=custom_task_id, + priority=DanswerCeleryPriority.MEDIUM, + ) + async_results.append(result) + + return len(async_results) + + @staticmethod + def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None: + taskset_key = f"{RedisConnectorDocPermSyncs.TASKSET_PREFIX}_{id}" + r.srem(taskset_key, task_id) + return + + @staticmethod + def reset_all(r: redis.Redis) -> None: + """Deletes all redis values for all connectors""" + for key in r.scan_iter(RedisConnectorDocPermSyncs.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter( + RedisConnectorDocPermSyncs.GENERATOR_COMPLETE_PREFIX + "*" + ): + r.delete(key) + + for key in r.scan_iter( + RedisConnectorDocPermSyncs.GENERATOR_PROGRESS_PREFIX + "*" + ): + r.delete(key) + + for key in r.scan_iter(RedisConnectorDocPermSyncs.FENCE_PREFIX + "*"): + r.delete(key) diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 68b48b85b0f..be0fd017c84 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -35,20 +35,17 @@ from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id from danswer.db.models import User from danswer.db.search_settings import get_current_search_settings -from danswer.db.tasks import check_task_is_live_and_not_timed_out -from danswer.db.tasks import get_latest_task from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_pool import get_redis_client from danswer.server.documents.models import CCPairFullInfo from danswer.server.documents.models import CCStatusUpdateRequest -from danswer.server.documents.models import CeleryTaskStatus from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorCredentialPairMetadata from danswer.server.documents.models import PaginatedIndexAttempts from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger -from ee.danswer.background.task_name_builders import ( - name_sync_external_doc_permissions_task, +from ee.danswer.background.celery.tasks.permissions.tasks import ( + try_creating_permissions_sync_task, ) from ee.danswer.db.user_group import validate_user_creation_permissions @@ -294,7 +291,7 @@ def get_cc_pair_latest_sync( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), -) -> CeleryTaskStatus: +) -> datetime | None: cc_pair = get_connector_credential_pair_from_id( cc_pair_id=cc_pair_id, db_session=db_session, @@ -304,25 +301,10 @@ def get_cc_pair_latest_sync( if not cc_pair: raise HTTPException( status_code=400, - detail="Connection not found for current user's permissions", + detail="cc_pair not found for current user's permissions", ) - # look up the last sync task for this connector (if it exists) - sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id) - last_sync_task = get_latest_task(sync_task_name, db_session) - if not last_sync_task: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="No sync task found.", - ) - - return CeleryTaskStatus( - id=last_sync_task.task_id, - name=last_sync_task.task_name, - status=last_sync_task.status, - start_time=last_sync_task.start_time, - register_time=last_sync_task.register_time, - ) + return cc_pair.last_time_perm_sync @router.post("/admin/cc-pair/{cc_pair_id}/sync") @@ -330,11 +312,9 @@ def sync_cc_pair( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), ) -> StatusResponse[list[int]]: - # avoiding circular refs - from ee.danswer.background.celery.apps.primary import ( - sync_external_doc_permissions_task, - ) + """Triggers permissions sync on a particular cc_pair immediately""" cc_pair = get_connector_credential_pair_from_id( cc_pair_id=cc_pair_id, @@ -348,27 +328,33 @@ def sync_cc_pair( detail="Connection not found for current user's permissions", ) - sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id) - last_sync_task = get_latest_task(sync_task_name, db_session) + r = get_redis_client(tenant_id=tenant_id) - if last_sync_task and check_task_is_live_and_not_timed_out( - last_sync_task, db_session - ): + redis_connector = RedisConnector(tenant_id, cc_pair_id) + if redis_connector.permissions.fenced: raise HTTPException( status_code=HTTPStatus.CONFLICT, - detail="Sync task already in progress.", + detail="Pruning task already in progress.", ) - logger.info(f"Syncing the {cc_pair.connector.name} connector.") - sync_external_doc_permissions_task.apply_async( - kwargs=dict( - cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get() - ), + logger.info( + f"permissions sync cc_pair: cc_pair_id={cc_pair_id} " + f"connector_id={cc_pair.connector_id} " + f"credential_id={cc_pair.credential_id} " + f"{cc_pair.connector.name} connector." ) + tasks_created = try_creating_permissions_sync_task( + primary_app, cc_pair, r, CURRENT_TENANT_ID_CONTEXTVAR.get() + ) + if not tasks_created: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Permissions sync task creation failed.", + ) return StatusResponse( success=True, - message="Successfully created the sync task.", + message="Successfully created the permissions sync task.", ) diff --git a/backend/danswer/utils/logger.py b/backend/danswer/utils/logger.py index bd784513898..c347eacb077 100644 --- a/backend/danswer/utils/logger.py +++ b/backend/danswer/utils/logger.py @@ -21,6 +21,10 @@ "pruning_ctx", default=dict() ) +external_doc_permissions_ctx: contextvars.ContextVar[ + dict[str, Any] +] = contextvars.ContextVar("external_doc_permissions_ctx", default=dict()) + class IndexAttemptSingleton: """Used to tell if this process is an indexing job, and if so what is the @@ -69,6 +73,7 @@ def process( index_attempt_id = IndexAttemptSingleton.get_index_attempt_id() cc_pair_id = IndexAttemptSingleton.get_connector_credential_pair_id() + external_doc_permissions_ctx_dict = external_doc_permissions_ctx.get() pruning_ctx_dict = pruning_ctx.get() if len(pruning_ctx_dict) > 0: if "request_id" in pruning_ctx_dict: @@ -76,6 +81,9 @@ def process( if "cc_pair_id" in pruning_ctx_dict: msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}" + elif len(external_doc_permissions_ctx_dict) > 0: + if "request_id" in external_doc_permissions_ctx_dict: + msg = f"[External Doc Permissions: {external_doc_permissions_ctx_dict['request_id']}] {msg}" else: if index_attempt_id is not None: msg = f"[Index Attempt: {index_attempt_id}] {msg}" diff --git a/backend/ee/danswer/background/celery/apps/primary.py b/backend/ee/danswer/background/celery/apps/primary.py index fecc21b58ef..21644228484 100644 --- a/backend/ee/danswer/background/celery/apps/primary.py +++ b/backend/ee/danswer/background/celery/apps/primary.py @@ -5,28 +5,8 @@ from danswer.db.engine import get_session_with_tenant from danswer.server.settings.store import load_settings from danswer.utils.logger import setup_logger -from danswer.utils.variable_functionality import global_version from ee.danswer.background.celery_utils import should_perform_chat_ttl_check -from ee.danswer.background.celery_utils import ( - should_perform_external_doc_permissions_check, -) -from ee.danswer.background.celery_utils import ( - should_perform_external_group_permissions_check, -) from ee.danswer.background.task_name_builders import name_chat_ttl_task -from ee.danswer.background.task_name_builders import ( - name_sync_external_doc_permissions_task, -) -from ee.danswer.background.task_name_builders import ( - name_sync_external_group_permissions_task, -) -from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs -from ee.danswer.external_permissions.permission_sync import ( - run_external_doc_permission_sync, -) -from ee.danswer.external_permissions.permission_sync import ( - run_external_group_permission_sync, -) from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR @@ -34,25 +14,6 @@ logger = setup_logger() # mark as EE for all tasks in this file -global_version.set_ee() - - -@build_celery_task_wrapper(name_sync_external_doc_permissions_task) -@celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_external_doc_permissions_task( - cc_pair_id: int, *, tenant_id: str | None -) -> None: - with get_session_with_tenant(tenant_id) as db_session: - run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) - - -@build_celery_task_wrapper(name_sync_external_group_permissions_task) -@celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_external_group_permissions_task( - cc_pair_id: int, *, tenant_id: str | None -) -> None: - with get_session_with_tenant(tenant_id) as db_session: - run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) @build_celery_task_wrapper(name_chat_ttl_task) @@ -67,38 +28,6 @@ def perform_ttl_management_task( ##### # Periodic Tasks ##### -@celery_app.task( - name="check_sync_external_doc_permissions_task", - soft_time_limit=JOB_TIMEOUT, -) -def check_sync_external_doc_permissions_task(*, tenant_id: str | None) -> None: - """Runs periodically to sync external permissions""" - with get_session_with_tenant(tenant_id) as db_session: - cc_pairs = get_all_auto_sync_cc_pairs(db_session) - for cc_pair in cc_pairs: - if should_perform_external_doc_permissions_check( - cc_pair=cc_pair, db_session=db_session - ): - sync_external_doc_permissions_task.apply_async( - kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id), - ) - - -@celery_app.task( - name="check_sync_external_group_permissions_task", - soft_time_limit=JOB_TIMEOUT, -) -def check_sync_external_group_permissions_task(*, tenant_id: str | None) -> None: - """Runs periodically to sync external group permissions""" - with get_session_with_tenant(tenant_id) as db_session: - cc_pairs = get_all_auto_sync_cc_pairs(db_session) - for cc_pair in cc_pairs: - if should_perform_external_group_permissions_check( - cc_pair=cc_pair, db_session=db_session - ): - sync_external_group_permissions_task.apply_async( - kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id), - ) @celery_app.task( diff --git a/backend/ee/danswer/background/celery/tasks/permissions/tasks.py b/backend/ee/danswer/background/celery/tasks/permissions/tasks.py new file mode 100644 index 00000000000..d3ddfc06d1a --- /dev/null +++ b/backend/ee/danswer/background/celery/tasks/permissions/tasks.py @@ -0,0 +1,318 @@ +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from uuid import uuid4 + +from celery import Celery +from celery import shared_task +from celery import Task +from celery.exceptions import SoftTimeLimitExceeded +from redis import Redis + +from danswer.access.models import DocumentExternalAccess +from danswer.background.celery.apps.app_base import task_logger +from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING +from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT +from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX +from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import DanswerCeleryQueues +from danswer.configs.constants import DanswerRedisLocks +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id +from danswer.db.engine import get_session_with_tenant +from danswer.db.enums import AccessType +from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.db.models import ConnectorCredentialPair +from danswer.db.users import batch_add_non_web_user_if_not_exists +from danswer.document_index.factory import get_current_primary_default_document_index +from danswer.document_index.interfaces import UpdateRequest +from danswer.redis.redis_connector import RedisConnector +from danswer.redis.redis_pool import get_redis_client +from danswer.utils.logger import external_doc_permissions_ctx +from danswer.utils.logger import setup_logger +from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs +from ee.danswer.db.document import upsert_document_external_perms +from ee.danswer.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS +from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP + +logger = setup_logger() + + +DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3 + + +# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT +LIGHT_SOFT_TIME_LIMIT = 105 +LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15 + + +def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool: + """Returns boolean indicating if external doc permissions sync is due.""" + + if cc_pair.access_type != AccessType.SYNC: + return False + + # skip pruning if not active + if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: + return False + + # If the last sync is None, it has never been run so we run the sync + last_perm_sync = cc_pair.last_time_perm_sync + if last_perm_sync is None: + return True + + source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source) + + # If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync. + if not source_sync_period: + return True + + # If the last sync is greater than the full fetch period, we run the sync + next_sync = last_perm_sync + timedelta(seconds=source_sync_period) + if datetime.now(timezone.utc) >= next_sync: + return True + + return False + + +@shared_task( + name="check_for_permissions_sync", + soft_time_limit=JOB_TIMEOUT, + bind=True, +) +def check_for_permissions_sync(self: Task, *, tenant_id: str | None) -> None: + r = get_redis_client(tenant_id=tenant_id) + + lock_beat = r.lock( + DanswerRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK, + timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, + ) + + try: + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return + + cc_pairs: list[ConnectorCredentialPair] = [] + with get_session_with_tenant(tenant_id) as db_session: + cc_pairs = get_all_auto_sync_cc_pairs(db_session) + + for cc_pair in cc_pairs: + with get_session_with_tenant(tenant_id) as db_session: + if not _is_external_doc_permissions_sync_due(cc_pair): + continue + + tasks_created = try_creating_permissions_sync_task( + self.app, cc_pair, r, tenant_id + ) + if not tasks_created: + continue + + task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair.id}") + except SoftTimeLimitExceeded: + task_logger.info( + "Soft time limit exceeded, task is being terminated gracefully." + ) + except Exception: + task_logger.exception(f"Unexpected exception: tenant={tenant_id}") + finally: + if lock_beat.owned(): + lock_beat.release() + + +def try_creating_permissions_sync_task( + app: Celery, + cc_pair: ConnectorCredentialPair, + r: Redis, + tenant_id: str | None, +) -> int | None: + """Returns an int if syncing is needed. The int represents the number of sync tasks generated. + Returns None if no syncing is required.""" + redis_connector = RedisConnector(tenant_id, cc_pair.id) + + if not ALLOW_SIMULTANEOUS_PRUNING: + count = redis_connector.permissions.get_active_task_count() + if count > 0: + return None + + LOCK_TIMEOUT = 30 + + lock = r.lock( + DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks", + timeout=LOCK_TIMEOUT, + ) + + acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2) + if not acquired: + return None + + try: + if redis_connector.permissions.fenced: + return None + + if redis_connector.delete.fenced: + return None + + if redis_connector.prune.fenced: + return None + + if cc_pair.status == ConnectorCredentialPairStatus.DELETING: + return None + + redis_connector.permissions.generator_clear() + redis_connector.permissions.taskset_clear() + + custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}" + + app.send_task( + "connector_permission_sync_generator_task", + kwargs=dict( + cc_pair=cc_pair, + tenant_id=tenant_id, + ), + queue=DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, + task_id=custom_task_id, + priority=DanswerCeleryPriority.HIGH, + ) + + redis_connector.permissions.set_fence(True) + return 1 + + except Exception: + task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}") + return None + finally: + if lock.owned(): + lock.release() + + +@shared_task( + name="connector_permission_sync_generator_task", + acks_late=False, + soft_time_limit=JOB_TIMEOUT, + track_started=True, + trail=False, + bind=True, +) +def connector_permission_sync_generator_task( + self: Task, + cc_pair_id: int, + tenant_id: str | None, +) -> None: + """ + Permission sync task that handles document permission syncing for a given connector credential pair + This task assumes that the task has already been properly fenced + """ + + external_doc_permissions_ctx_dict = external_doc_permissions_ctx.get() + external_doc_permissions_ctx_dict["cc_pair_id"] = cc_pair_id + external_doc_permissions_ctx_dict["request_id"] = self.request.id + external_doc_permissions_ctx.set(external_doc_permissions_ctx_dict) + + redis_connector = RedisConnector(tenant_id, cc_pair_id) + + r = get_redis_client(tenant_id=tenant_id) + + lock = r.lock( + DanswerRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK + + f"_{redis_connector.id}", + timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT, + ) + + acquired = lock.acquire(blocking=False) + if not acquired: + task_logger.warning( + f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}" + ) + return None + + try: + with get_session_with_tenant(tenant_id) as db_session: + cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + if cc_pair is None: + raise ValueError( + f"No connector credential pair found for id: {cc_pair_id}" + ) + + source_type = cc_pair.connector.source + + doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) + if doc_sync_func is None: + raise ValueError(f"No doc sync func found for {source_type}") + + logger.info(f"Syncing docs for {source_type}") + document_external_accesses: list[DocumentExternalAccess] = doc_sync_func( + cc_pair + ) + + task_logger.info( + f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}" + ) + tasks_generated = redis_connector.permissions.generate_tasks( + self.app, db_session, lock, document_external_accesses + ) + if tasks_generated is None: + return None + + task_logger.info( + f"RedisConnector.prune.generate_tasks finished. " + f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}" + ) + + redis_connector.permissions.generator_complete = tasks_generated + + except Exception as e: + task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id} ") + + redis_connector.permissions.generator_clear() + redis_connector.permissions.taskset_clear() + redis_connector.permissions.set_fence(False) + raise e + finally: + if lock.owned(): + lock.release() + + +@shared_task( + name="update_external_document_permissions_task", + soft_time_limit=LIGHT_SOFT_TIME_LIMIT, + time_limit=LIGHT_TIME_LIMIT, + max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES, + bind=True, +) +def update_external_document_permissions_task( + self: Task, + document_external_access: DocumentExternalAccess, + tenant_id: str | None, +) -> bool: + doc_id = document_external_access.doc_id + external_access = document_external_access.external_access + try: + with get_session_with_tenant(tenant_id) as db_session: + # Then we build the update requests to update vespa + batch_add_non_web_user_if_not_exists( + db_session=db_session, + emails=list(external_access.external_user_emails), + ) + doc_access = upsert_document_external_perms( + db_session=db_session, + doc_id=doc_id, + external_access=external_access, + ) + update_request = UpdateRequest( + document_ids=[doc_id], + access=doc_access, + ) + + # Don't bother sync-ing secondary, it will be sync-ed after switch anyway + document_index = get_current_primary_default_document_index(db_session) + + # update vespa + document_index.update([update_request]) + + logger.info(f"Successfully synced docs for {doc_id}") + return True + except Exception: + logger.exception("Error Syncing Document Permissions") + return False diff --git a/backend/ee/danswer/background/celery_utils.py b/backend/ee/danswer/background/celery_utils.py index 80278d8c433..8c98c6bcb8e 100644 --- a/backend/ee/danswer/background/celery_utils.py +++ b/backend/ee/danswer/background/celery_utils.py @@ -9,19 +9,13 @@ from danswer.db.tasks import get_latest_task from danswer.utils.logger import setup_logger from ee.danswer.background.task_name_builders import name_chat_ttl_task -from ee.danswer.background.task_name_builders import ( - name_sync_external_doc_permissions_task, -) -from ee.danswer.background.task_name_builders import ( - name_sync_external_group_permissions_task, -) -from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS +from ee.danswer.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS logger = setup_logger() def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool: - source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source) + source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source) # If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync. if not source_sync_period: @@ -59,35 +53,14 @@ def should_perform_chat_ttl_check( return True -def should_perform_external_doc_permissions_check( - cc_pair: ConnectorCredentialPair, db_session: Session -) -> bool: - if cc_pair.access_type != AccessType.SYNC: - return False - - task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair.id) - - latest_task = get_latest_task(task_name, db_session) - if not latest_task: - return True - - if check_task_is_live_and_not_timed_out(latest_task, db_session): - logger.debug(f"{task_name} is already being performed. Skipping.") - return False - - if not _is_time_to_run_sync(cc_pair): - return False - - return True - - def should_perform_external_group_permissions_check( cc_pair: ConnectorCredentialPair, db_session: Session ) -> bool: if cc_pair.access_type != AccessType.SYNC: return False - task_name = name_sync_external_group_permissions_task(cc_pair_id=cc_pair.id) + # task_name = name_sync_external_group_permissions_task(cc_pair_id=cc_pair.id) + task_name = "" latest_task = get_latest_task(task_name, db_session) if not latest_task: diff --git a/backend/ee/danswer/background/task_name_builders.py b/backend/ee/danswer/background/task_name_builders.py index aea6648a02d..c218cdd3b59 100644 --- a/backend/ee/danswer/background/task_name_builders.py +++ b/backend/ee/danswer/background/task_name_builders.py @@ -1,14 +1,2 @@ def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None) -> str: return f"chat_ttl_{retention_limit_days}_days" - - -def name_sync_external_doc_permissions_task( - cc_pair_id: int, tenant_id: str | None = None -) -> str: - return f"sync_external_doc_permissions_task__{cc_pair_id}" - - -def name_sync_external_group_permissions_task( - cc_pair_id: int, tenant_id: str | None = None -) -> str: - return f"sync_external_group_permissions_task__{cc_pair_id}" diff --git a/backend/ee/danswer/db/document.py b/backend/ee/danswer/db/document.py index d67bc0e57e7..6a804e06679 100644 --- a/backend/ee/danswer/db/document.py +++ b/backend/ee/danswer/db/document.py @@ -1,6 +1,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session +from danswer.access.models import DocumentAccess from danswer.access.models import ExternalAccess from danswer.access.utils import prefix_group_w_source from danswer.configs.constants import DocumentSource @@ -45,3 +46,51 @@ def upsert_document_external_perms__no_commit( document.external_user_emails = list(external_access.external_user_emails) document.external_user_group_ids = prefixed_external_groups document.is_public = external_access.is_public + + +def upsert_document_external_perms( + db_session: Session, + doc_id: str, + external_access: ExternalAccess, + source_type: DocumentSource, +) -> DocumentAccess: + """ + This sets the permissions for a document in postgres. + NOTE: this will replace any existing external access, it will not do a union + """ + document = db_session.scalars( + select(DbDocument).where(DbDocument.id == doc_id) + ).first() + + prefixed_external_groups = [ + prefix_group_w_source( + ext_group_name=group_id, + source=source_type, + ) + for group_id in external_access.external_user_group_ids + ] + + if not document: + # If the document does not exist, still store the external access + # So that if the document is added later, the external access is already stored + document = DbDocument( + id=doc_id, + semantic_id="", + external_user_emails=external_access.external_user_emails, + external_user_group_ids=prefixed_external_groups, + is_public=external_access.is_public, + ) + db_session.add(document) + else: + document.external_user_emails = list(external_access.external_user_emails) + document.external_user_group_ids = prefixed_external_groups + document.is_public = external_access.is_public + + db_session.commit() + return DocumentAccess( + external_user_emails=set(document.external_user_emails or []), + external_user_group_ids=set(document.external_user_group_ids or []), + user_emails=set(document.external_user_emails or []), + user_groups=set(document.external_user_group_ids or []), + is_public=document.is_public, + ) diff --git a/backend/ee/danswer/external_permissions/confluence/doc_sync.py b/backend/ee/danswer/external_permissions/confluence/doc_sync.py index a7bc898b8b7..b034e07a0a0 100644 --- a/backend/ee/danswer/external_permissions/confluence/doc_sync.py +++ b/backend/ee/danswer/external_permissions/confluence/doc_sync.py @@ -4,17 +4,14 @@ """ from typing import Any -from sqlalchemy.orm import Session - +from danswer.access.models import DocumentExternalAccess from danswer.access.models import ExternalAccess from danswer.connectors.confluence.connector import ConfluenceConnector from danswer.connectors.confluence.onyx_confluence import OnyxConfluence from danswer.connectors.confluence.utils import get_user_email_from_username__server from danswer.connectors.models import SlimDocument from danswer.db.models import ConnectorCredentialPair -from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger -from ee.danswer.db.document import upsert_document_external_perms__no_commit logger = setup_logger() @@ -190,12 +187,12 @@ def _fetch_all_page_restrictions_for_space( confluence_client: OnyxConfluence, slim_docs: list[SlimDocument], space_permissions_by_space_key: dict[str, ExternalAccess], -) -> dict[str, ExternalAccess]: +) -> list[DocumentExternalAccess]: """ For all pages, if a page has restrictions, then use those restrictions. Otherwise, use the space's restrictions. """ - document_restrictions: dict[str, ExternalAccess] = {} + document_restrictions: list[DocumentExternalAccess] = [] for slim_doc in slim_docs: if slim_doc.perm_sync_data is None: @@ -207,11 +204,21 @@ def _fetch_all_page_restrictions_for_space( restrictions=slim_doc.perm_sync_data.get("restrictions", {}), ) if restrictions: - document_restrictions[slim_doc.id] = restrictions + document_restrictions.append( + DocumentExternalAccess( + doc_id=slim_doc.id, + external_access=restrictions, + ) + ) else: space_key = slim_doc.perm_sync_data.get("space_key") if space_permissions := space_permissions_by_space_key.get(space_key): - document_restrictions[slim_doc.id] = space_permissions + document_restrictions.append( + DocumentExternalAccess( + doc_id=slim_doc.id, + external_access=space_permissions, + ) + ) else: logger.warning(f"No permissions found for document {slim_doc.id}") @@ -219,9 +226,8 @@ def _fetch_all_page_restrictions_for_space( def confluence_doc_sync( - db_session: Session, cc_pair: ConnectorCredentialPair, -) -> None: +) -> list[DocumentExternalAccess]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create @@ -247,20 +253,8 @@ def confluence_doc_sync( for doc_batch in confluence_connector.retrieve_all_slim_documents(): slim_docs.extend(doc_batch) - permissions_by_doc_id = _fetch_all_page_restrictions_for_space( + return _fetch_all_page_restrictions_for_space( confluence_client=confluence_client, slim_docs=slim_docs, space_permissions_by_space_key=space_permissions_by_space_key, ) - - all_emails = set() - for doc_id, page_specific_access in permissions_by_doc_id.items(): - upsert_document_external_perms__no_commit( - db_session=db_session, - doc_id=doc_id, - external_access=page_specific_access, - source_type=cc_pair.connector.source, - ) - all_emails.update(page_specific_access.external_user_emails) - - batch_add_non_web_user_if_not_exists__no_commit(db_session, list(all_emails)) diff --git a/backend/ee/danswer/external_permissions/gmail/doc_sync.py b/backend/ee/danswer/external_permissions/gmail/doc_sync.py index 2748443f022..01e570f42ca 100644 --- a/backend/ee/danswer/external_permissions/gmail/doc_sync.py +++ b/backend/ee/danswer/external_permissions/gmail/doc_sync.py @@ -1,15 +1,12 @@ from datetime import datetime from datetime import timezone -from sqlalchemy.orm import Session - +from danswer.access.models import DocumentExternalAccess from danswer.access.models import ExternalAccess from danswer.connectors.gmail.connector import GmailConnector from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.db.models import ConnectorCredentialPair -from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger -from ee.danswer.db.document import upsert_document_external_perms__no_commit logger = setup_logger() @@ -31,9 +28,8 @@ def _get_slim_doc_generator( def gmail_doc_sync( - db_session: Session, cc_pair: ConnectorCredentialPair, -) -> None: +) -> list[DocumentExternalAccess]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create @@ -45,6 +41,7 @@ def gmail_doc_sync( slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector) + document_external_access: list[DocumentExternalAccess] = [] for slim_doc_batch in slim_doc_generator: for slim_doc in slim_doc_batch: if slim_doc.perm_sync_data is None: @@ -56,13 +53,11 @@ def gmail_doc_sync( external_user_group_ids=set(), is_public=False, ) - batch_add_non_web_user_if_not_exists__no_commit( - db_session=db_session, - emails=list(ext_access.external_user_emails), - ) - upsert_document_external_perms__no_commit( - db_session=db_session, - doc_id=slim_doc.id, - external_access=ext_access, - source_type=cc_pair.connector.source, + document_external_access.append( + DocumentExternalAccess( + doc_id=slim_doc.id, + external_access=ext_access, + ) ) + + return document_external_access diff --git a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py index fddb0e72171..e0512915940 100644 --- a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py @@ -2,8 +2,7 @@ from datetime import timezone from typing import Any -from sqlalchemy.orm import Session - +from danswer.access.models import DocumentExternalAccess from danswer.access.models import ExternalAccess from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval @@ -11,9 +10,7 @@ from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.models import SlimDocument from danswer.db.models import ConnectorCredentialPair -from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger -from ee.danswer.db.document import upsert_document_external_perms__no_commit logger = setup_logger() @@ -126,9 +123,8 @@ def _get_permissions_from_slim_doc( def gdrive_doc_sync( - db_session: Session, cc_pair: ConnectorCredentialPair, -) -> None: +) -> list[DocumentExternalAccess]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create @@ -142,19 +138,17 @@ def gdrive_doc_sync( slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector) + document_external_accesses = [] for slim_doc_batch in slim_doc_generator: for slim_doc in slim_doc_batch: ext_access = _get_permissions_from_slim_doc( google_drive_connector=google_drive_connector, slim_doc=slim_doc, ) - batch_add_non_web_user_if_not_exists__no_commit( - db_session=db_session, - emails=list(ext_access.external_user_emails), - ) - upsert_document_external_perms__no_commit( - db_session=db_session, - doc_id=slim_doc.id, - external_access=ext_access, - source_type=cc_pair.connector.source, + document_external_accesses.append( + DocumentExternalAccess( + external_access=ext_access, + doc_id=slim_doc.id, + ) ) + return document_external_accesses diff --git a/backend/ee/danswer/external_permissions/permission_sync.py b/backend/ee/danswer/external_permissions/permission_sync.py index 94a0b4bfa8e..41eeaa91725 100644 --- a/backend/ee/danswer/external_permissions/permission_sync.py +++ b/backend/ee/danswer/external_permissions/permission_sync.py @@ -1,15 +1,7 @@ -from datetime import datetime -from datetime import timezone - from sqlalchemy.orm import Session -from danswer.access.access import get_access_for_documents from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id -from danswer.db.document import get_document_ids_for_connector_credential_pair -from danswer.document_index.factory import get_current_primary_default_document_index -from danswer.document_index.interfaces import UpdateRequest from danswer.utils.logger import setup_logger -from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP logger = setup_logger() @@ -37,7 +29,6 @@ def run_external_group_permission_sync( logger.debug(f"Syncing groups for {source_type}") if group_sync_func is not None: group_sync_func( - db_session, cc_pair, ) @@ -46,70 +37,3 @@ def run_external_group_permission_sync( except Exception: logger.exception("Error Syncing Group Permissions") db_session.rollback() - - -def run_external_doc_permission_sync( - db_session: Session, - cc_pair_id: int, -) -> None: - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) - if cc_pair is None: - raise ValueError(f"No connector credential pair found for id: {cc_pair_id}") - - source_type = cc_pair.connector.source - - doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) - last_time_perm_sync = cc_pair.last_time_perm_sync - - if doc_sync_func is None: - raise ValueError( - f"No permission sync function found for source type: {source_type}" - ) - - try: - # This function updates: - # - the user_email <-> document mapping - # - the external_user_group_id <-> document mapping - # in postgres without committing - logger.info(f"Syncing docs for {source_type}") - doc_sync_func( - db_session, - cc_pair, - ) - - # Get the document ids for the cc pair - document_ids_for_cc_pair = get_document_ids_for_connector_credential_pair( - db_session=db_session, - connector_id=cc_pair.connector_id, - credential_id=cc_pair.credential_id, - ) - - # This function fetches the updated access for the documents - # and returns a dictionary of document_ids and access - # This is the access we want to update vespa with - docs_access = get_access_for_documents( - document_ids=document_ids_for_cc_pair, - db_session=db_session, - ) - - # Then we build the update requests to update vespa - update_reqs = [ - UpdateRequest(document_ids=[doc_id], access=doc_access) - for doc_id, doc_access in docs_access.items() - ] - - # Don't bother sync-ing secondary, it will be sync-ed after switch anyway - document_index = get_current_primary_default_document_index(db_session) - - # update vespa - document_index.update(update_reqs) - - cc_pair.last_time_perm_sync = datetime.now(timezone.utc) - - # update postgres - db_session.commit() - logger.info(f"Successfully synced docs for {source_type}") - except Exception: - logger.exception("Error Syncing Document Permissions") - cc_pair.last_time_perm_sync = last_time_perm_sync - db_session.rollback() diff --git a/backend/ee/danswer/external_permissions/slack/doc_sync.py b/backend/ee/danswer/external_permissions/slack/doc_sync.py index b5f6e9695db..1e02ebff717 100644 --- a/backend/ee/danswer/external_permissions/slack/doc_sync.py +++ b/backend/ee/danswer/external_permissions/slack/doc_sync.py @@ -1,16 +1,12 @@ from slack_sdk import WebClient -from sqlalchemy.orm import Session +from danswer.access.models import DocumentExternalAccess from danswer.access.models import ExternalAccess -from danswer.connectors.factory import instantiate_connector -from danswer.connectors.interfaces import SlimConnector -from danswer.connectors.models import InputType from danswer.connectors.slack.connector import get_channels from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries +from danswer.connectors.slack.connector import SlackPollConnector from danswer.db.models import ConnectorCredentialPair -from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger -from ee.danswer.db.document import upsert_document_external_perms__no_commit from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map @@ -18,22 +14,15 @@ def _get_slack_document_ids_and_channels( - db_session: Session, cc_pair: ConnectorCredentialPair, ) -> dict[str, list[str]]: - # Get all document ids that need their permissions updated - runnable_connector = instantiate_connector( - db_session=db_session, - source=cc_pair.connector.source, - input_type=InputType.SLIM_RETRIEVAL, - connector_specific_config=cc_pair.connector.connector_specific_config, - credential=cc_pair.credential, - ) + slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config) + slack_connector.load_credentials(cc_pair.credential.credential_json) - assert isinstance(runnable_connector, SlimConnector) + slim_doc_generator = slack_connector.retrieve_all_slim_documents() channel_doc_map: dict[str, list[str]] = {} - for doc_metadata_batch in runnable_connector.retrieve_all_slim_documents(): + for doc_metadata_batch in slim_doc_generator: for doc_metadata in doc_metadata_batch: if doc_metadata.perm_sync_data is None: continue @@ -46,13 +35,11 @@ def _get_slack_document_ids_and_channels( def _fetch_workspace_permissions( - db_session: Session, user_id_to_email_map: dict[str, str], ) -> ExternalAccess: user_emails = set() for email in user_id_to_email_map.values(): user_emails.add(email) - batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails)) return ExternalAccess( external_user_emails=user_emails, # No group<->document mapping for slack @@ -63,7 +50,6 @@ def _fetch_workspace_permissions( def _fetch_channel_permissions( - db_session: Session, slack_client: WebClient, workspace_permissions: ExternalAccess, user_id_to_email_map: dict[str, str], @@ -113,9 +99,6 @@ def _fetch_channel_permissions( # If no email is found, we skip the user continue user_id_to_email_map[member_id] = member_email - batch_add_non_web_user_if_not_exists__no_commit( - db_session, [member_email] - ) member_emails.add(member_email) @@ -131,9 +114,8 @@ def _fetch_channel_permissions( def slack_doc_sync( - db_session: Session, cc_pair: ConnectorCredentialPair, -) -> None: +) -> list[DocumentExternalAccess]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create @@ -145,19 +127,18 @@ def slack_doc_sync( ) user_id_to_email_map = fetch_user_id_to_email_map(slack_client) channel_doc_map = _get_slack_document_ids_and_channels( - db_session=db_session, cc_pair=cc_pair, ) workspace_permissions = _fetch_workspace_permissions( - db_session=db_session, user_id_to_email_map=user_id_to_email_map, ) channel_permissions = _fetch_channel_permissions( - db_session=db_session, slack_client=slack_client, workspace_permissions=workspace_permissions, user_id_to_email_map=user_id_to_email_map, ) + + document_external_accesses = [] for channel_id, ext_access in channel_permissions.items(): doc_ids = channel_doc_map.get(channel_id) if not doc_ids: @@ -165,9 +146,10 @@ def slack_doc_sync( continue for doc_id in doc_ids: - upsert_document_external_perms__no_commit( - db_session=db_session, - doc_id=doc_id, - external_access=ext_access, - source_type=cc_pair.connector.source, + document_external_accesses.append( + DocumentExternalAccess( + external_access=ext_access, + doc_id=doc_id, + ) ) + return document_external_accesses diff --git a/backend/ee/danswer/external_permissions/sync_params.py b/backend/ee/danswer/external_permissions/sync_params.py index 1fd09ca1509..e7083565b85 100644 --- a/backend/ee/danswer/external_permissions/sync_params.py +++ b/backend/ee/danswer/external_permissions/sync_params.py @@ -1,9 +1,9 @@ from collections.abc import Callable -from sqlalchemy.orm import Session - +from danswer.access.models import DocumentExternalAccess from danswer.configs.constants import DocumentSource from danswer.db.models import ConnectorCredentialPair +from ee.danswer.db.external_perm import ExternalUserGroup from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync from ee.danswer.external_permissions.gmail.doc_sync import gmail_doc_sync @@ -12,12 +12,18 @@ from ee.danswer.external_permissions.slack.doc_sync import slack_doc_sync # Defining the input/output types for the sync functions -SyncFuncType = Callable[ +DocSyncFuncType = Callable[ + [ + ConnectorCredentialPair, + ], + list[DocumentExternalAccess], +] + +GroupSyncFuncType = Callable[ [ - Session, ConnectorCredentialPair, ], - None, + list[ExternalUserGroup], ] # These functions update: @@ -25,7 +31,7 @@ # - the external_user_group_id <-> document mapping # in postgres without committing # THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK -DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = { +DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = { DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync, DocumentSource.CONFLUENCE: confluence_doc_sync, DocumentSource.SLACK: slack_doc_sync, @@ -36,14 +42,14 @@ # - the user_email <-> external_user_group_id mapping # in postgres without committing # THIS ONE IS OPTIONAL ON AN APP BY APP BASIS -GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = { +GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = { DocumentSource.GOOGLE_DRIVE: gdrive_group_sync, DocumentSource.CONFLUENCE: confluence_group_sync, } # If nothing is specified here, we run the doc_sync every time the celery beat runs -PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = { +DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = { # Polling is not supported so we fetch all doc permissions every 5 minutes DocumentSource.CONFLUENCE: 5 * 60, DocumentSource.SLACK: 5 * 60,