diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 51cac314fdb..f4f069a5bac 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -93,7 +93,7 @@ from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType from danswer.utils.variable_functionality import fetch_versioned_implementation -from shared_configs.configs import current_tenant_id +from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() @@ -249,7 +249,7 @@ async def create( ) async with get_async_session_with_tenant(tenant_id) as db_session: - token = current_tenant_id.set(tenant_id) + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) verify_email_is_invited(user_create.email) verify_email_domain(user_create.email) @@ -288,7 +288,7 @@ async def create( else: raise exceptions.UserAlreadyExists() - current_tenant_id.reset(token) + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) return user async def on_after_login( @@ -342,7 +342,7 @@ async def oauth_callback( token = None async with get_async_session_with_tenant(tenant_id) as db_session: - token = current_tenant_id.set(tenant_id) + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) verify_email_in_whitelist(account_email, tenant_id) verify_email_domain(account_email) @@ -432,7 +432,7 @@ async def oauth_callback( user.oidc_expiry = None # type: ignore if token: - current_tenant_id.reset(token) + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) return user diff --git a/backend/danswer/background/celery/apps/app_base.py b/backend/danswer/background/celery/apps/app_base.py index 2a52abde5d1..05c0eabbbd9 100644 --- a/backend/danswer/background/celery/apps/app_base.py +++ b/backend/danswer/background/celery/apps/app_base.py @@ -19,6 +19,7 @@ from danswer.background.celery.celery_redis import RedisUserGroup from danswer.background.celery.celery_utils import celery_is_worker_primary from danswer.configs.constants import DanswerRedisLocks +from danswer.db.engine import get_all_tenant_ids from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import ColoredFormatter from danswer.utils.logger import PlainFormatter @@ -56,7 +57,7 @@ def on_task_postrun( task_id: str | None = None, task: Task | None = None, args: tuple | None = None, - kwargs: dict | None = None, + kwargs: dict[str, Any] | None = None, retval: Any | None = None, state: str | None = None, **kwds: Any, @@ -83,7 +84,19 @@ def on_task_postrun( if not task_id: return - r = get_redis_client() + # Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg + if not kwargs: + logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs") + tenant_id = None + else: + tenant_id = kwargs.get("tenant_id") + + task_logger.debug( + f"Task {task.name} (ID: {task_id}) completed with state: {state} " + f"{f'for tenant_id={tenant_id}' if tenant_id else ''}" + ) + + r = get_redis_client(tenant_id=tenant_id) if task_id.startswith(RedisConnectorCredentialPair.PREFIX): r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id) @@ -124,7 +137,7 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None def wait_for_redis(sender: Any, **kwargs: Any) -> None: - r = get_redis_client() + r = get_redis_client(tenant_id=None) WAIT_INTERVAL = 5 WAIT_LIMIT = 60 @@ -157,26 +170,44 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None: def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None: - r = get_redis_client() - WAIT_INTERVAL = 5 WAIT_LIMIT = 60 logger.info("Running as a secondary celery worker.") - logger.info("Waiting for primary worker to be ready...") + logger.info("Waiting for all tenant primary workers to be ready...") time_start = time.monotonic() + while True: - if r.exists(DanswerRedisLocks.PRIMARY_WORKER): + tenant_ids = get_all_tenant_ids() + # Check if we have a primary worker lock for each tenant + all_tenants_ready = all( + get_redis_client(tenant_id=tenant_id).exists( + DanswerRedisLocks.PRIMARY_WORKER + ) + for tenant_id in tenant_ids + ) + + if all_tenants_ready: break - time.monotonic() time_elapsed = time.monotonic() - time_start + ready_tenants = sum( + 1 + for tenant_id in tenant_ids + if get_redis_client(tenant_id=tenant_id).exists( + DanswerRedisLocks.PRIMARY_WORKER + ) + ) + logger.info( - f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" + f"Not all tenant primary workers are ready yet. " + f"Ready tenants: {ready_tenants}/{len(tenant_ids)} " + f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" ) + if time_elapsed > WAIT_LIMIT: msg = ( - f"Primary worker was not ready within the timeout. " + f"Not all tenant primary workers were ready within the timeout " f"({WAIT_LIMIT} seconds). Exiting..." ) logger.error(msg) @@ -184,7 +215,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None: time.sleep(WAIT_INTERVAL) - logger.info("Wait for primary worker completed successfully. Continuing...") + logger.info("All tenant primary workers are ready. Continuing...") return @@ -196,14 +227,26 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: if not celery_is_worker_primary(sender): return - if not sender.primary_worker_lock: + if not hasattr(sender, "primary_worker_locks"): return - logger.info("Releasing primary worker lock.") - lock = sender.primary_worker_lock - if lock.owned(): - lock.release() - sender.primary_worker_lock = None + for tenant_id, lock in sender.primary_worker_locks.items(): + try: + if lock and lock.owned(): + logger.debug(f"Attempting to release lock for tenant {tenant_id}") + try: + lock.release() + logger.debug(f"Successfully released lock for tenant {tenant_id}") + except Exception as e: + logger.error( + f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}" + ) + finally: + sender.primary_worker_locks[tenant_id] = None + except Exception as e: + logger.error( + f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}" + ) def on_setup_logging( diff --git a/backend/danswer/background/celery/apps/beat.py b/backend/danswer/background/celery/apps/beat.py index 47be61e36be..8ddc17efc52 100644 --- a/backend/danswer/background/celery/apps/beat.py +++ b/backend/danswer/background/celery/apps/beat.py @@ -88,7 +88,7 @@ def on_setup_logging( "task": task["task"], "schedule": task["schedule"], "options": task["options"], - "args": (tenant_id,), # Must pass tenant_id as an argument + "kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument } # Include any existing beat schedules diff --git a/backend/danswer/background/celery/apps/primary.py b/backend/danswer/background/celery/apps/primary.py index 58e464f3768..d86f0d60fe4 100644 --- a/backend/danswer/background/celery/apps/primary.py +++ b/backend/danswer/background/celery/apps/primary.py @@ -1,7 +1,6 @@ import multiprocessing from typing import Any -import redis from celery import bootsteps # type: ignore from celery import Celery from celery import signals @@ -24,6 +23,7 @@ from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT from danswer.configs.constants import DanswerRedisLocks from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME +from danswer.db.engine import get_all_tenant_ids from danswer.db.engine import SqlEngine from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import setup_logger @@ -80,81 +80,83 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: # This is singleton work that should be done on startup exactly once # by the primary worker - r = get_redis_client() - - # For the moment, we're assuming that we are the only primary worker - # that should be running. - # TODO: maybe check for or clean up another zombie primary worker if we detect it - r.delete(DanswerRedisLocks.PRIMARY_WORKER) - - # this process wide lock is taken to help other workers start up in order. - # it is planned to use this lock to enforce singleton behavior on the primary - # worker, since the primary worker does redis cleanup on startup, but this isn't - # implemented yet. - lock = r.lock( - DanswerRedisLocks.PRIMARY_WORKER, - timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, - ) + tenant_ids = get_all_tenant_ids() + for tenant_id in tenant_ids: + r = get_redis_client(tenant_id=tenant_id) + + # For the moment, we're assuming that we are the only primary worker + # that should be running. + # TODO: maybe check for or clean up another zombie primary worker if we detect it + r.delete(DanswerRedisLocks.PRIMARY_WORKER) + + # this process wide lock is taken to help other workers start up in order. + # it is planned to use this lock to enforce singleton behavior on the primary + # worker, since the primary worker does redis cleanup on startup, but this isn't + # implemented yet. + lock = r.lock( + DanswerRedisLocks.PRIMARY_WORKER, + timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, + ) - logger.info("Primary worker lock: Acquire starting.") - acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2) - if acquired: - logger.info("Primary worker lock: Acquire succeeded.") - else: - logger.error("Primary worker lock: Acquire failed!") - raise WorkerShutdown("Primary worker lock could not be acquired!") + logger.info("Primary worker lock: Acquire starting.") + acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2) + if acquired: + logger.info("Primary worker lock: Acquire succeeded.") + else: + logger.error("Primary worker lock: Acquire failed!") + raise WorkerShutdown("Primary worker lock could not be acquired!") - sender.primary_worker_lock = lock + sender.primary_worker_locks[tenant_id] = lock - # As currently designed, when this worker starts as "primary", we reinitialize redis - # to a clean state (for our purposes, anyway) - r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) - r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) + # As currently designed, when this worker starts as "primary", we reinitialize redis + # to a clean state (for our purposes, anyway) + r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) + r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) - r.delete(RedisConnectorCredentialPair.get_taskset_key()) - r.delete(RedisConnectorCredentialPair.get_fence_key()) + r.delete(RedisConnectorCredentialPair.get_taskset_key()) + r.delete(RedisConnectorCredentialPair.get_fence_key()) - for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"): + r.delete(key) - for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): - r.delete(key) + for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): + r.delete(key) # @worker_process_init.connect @@ -217,42 +219,58 @@ def start(self, worker: Any) -> None: def run_periodic_task(self, worker: Any) -> None: try: - if not worker.primary_worker_lock: + if not celery_is_worker_primary(worker): return - if not hasattr(worker, "primary_worker_lock"): + if not hasattr(worker, "primary_worker_locks"): return - r = get_redis_client() - - lock: redis.lock.Lock = worker.primary_worker_lock - - if lock.owned(): - task_logger.debug("Reacquiring primary worker lock.") - lock.reacquire() - else: - task_logger.warning( - "Full acquisition of primary worker lock. " - "Reasons could be computer sleep or a clock change." - ) - lock = r.lock( - DanswerRedisLocks.PRIMARY_WORKER, - timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, - ) - - task_logger.info("Primary worker lock: Acquire starting.") - acquired = lock.acquire( - blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2 - ) - if acquired: - task_logger.info("Primary worker lock: Acquire succeeded.") + # Retrieve all tenant IDs + tenant_ids = get_all_tenant_ids() + + for tenant_id in tenant_ids: + lock = worker.primary_worker_locks.get(tenant_id) + if not lock: + continue # Skip if no lock for this tenant + + r = get_redis_client(tenant_id=tenant_id) + + if lock.owned(): + task_logger.debug( + f"Reacquiring primary worker lock for tenant {tenant_id}." + ) + lock.reacquire() else: - task_logger.error("Primary worker lock: Acquire failed!") - raise TimeoutError("Primary worker lock could not be acquired!") + task_logger.warning( + f"Full acquisition of primary worker lock for tenant {tenant_id}. " + "Reasons could be worker restart or lock expiration." + ) + lock = r.lock( + DanswerRedisLocks.PRIMARY_WORKER, + timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, + ) + + task_logger.info( + f"Primary worker lock for tenant {tenant_id}: Acquire starting." + ) + acquired = lock.acquire( + blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2 + ) + if acquired: + task_logger.info( + f"Primary worker lock for tenant {tenant_id}: Acquire succeeded." + ) + worker.primary_worker_locks[tenant_id] = lock + else: + task_logger.error( + f"Primary worker lock for tenant {tenant_id}: Acquire failed!" + ) + raise TimeoutError( + f"Primary worker lock for tenant {tenant_id} could not be acquired!" + ) - worker.primary_worker_lock = lock except Exception: - task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.") + task_logger.exception("Periodic task failed.") def stop(self, worker: Any) -> None: # Cancel the scheduled task when the worker stops diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index 794f89232c5..b1e9c2113e2 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -27,7 +27,10 @@ def _get_deletion_status( - connector_id: int, credential_id: int, db_session: Session + connector_id: int, + credential_id: int, + db_session: Session, + tenant_id: str | None = None, ) -> TaskQueueState | None: """We no longer store TaskQueueState in the DB for a deletion attempt. This function populates TaskQueueState by just checking redis. @@ -40,7 +43,7 @@ def _get_deletion_status( rcd = RedisConnectorDeletion(cc_pair.id) - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) if not r.exists(rcd.fence_key): return None @@ -50,9 +53,14 @@ def _get_deletion_status( def get_deletion_attempt_snapshot( - connector_id: int, credential_id: int, db_session: Session + connector_id: int, + credential_id: int, + db_session: Session, + tenant_id: str | None = None, ) -> DeletionAttemptSnapshot | None: - deletion_task = _get_deletion_status(connector_id, credential_id, db_session) + deletion_task = _get_deletion_status( + connector_id, credential_id, db_session, tenant_id + ) if not deletion_task: return None diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index b3c2eea30b0..f6a59d03e3a 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -24,8 +24,8 @@ trail=False, bind=True, ) -def check_for_connector_deletion_task(self: Task, tenant_id: str | None) -> None: - r = get_redis_client() +def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None: + r = get_redis_client(tenant_id=tenant_id) lock_beat = r.lock( DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK, diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index ed08787d53e..bdd55f77f32 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -55,10 +55,10 @@ soft_time_limit=300, bind=True, ) -def check_for_indexing(self: Task, tenant_id: str | None) -> int | None: +def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: tasks_created = 0 - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) lock_beat = r.lock( DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK, @@ -398,7 +398,7 @@ def connector_indexing_task( attempt = None n_final_progress = 0 - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) rci = RedisConnectorIndexing(cc_pair_id, search_settings_id) diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index 698c2937299..9f290d6f23a 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -41,8 +41,8 @@ soft_time_limit=JOB_TIMEOUT, bind=True, ) -def check_for_pruning(self: Task, tenant_id: str | None) -> None: - r = get_redis_client() +def check_for_pruning(self: Task, *, tenant_id: str | None) -> None: + r = get_redis_client(tenant_id=tenant_id) lock_beat = r.lock( DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK, @@ -222,7 +222,7 @@ def connector_pruning_generator_task( and compares those IDs to locally stored documents and deletes all locally stored IDs missing from the most recently pulled document ID list""" - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) rcp = RedisConnectorPruning(cc_pair_id) diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index 53e26be6954..812074b91e9 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -60,6 +60,7 @@ from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import VespaDocumentFields from danswer.redis.redis_pool import get_redis_client +from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, @@ -67,6 +68,8 @@ from danswer.utils.variable_functionality import global_version from danswer.utils.variable_functionality import noop_fallback +logger = setup_logger() + # celery auto associates tasks created inside another task, # which bloats the result metadata considerably. trail=False prevents this. @@ -76,11 +79,11 @@ trail=False, bind=True, ) -def check_for_vespa_sync_task(self: Task, tenant_id: str | None) -> None: +def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None: """Runs periodically to check if any document needs syncing. Generates sets of tasks for Celery if syncing is needed.""" - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) lock_beat = r.lock( DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK, @@ -680,7 +683,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: Returns True if the task actually did work, False """ - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) lock_beat: redis.lock.Lock = r.lock( DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK, diff --git a/backend/danswer/connectors/file/connector.py b/backend/danswer/connectors/file/connector.py index eb79cce579e..d07a224478e 100644 --- a/backend/danswer/connectors/file/connector.py +++ b/backend/danswer/connectors/file/connector.py @@ -27,7 +27,7 @@ from danswer.file_processing.extract_file_text import read_text_file from danswer.file_store.file_store import get_default_file_store from danswer.utils.logger import setup_logger -from shared_configs.configs import current_tenant_id +from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() @@ -175,7 +175,7 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None def load_from_state(self) -> GenerateDocumentsOutput: documents: list[Document] = [] - token = current_tenant_id.set(self.tenant_id) + token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id) with get_session_with_tenant(self.tenant_id) as db_session: for file_path in self.file_locations: @@ -199,7 +199,7 @@ def load_from_state(self) -> GenerateDocumentsOutput: if documents: yield documents - current_tenant_id.reset(token) + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) if __name__ == "__main__": diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index b05c3a5ce55..a40dbe9a9b9 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -57,10 +57,9 @@ from danswer.server.manage.models import SlackBotTokens from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable -from shared_configs.configs import current_tenant_id +from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT -from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import SLACK_CHANNEL_ID logger = setup_logger() @@ -364,7 +363,7 @@ def process_message( # Set the current tenant ID at the beginning for all DB calls within this thread if client.tenant_id: logger.info(f"Setting tenant ID to {client.tenant_id}") - token = current_tenant_id.set(client.tenant_id) + token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id) try: with get_session_with_tenant(client.tenant_id) as db_session: slack_bot_config = get_slack_bot_config_for_channel( @@ -413,7 +412,7 @@ def process_message( apologize_for_fail(details, client) finally: if client.tenant_id: - current_tenant_id.reset(token) + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None: @@ -511,11 +510,9 @@ def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None: for tenant_id in tenant_ids: with get_session_with_tenant(tenant_id) as db_session: try: - token = current_tenant_id.set( - tenant_id or POSTGRES_DEFAULT_SCHEMA - ) + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public") latest_slack_bot_tokens = fetch_tokens() - current_tenant_id.reset(token) + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) if ( tenant_id not in slack_bot_tokens diff --git a/backend/danswer/db/credentials.py b/backend/danswer/db/credentials.py index b4ecfc888fe..5da5099f1e3 100644 --- a/backend/danswer/db/credentials.py +++ b/backend/danswer/db/credentials.py @@ -242,7 +242,6 @@ def create_credential( ) db_session.add(credential) db_session.flush() # This ensures the credential gets an ID - _relate_credential_to_user_groups__no_commit( db_session=db_session, credential_id=credential.id, diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 7bf813b44f8..c03a47d7f44 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -39,7 +39,7 @@ from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME from danswer.configs.constants import TENANT_ID_PREFIX from danswer.utils.logger import setup_logger -from shared_configs.configs import current_tenant_id +from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() @@ -260,12 +260,12 @@ def get_current_tenant_id(request: Request) -> str: """Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable.""" if not MULTI_TENANT: tenant_id = POSTGRES_DEFAULT_SCHEMA - current_tenant_id.set(tenant_id) + CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) return tenant_id token = request.cookies.get("tenant_details") if not token: - current_value = current_tenant_id.get() + current_value = CURRENT_TENANT_ID_CONTEXTVAR.get() # If no token is present, use the default schema or handle accordingly return current_value @@ -273,14 +273,14 @@ def get_current_tenant_id(request: Request) -> str: payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"]) tenant_id = payload.get("tenant_id") if not tenant_id: - return current_tenant_id.get() + return CURRENT_TENANT_ID_CONTEXTVAR.get() if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID format") - current_tenant_id.set(tenant_id) + CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) return tenant_id except jwt.InvalidTokenError: - return current_tenant_id.get() + return CURRENT_TENANT_ID_CONTEXTVAR.get() except Exception as e: logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @@ -291,7 +291,7 @@ async def get_async_session_with_tenant( tenant_id: str | None = None, ) -> AsyncGenerator[AsyncSession, None]: if tenant_id is None: - tenant_id = current_tenant_id.get() + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() if not is_valid_schema_name(tenant_id): logger.error(f"Invalid tenant ID: {tenant_id}") @@ -319,30 +319,32 @@ async def get_async_session_with_tenant( def get_session_with_tenant( tenant_id: str | None = None, ) -> Generator[Session, None, None]: - """Generate a database session with the appropriate tenant schema set.""" + """Generate a database session bound to a connection with the appropriate tenant schema set.""" engine = get_sqlalchemy_engine() + if tenant_id is None: - tenant_id = current_tenant_id.get() + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() + else: + CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + + event.listen(engine, "checkout", set_search_path_on_checkout) if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID") - # Establish a raw connection without starting a transaction + # Establish a raw connection with engine.connect() as connection: - # Access the raw DBAPI connection + # Access the raw DBAPI connection and set the search_path dbapi_connection = connection.connection - # Execute SET search_path outside of any transaction + # Set the search_path outside of any transaction cursor = dbapi_connection.cursor() try: - cursor.execute(f'SET search_path TO "{tenant_id}"') - # Optionally verify the search_path was set correctly - cursor.execute("SHOW search_path") - cursor.fetchone() + cursor.execute(f'SET search_path = "{tenant_id}"') finally: cursor.close() - # Proceed to create a session using the connection + # Bind the session to the connection with Session(bind=connection, expire_on_commit=False) as session: try: yield session @@ -356,15 +358,27 @@ def get_session_with_tenant( cursor.close() +def set_search_path_on_checkout( + dbapi_conn: Any, connection_record: Any, connection_proxy: Any +) -> None: + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() + if tenant_id and is_valid_schema_name(tenant_id): + with dbapi_conn.cursor() as cursor: + cursor.execute(f'SET search_path TO "{tenant_id}"') + logger.debug( + f"Set search_path to {tenant_id} for connection {connection_record}" + ) + + def get_session_generator_with_tenant() -> Generator[Session, None, None]: - tenant_id = current_tenant_id.get() + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() with get_session_with_tenant(tenant_id) as session: yield session def get_session() -> Generator[Session, None, None]: """Generate a database session with the appropriate tenant schema set.""" - tenant_id = current_tenant_id.get() + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT: raise HTTPException(status_code=401, detail="User must authenticate") @@ -381,7 +395,7 @@ def get_session() -> Generator[Session, None, None]: async def get_async_session() -> AsyncGenerator[AsyncSession, None]: """Generate an async database session with the appropriate tenant schema set.""" - tenant_id = current_tenant_id.get() + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() engine = get_sqlalchemy_async_engine() async with AsyncSession(engine, expire_on_commit=False) as async_session: if MULTI_TENANT: diff --git a/backend/danswer/key_value_store/store.py b/backend/danswer/key_value_store/store.py index b461ca22feb..d0a17b26565 100644 --- a/backend/danswer/key_value_store/store.py +++ b/backend/danswer/key_value_store/store.py @@ -4,6 +4,7 @@ from typing import cast from fastapi import HTTPException +from redis.client import Redis from sqlalchemy import text from sqlalchemy.orm import Session @@ -16,7 +17,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import setup_logger -from shared_configs.configs import current_tenant_id +from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() @@ -27,15 +28,22 @@ class PgRedisKVStore(KeyValueStore): - def __init__(self) -> None: - self.redis_client = get_redis_client() + def __init__( + self, redis_client: Redis | None = None, tenant_id: str | None = None + ) -> None: + # If no redis_client is provided, fall back to the context var + if redis_client is not None: + self.redis_client = redis_client + else: + tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get() + self.redis_client = get_redis_client(tenant_id=tenant_id) @contextmanager def get_session(self) -> Iterator[Session]: engine = get_sqlalchemy_engine() with Session(engine, expire_on_commit=False) as session: if MULTI_TENANT: - tenant_id = current_tenant_id.get() + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() if tenant_id == POSTGRES_DEFAULT_SCHEMA: raise HTTPException( status_code=401, detail="User must authenticate" diff --git a/backend/danswer/redis/redis_pool.py b/backend/danswer/redis/redis_pool.py index fd08b9157bd..3f2ec03d77f 100644 --- a/backend/danswer/redis/redis_pool.py +++ b/backend/danswer/redis/redis_pool.py @@ -1,4 +1,7 @@ +import functools import threading +from collections.abc import Callable +from typing import Any from typing import Optional import redis @@ -14,6 +17,98 @@ from danswer.configs.app_configs import REDIS_SSL_CA_CERTS from danswer.configs.app_configs import REDIS_SSL_CERT_REQS from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +class TenantRedis(redis.Redis): + def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.tenant_id: str = tenant_id + + def _prefixed(self, key: str | bytes | memoryview) -> str | bytes | memoryview: + prefix: str = f"{self.tenant_id}:" + if isinstance(key, str): + if key.startswith(prefix): + return key + else: + return prefix + key + elif isinstance(key, bytes): + prefix_bytes = prefix.encode() + if key.startswith(prefix_bytes): + return key + else: + return prefix_bytes + key + elif isinstance(key, memoryview): + key_bytes = key.tobytes() + prefix_bytes = prefix.encode() + if key_bytes.startswith(prefix_bytes): + return key + else: + return memoryview(prefix_bytes + key_bytes) + else: + raise TypeError(f"Unsupported key type: {type(key)}") + + def _prefix_method(self, method: Callable) -> Callable: + @functools.wraps(method) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if "name" in kwargs: + kwargs["name"] = self._prefixed(kwargs["name"]) + elif len(args) > 0: + args = (self._prefixed(args[0]),) + args[1:] + return method(*args, **kwargs) + + return wrapper + + def _prefix_scan_iter(self, method: Callable) -> Callable: + @functools.wraps(method) + def wrapper(*args: Any, **kwargs: Any) -> Any: + # Prefix the match pattern if provided + if "match" in kwargs: + kwargs["match"] = self._prefixed(kwargs["match"]) + elif len(args) > 0: + args = (self._prefixed(args[0]),) + args[1:] + + # Get the iterator + iterator = method(*args, **kwargs) + + # Remove prefix from returned keys + prefix = f"{self.tenant_id}:".encode() + prefix_len = len(prefix) + + for key in iterator: + if isinstance(key, bytes) and key.startswith(prefix): + yield key[prefix_len:] + else: + yield key + + return wrapper + + def __getattribute__(self, item: str) -> Any: + original_attr = super().__getattribute__(item) + methods_to_wrap = [ + "lock", + "unlock", + "get", + "set", + "delete", + "exists", + "incrby", + "hset", + "hget", + "getset", + "owned", + "reacquire", + "create_lock", + "startswith", + ] # Regular methods that need simple prefixing + + if item == "scan_iter": + return self._prefix_scan_iter(original_attr) + elif item in methods_to_wrap and callable(original_attr): + return self._prefix_method(original_attr) + return original_attr class RedisPool: @@ -32,8 +127,10 @@ def __new__(cls) -> "RedisPool": def _init_pool(self) -> None: self._pool = RedisPool.create_pool(ssl=REDIS_SSL) - def get_client(self) -> Redis: - return redis.Redis(connection_pool=self._pool) + def get_client(self, tenant_id: str | None) -> Redis: + if tenant_id is None: + tenant_id = "public" + return TenantRedis(tenant_id, connection_pool=self._pool) @staticmethod def create_pool( @@ -84,8 +181,8 @@ def create_pool( redis_pool = RedisPool() -def get_redis_client() -> Redis: - return redis_pool.get_client() +def get_redis_client(*, tenant_id: str | None) -> Redis: + return redis_pool.get_client(tenant_id) # # Usage example diff --git a/backend/danswer/search/preprocessing/preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py index aa3124617e5..10832ae7cbf 100644 --- a/backend/danswer/search/preprocessing/preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -10,7 +10,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS from danswer.configs.chat_configs import NUM_RETURNED_HITS -from danswer.db.engine import current_tenant_id +from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR from danswer.db.models import User from danswer.db.search_settings import get_current_search_settings from danswer.llm.interfaces import LLM @@ -162,7 +162,7 @@ def retrieval_preprocessing( time_cutoff=time_filter or predicted_time_cutoff, tags=preset_filters.tags, # Tags are never auto-extracted access_control_list=user_acl_filters, - tenant_id=current_tenant_id.get() if MULTI_TENANT else None, + tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get() if MULTI_TENANT else None, ) llm_evaluation_type = LLMEvaluationType.BASIC diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 92a94a63878..ddc084498ba 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -25,7 +25,8 @@ update_connector_credential_pair_from_id, ) from danswer.db.document import get_document_counts_for_cc_pairs -from danswer.db.engine import current_tenant_id +from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR +from danswer.db.engine import get_current_tenant_id from danswer.db.engine import get_session from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus @@ -94,8 +95,9 @@ def get_cc_pair_full_info( cc_pair_id: int, user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), ) -> CCPairFullInfo: - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) cc_pair = get_connector_credential_pair_from_id( cc_pair_id, db_session, user, get_editable=False @@ -147,6 +149,7 @@ def get_cc_pair_full_info( connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, db_session=db_session, + tenant_id=tenant_id, ), num_docs_indexed=documents_indexed, is_editable_for_current_user=is_editable_for_current_user, @@ -243,6 +246,7 @@ def prune_cc_pair( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), ) -> StatusResponse[list[int]]: """Triggers pruning on a particular cc_pair immediately""" @@ -258,7 +262,7 @@ def prune_cc_pair( detail="Connection not found for current user's permissions", ) - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) rcp = RedisConnectorPruning(cc_pair_id) if rcp.is_pruning(r): raise HTTPException( @@ -273,7 +277,7 @@ def prune_cc_pair( f"{cc_pair.connector.name} connector." ) tasks_created = try_creating_prune_generator_task( - primary_app, cc_pair, db_session, r, current_tenant_id.get() + primary_app, cc_pair, db_session, r, CURRENT_TENANT_ID_CONTEXTVAR.get() ) if not tasks_created: raise HTTPException( @@ -359,7 +363,9 @@ def sync_cc_pair( logger.info(f"Syncing the {cc_pair.connector.name} connector.") sync_external_doc_permissions_task.apply_async( - kwargs=dict(cc_pair_id=cc_pair_id, tenant_id=current_tenant_id.get()), + kwargs=dict( + cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get() + ), ) return StatusResponse( diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index a6ce87ad8a6..1ba0ab13e2c 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -493,10 +493,11 @@ def get_connector_indexing_status( get_editable: bool = Query( False, description="If true, return editable document sets" ), + tenant_id: str | None = Depends(get_current_tenant_id), ) -> list[ConnectorIndexingStatus]: indexing_statuses: list[ConnectorIndexingStatus] = [] - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) # NOTE: If the connector is deleting behind the scenes, # accessing cc_pairs can be inconsistent and members like @@ -617,6 +618,7 @@ def get_connector_indexing_status( connector_id=connector.id, credential_id=credential.id, db_session=db_session, + tenant_id=tenant_id, ), is_deletable=check_deletion_attempt_is_allowed( connector_credential_pair=cc_pair, @@ -694,15 +696,18 @@ def create_connector_with_mock_credential( connector_response = create_connector( db_session=db_session, connector_data=connector_data ) + mock_credential = CredentialBase( credential_json={}, admin_public=True, source=connector_data.source ) credential = create_credential( mock_credential, user=user, db_session=db_session ) + access_type = ( AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE ) + response = add_credential_to_connector( db_session=db_session, user=user, @@ -786,7 +791,7 @@ def connector_run_once( """Used to trigger indexing on a set of cc_pairs associated with a single connector.""" - r = get_redis_client() + r = get_redis_client(tenant_id=tenant_id) connector_id = run_info.connector_id specified_credential_ids = run_info.credential_ids diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index ae2ab8c6e8c..cd2ffe08422 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -38,7 +38,7 @@ from danswer.configs.app_configs import VALID_EMAIL_DOMAINS from danswer.configs.constants import AuthType from danswer.db.auth import get_total_users -from danswer.db.engine import current_tenant_id +from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR from danswer.db.engine import get_session from danswer.db.models import AccessToken from danswer.db.models import DocumentSet__User @@ -188,7 +188,7 @@ def bulk_invite_users( status_code=400, detail="Auth is disabled, cannot invite users" ) - tenant_id = current_tenant_id.get() + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() normalized_emails = [] try: @@ -222,7 +222,9 @@ def bulk_invite_users( return number_of_invited_users try: logger.info("Registering tenant users") - register_tenant_users(current_tenant_id.get(), get_total_users(db_session)) + register_tenant_users( + CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users(db_session) + ) if ENABLE_EMAIL_INVITES: try: for email in all_emails: @@ -250,13 +252,15 @@ def remove_invited_user( user_emails = get_invited_users() remaining_users = [user for user in user_emails if user != user_email.user_email] - tenant_id = current_tenant_id.get() + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() remove_users_from_tenant([user_email.user_email], tenant_id) number_of_invited_users = write_invited_users(remaining_users) try: if MULTI_TENANT: - register_tenant_users(current_tenant_id.get(), get_total_users(db_session)) + register_tenant_users( + CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users(db_session) + ) except Exception: logger.error( "Request to update number of seats taken in control plane failed. " diff --git a/backend/danswer/server/query_and_chat/token_limit.py b/backend/danswer/server/query_and_chat/token_limit.py index 6221eae3346..ec94e2ece4d 100644 --- a/backend/danswer/server/query_and_chat/token_limit.py +++ b/backend/danswer/server/query_and_chat/token_limit.py @@ -21,7 +21,7 @@ from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import fetch_versioned_implementation from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits -from shared_configs.configs import current_tenant_id +from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() @@ -41,7 +41,7 @@ def check_token_rate_limits( versioned_rate_limit_strategy = fetch_versioned_implementation( "danswer.server.query_and_chat.token_limit", "_check_token_rate_limits" ) - return versioned_rate_limit_strategy(user, current_tenant_id.get()) + return versioned_rate_limit_strategy(user, CURRENT_TENANT_ID_CONTEXTVAR.get()) def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None: diff --git a/backend/ee/danswer/background/celery/apps/beat.py b/backend/ee/danswer/background/celery/apps/beat.py index 20325e77df6..bee219e2471 100644 --- a/backend/ee/danswer/background/celery/apps/beat.py +++ b/backend/ee/danswer/background/celery/apps/beat.py @@ -41,7 +41,7 @@ beat_schedule[task_name] = { "task": task["task"], "schedule": task["schedule"], - "args": (tenant_id,), # Must pass tenant_id as an argument + "kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument } # Include any existing beat schedules diff --git a/backend/ee/danswer/background/celery/apps/primary.py b/backend/ee/danswer/background/celery/apps/primary.py index 97c5b0221ca..be27d22868e 100644 --- a/backend/ee/danswer/background/celery/apps/primary.py +++ b/backend/ee/danswer/background/celery/apps/primary.py @@ -29,7 +29,7 @@ run_external_group_permission_sync, ) from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report -from shared_configs.configs import current_tenant_id +from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() @@ -39,7 +39,9 @@ @build_celery_task_wrapper(name_sync_external_doc_permissions_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None: +def sync_external_doc_permissions_task( + cc_pair_id: int, *, tenant_id: str | None +) -> None: with get_session_with_tenant(tenant_id) as db_session: run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) @@ -47,7 +49,7 @@ def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) - @build_celery_task_wrapper(name_sync_external_group_permissions_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def sync_external_group_permissions_task( - cc_pair_id: int, tenant_id: str | None + cc_pair_id: int, *, tenant_id: str | None ) -> None: with get_session_with_tenant(tenant_id) as db_session: run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id) @@ -56,7 +58,7 @@ def sync_external_group_permissions_task( @build_celery_task_wrapper(name_chat_ttl_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def perform_ttl_management_task( - retention_limit_days: int, tenant_id: str | None + retention_limit_days: int, *, tenant_id: str | None ) -> None: with get_session_with_tenant(tenant_id) as db_session: delete_chat_sessions_older_than(retention_limit_days, db_session) @@ -69,7 +71,7 @@ def perform_ttl_management_task( name="check_sync_external_doc_permissions_task", soft_time_limit=JOB_TIMEOUT, ) -def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None: +def check_sync_external_doc_permissions_task(*, tenant_id: str | None) -> None: """Runs periodically to sync external permissions""" with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) @@ -86,7 +88,7 @@ def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None: name="check_sync_external_group_permissions_task", soft_time_limit=JOB_TIMEOUT, ) -def check_sync_external_group_permissions_task(tenant_id: str | None) -> None: +def check_sync_external_group_permissions_task(*, tenant_id: str | None) -> None: """Runs periodically to sync external group permissions""" with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) @@ -103,12 +105,12 @@ def check_sync_external_group_permissions_task(tenant_id: str | None) -> None: name="check_ttl_management_task", soft_time_limit=JOB_TIMEOUT, ) -def check_ttl_management_task(tenant_id: str | None) -> None: +def check_ttl_management_task(*, tenant_id: str | None) -> None: """Runs periodically to check if any ttl tasks should be run and adds them to the queue""" token = None if MULTI_TENANT and tenant_id is not None: - token = current_tenant_id.set(tenant_id) + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) settings = load_settings() retention_limit_days = settings.maximum_chat_retention_days @@ -120,14 +122,14 @@ def check_ttl_management_task(tenant_id: str | None) -> None: ), ) if token is not None: - current_tenant_id.reset(token) + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) @celery_app.task( name="autogenerate_usage_report_task", soft_time_limit=JOB_TIMEOUT, ) -def autogenerate_usage_report_task(tenant_id: str | None) -> None: +def autogenerate_usage_report_task(*, tenant_id: str | None) -> None: """This generates usage report under the /admin/generate-usage/report endpoint""" with get_session_with_tenant(tenant_id) as db_session: create_new_usage_report( diff --git a/backend/ee/danswer/server/middleware/tenant_tracking.py b/backend/ee/danswer/server/middleware/tenant_tracking.py index 63b0f82be8f..eba02d0428b 100644 --- a/backend/ee/danswer/server/middleware/tenant_tracking.py +++ b/backend/ee/danswer/server/middleware/tenant_tracking.py @@ -11,7 +11,7 @@ from danswer.configs.app_configs import MULTI_TENANT from danswer.configs.app_configs import SECRET_JWT_KEY from danswer.db.engine import is_valid_schema_name -from shared_configs.configs import current_tenant_id +from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA @@ -49,7 +49,7 @@ async def set_tenant_id( else: tenant_id = POSTGRES_DEFAULT_SCHEMA - current_tenant_id.set(tenant_id) + CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) logger.info(f"Middleware set current_tenant_id to: {tenant_id}") response = await call_next(request) diff --git a/backend/ee/danswer/server/tenants/api.py b/backend/ee/danswer/server/tenants/api.py index 9ee598a2e26..342554c1c43 100644 --- a/backend/ee/danswer/server/tenants/api.py +++ b/backend/ee/danswer/server/tenants/api.py @@ -24,7 +24,7 @@ from ee.danswer.server.tenants.provisioning import ensure_schema_exists from ee.danswer.server.tenants.provisioning import run_alembic_migrations from ee.danswer.server.tenants.provisioning import user_owns_a_tenant -from shared_configs.configs import current_tenant_id +from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR stripe.api_key = STRIPE_SECRET_KEY @@ -55,7 +55,7 @@ def create_tenant( else: logger.info(f"Schema already exists for tenant {tenant_id}") - token = current_tenant_id.set(tenant_id) + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) run_alembic_migrations(tenant_id) with get_session_with_tenant(tenant_id) as db_session: @@ -74,7 +74,7 @@ def create_tenant( ) finally: if token is not None: - current_tenant_id.reset(token) + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) @router.post("/product-gating") @@ -89,7 +89,7 @@ def gate_product( 2) User's card has declined """ tenant_id = product_gating_request.tenant_id - token = current_tenant_id.set(tenant_id) + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) settings = load_settings() settings.product_gating = product_gating_request.product_gating @@ -100,7 +100,7 @@ def gate_product( create_notification(None, product_gating_request.notification, db_session) if token is not None: - current_tenant_id.reset(token) + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) @router.get("/billing-information", response_model=BillingInformation) @@ -108,14 +108,16 @@ async def billing_information( _: User = Depends(current_admin_user), ) -> BillingInformation: logger.info("Fetching billing information") - return BillingInformation(**fetch_billing_information(current_tenant_id.get())) + return BillingInformation( + **fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get()) + ) @router.post("/create-customer-portal-session") async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict: try: # Fetch tenant_id and current tenant's information - tenant_id = current_tenant_id.get() + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() stripe_info = fetch_tenant_stripe_information(tenant_id) stripe_customer_id = stripe_info.get("stripe_customer_id") if not stripe_customer_id: diff --git a/backend/ee/danswer/server/tenants/billing.py b/backend/ee/danswer/server/tenants/billing.py index 5dcd96713de..681ac835e5f 100644 --- a/backend/ee/danswer/server/tenants/billing.py +++ b/backend/ee/danswer/server/tenants/billing.py @@ -8,7 +8,6 @@ from ee.danswer.configs.app_configs import STRIPE_PRICE_ID from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY from ee.danswer.server.tenants.access import generate_data_plane_token -from shared_configs.configs import current_tenant_id stripe.api_key = STRIPE_SECRET_KEY @@ -50,7 +49,6 @@ def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscr if not STRIPE_PRICE_ID: raise Exception("STRIPE_PRICE_ID is not set") - tenant_id = current_tenant_id.get() response = fetch_tenant_stripe_information(tenant_id) stripe_subscription_id = cast(str, response.get("stripe_subscription_id")) diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 77139125f6e..5c24aebe749 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -131,7 +131,7 @@ def validate_cors_origin(origin: str) -> None: POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public" -current_tenant_id = contextvars.ContextVar( +CURRENT_TENANT_ID_CONTEXTVAR = contextvars.ContextVar( "current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA )