From 9105f95d138a21bca3c5c64722cf93b1c3a09fb9 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Tue, 22 Oct 2024 15:57:36 -0700 Subject: [PATCH] Feature/celery refactor (#2813) * fresh indexing feature branch * cherry pick test * Revert "cherry pick test" This reverts commit 2a624220687affdda3de347e30f2011136f64bda. * set multitenant so that vespa fields match when indexing * cleanup pass * mypy * pass through env var to control celery indexing concurrency * comments on task kickoff and some logging improvements * disentangle configuration for different workers and beats. * use get_session_with_tenant * comment out all of update.py * rename to RedisConnectorIndexingFenceData * first check num_indexing_workers * refactor RedisConnectorIndexingFenceData * comment out on_worker_process_init * missed a file * scope db sessions to short lengths * update launch.json template * fix types * code review --- .vscode/launch.template.jsonc | 264 +++++++- .../background/celery/apps/app_base.py | 256 ++++++++ .../danswer/background/celery/apps/beat.py | 99 +++ .../danswer/background/celery/apps/heavy.py | 88 +++ .../background/celery/apps/indexing.py | 116 ++++ .../danswer/background/celery/apps/light.py | 89 +++ .../danswer/background/celery/apps/primary.py | 278 ++++++++ .../background/celery/apps/task_formatters.py | 26 + .../danswer/background/celery/celery_app.py | 601 ------------------ .../danswer/background/celery/celery_redis.py | 2 +- .../danswer/background/celery/celery_utils.py | 29 +- .../{celeryconfig.py => configs/base.py} | 24 +- .../danswer/background/celery/configs/beat.py | 14 + .../background/celery/configs/heavy.py | 20 + .../background/celery/configs/indexing.py | 21 + .../background/celery/configs/light.py | 22 + .../background/celery/configs/primary.py | 20 + .../celery/tasks/connector_deletion/tasks.py | 34 +- .../background/celery/tasks/indexing/tasks.py | 47 +- .../background/celery/tasks/periodic/tasks.py | 2 +- .../background/celery/tasks/pruning/tasks.py | 34 +- .../shared/RedisConnectorIndexingFenceData.py | 10 + .../background/celery/tasks/shared/tasks.py | 2 +- .../background/celery/tasks/vespa/tasks.py | 155 +++-- .../{celery_run.py => versioned_apps/beat.py} | 7 +- .../background/celery/versioned_apps/heavy.py | 17 + .../celery/versioned_apps/indexing.py | 17 + .../background/celery/versioned_apps/light.py | 17 + .../celery/versioned_apps/primary.py | 8 + backend/danswer/configs/app_configs.py | 35 + backend/danswer/server/documents/cc_pair.py | 6 +- backend/danswer/server/documents/connector.py | 2 + .../danswer/server/manage/administrative.py | 4 +- .../ee/danswer/background/celery/apps/beat.py | 52 ++ .../celery/{celery_app.py => apps/primary.py} | 55 +- .../background/celery/tasks/vespa/tasks.py | 2 +- backend/scripts/dev_run_background_jobs.py | 22 +- backend/supervisord.conf | 34 +- 38 files changed, 1665 insertions(+), 866 deletions(-) create mode 100644 backend/danswer/background/celery/apps/app_base.py create mode 100644 backend/danswer/background/celery/apps/beat.py create mode 100644 backend/danswer/background/celery/apps/heavy.py create mode 100644 backend/danswer/background/celery/apps/indexing.py create mode 100644 backend/danswer/background/celery/apps/light.py create mode 100644 backend/danswer/background/celery/apps/primary.py create mode 100644 backend/danswer/background/celery/apps/task_formatters.py delete mode 100644 backend/danswer/background/celery/celery_app.py rename backend/danswer/background/celery/{celeryconfig.py => configs/base.py} (95%) create mode 100644 backend/danswer/background/celery/configs/beat.py create mode 100644 backend/danswer/background/celery/configs/heavy.py create mode 100644 backend/danswer/background/celery/configs/indexing.py create mode 100644 backend/danswer/background/celery/configs/light.py create mode 100644 backend/danswer/background/celery/configs/primary.py create mode 100644 backend/danswer/background/celery/tasks/shared/RedisConnectorIndexingFenceData.py rename backend/danswer/background/celery/{celery_run.py => versioned_apps/beat.py} (55%) create mode 100644 backend/danswer/background/celery/versioned_apps/heavy.py create mode 100644 backend/danswer/background/celery/versioned_apps/indexing.py create mode 100644 backend/danswer/background/celery/versioned_apps/light.py create mode 100644 backend/danswer/background/celery/versioned_apps/primary.py create mode 100644 backend/ee/danswer/background/celery/apps/beat.py rename backend/ee/danswer/background/celery/{celery_app.py => apps/primary.py} (76%) diff --git a/.vscode/launch.template.jsonc b/.vscode/launch.template.jsonc index c733800981c..87875907cd5 100644 --- a/.vscode/launch.template.jsonc +++ b/.vscode/launch.template.jsonc @@ -6,19 +6,69 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "compounds": [ + { + // Dummy entry used to label the group + "name": "--- Compound ---", + "configurations": [ + "--- Individual ---" + ], + "presentation": { + "group": "1", + } + }, { "name": "Run All Danswer Services", "configurations": [ "Web Server", "Model Server", "API Server", - "Indexing", - "Background Jobs", - "Slack Bot" - ] - } + "Slack Bot", + "Celery primary", + "Celery light", + "Celery heavy", + "Celery indexing", + "Celery beat", + ], + "presentation": { + "group": "1", + } + }, + { + "name": "Web / Model / API", + "configurations": [ + "Web Server", + "Model Server", + "API Server", + ], + "presentation": { + "group": "1", + } + }, + { + "name": "Celery (all)", + "configurations": [ + "Celery primary", + "Celery light", + "Celery heavy", + "Celery indexing", + "Celery beat" + ], + "presentation": { + "group": "1", + } + } ], "configurations": [ + { + // Dummy entry used to label the group + "name": "--- Individual ---", + "type": "node", + "request": "launch", + "presentation": { + "group": "2", + "order": 0 + } + }, { "name": "Web Server", "type": "node", @@ -29,7 +79,11 @@ "runtimeArgs": [ "run", "dev" ], - "console": "integratedTerminal" + "presentation": { + "group": "2", + }, + "console": "integratedTerminal", + "consoleTitle": "Web Server Console" }, { "name": "Model Server", @@ -48,7 +102,11 @@ "--reload", "--port", "9000" - ] + ], + "presentation": { + "group": "2", + }, + "consoleTitle": "Model Server Console" }, { "name": "API Server", @@ -68,57 +126,171 @@ "--reload", "--port", "8080" - ] + ], + "presentation": { + "group": "2", + }, + "consoleTitle": "API Server Console" }, + // For the listener to access the Slack API, + // DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project { - "name": "Indexing", - "consoleName": "Indexing", + "name": "Slack Bot", + "consoleName": "Slack Bot", "type": "debugpy", "request": "launch", - "program": "danswer/background/update.py", + "program": "danswer/danswerbot/slack/listener.py", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { - "ENABLE_MULTIPASS_INDEXING": "false", "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." - } + }, + "presentation": { + "group": "2", + }, + "consoleTitle": "Slack Bot Console" }, - // Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev { - "name": "Background Jobs", - "consoleName": "Background Jobs", + "name": "Celery primary", "type": "debugpy", "request": "launch", - "program": "scripts/dev_run_background_jobs.py", + "module": "celery", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { - "LOG_DANSWER_MODEL_INTERACTIONS": "True", + "LOG_LEVEL": "INFO", + "PYTHONUNBUFFERED": "1", + "PYTHONPATH": "." + }, + "args": [ + "-A", + "danswer.background.celery.versioned_apps.primary", + "worker", + "--pool=threads", + "--concurrency=4", + "--prefetch-multiplier=1", + "--loglevel=INFO", + "--hostname=primary@%n", + "-Q", + "celery", + ], + "presentation": { + "group": "2", + }, + "consoleTitle": "Celery primary Console" + }, + { + "name": "Celery light", + "type": "debugpy", + "request": "launch", + "module": "celery", + "cwd": "${workspaceFolder}/backend", + "envFile": "${workspaceFolder}/.vscode/.env", + "env": { + "LOG_LEVEL": "INFO", + "PYTHONUNBUFFERED": "1", + "PYTHONPATH": "." + }, + "args": [ + "-A", + "danswer.background.celery.versioned_apps.light", + "worker", + "--pool=threads", + "--concurrency=64", + "--prefetch-multiplier=8", + "--loglevel=INFO", + "--hostname=light@%n", + "-Q", + "vespa_metadata_sync,connector_deletion", + ], + "presentation": { + "group": "2", + }, + "consoleTitle": "Celery light Console" + }, + { + "name": "Celery heavy", + "type": "debugpy", + "request": "launch", + "module": "celery", + "cwd": "${workspaceFolder}/backend", + "envFile": "${workspaceFolder}/.vscode/.env", + "env": { + "LOG_LEVEL": "INFO", + "PYTHONUNBUFFERED": "1", + "PYTHONPATH": "." + }, + "args": [ + "-A", + "danswer.background.celery.versioned_apps.heavy", + "worker", + "--pool=threads", + "--concurrency=4", + "--prefetch-multiplier=1", + "--loglevel=INFO", + "--hostname=heavy@%n", + "-Q", + "connector_pruning", + ], + "presentation": { + "group": "2", + }, + "consoleTitle": "Celery heavy Console" + }, + { + "name": "Celery indexing", + "type": "debugpy", + "request": "launch", + "module": "celery", + "cwd": "${workspaceFolder}/backend", + "envFile": "${workspaceFolder}/.vscode/.env", + "env": { + "ENABLE_MULTIPASS_INDEXING": "false", "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": [ - "--no-indexing" - ] + "-A", + "danswer.background.celery.versioned_apps.indexing", + "worker", + "--pool=threads", + "--concurrency=1", + "--prefetch-multiplier=1", + "--loglevel=INFO", + "--hostname=indexing@%n", + "-Q", + "connector_indexing", + ], + "presentation": { + "group": "2", + }, + "consoleTitle": "Celery indexing Console" }, - // For the listner to access the Slack API, - // DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project { - "name": "Slack Bot", - "consoleName": "Slack Bot", + "name": "Celery beat", "type": "debugpy", "request": "launch", - "program": "danswer/danswerbot/slack/listener.py", + "module": "celery", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." - } + }, + "args": [ + "-A", + "danswer.background.celery.versioned_apps.beat", + "beat", + "--loglevel=INFO", + ], + "presentation": { + "group": "2", + }, + "consoleTitle": "Celery beat Console" }, { "name": "Pytest", @@ -137,8 +309,22 @@ "-v" // Specify a sepcific module/test to run or provide nothing to run all tests //"tests/unit/danswer/llm/answering/test_prune_and_merge.py" - ] + ], + "presentation": { + "group": "2", + }, + "consoleTitle": "Pytest Console" }, + { + // Dummy entry used to label the group + "name": "--- Tasks ---", + "type": "node", + "request": "launch", + "presentation": { + "group": "3", + "order": 0 + } + }, { "name": "Clear and Restart External Volumes and Containers", "type": "node", @@ -147,7 +333,27 @@ "runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"], "cwd": "${workspaceFolder}", "console": "integratedTerminal", - "stopOnEntry": true - } + "stopOnEntry": true, + "presentation": { + "group": "3", + }, + }, + { + // Celery jobs launched through a single background script (legacy) + // Recommend using the "Celery (all)" compound launch instead. + "name": "Background Jobs", + "consoleName": "Background Jobs", + "type": "debugpy", + "request": "launch", + "program": "scripts/dev_run_background_jobs.py", + "cwd": "${workspaceFolder}/backend", + "envFile": "${workspaceFolder}/.vscode/.env", + "env": { + "LOG_DANSWER_MODEL_INTERACTIONS": "True", + "LOG_LEVEL": "DEBUG", + "PYTHONUNBUFFERED": "1", + "PYTHONPATH": "." + }, + }, ] } diff --git a/backend/danswer/background/celery/apps/app_base.py b/backend/danswer/background/celery/apps/app_base.py new file mode 100644 index 00000000000..2a52abde5d1 --- /dev/null +++ b/backend/danswer/background/celery/apps/app_base.py @@ -0,0 +1,256 @@ +import logging +import multiprocessing +import time +from typing import Any + +import sentry_sdk +from celery import Task +from celery.exceptions import WorkerShutdown +from celery.states import READY_STATES +from celery.utils.log import get_task_logger +from sentry_sdk.integrations.celery import CeleryIntegration + +from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter +from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter +from danswer.background.celery.celery_redis import RedisConnectorCredentialPair +from danswer.background.celery.celery_redis import RedisConnectorDeletion +from danswer.background.celery.celery_redis import RedisConnectorPruning +from danswer.background.celery.celery_redis import RedisDocumentSet +from danswer.background.celery.celery_redis import RedisUserGroup +from danswer.background.celery.celery_utils import celery_is_worker_primary +from danswer.configs.constants import DanswerRedisLocks +from danswer.redis.redis_pool import get_redis_client +from danswer.utils.logger import ColoredFormatter +from danswer.utils.logger import PlainFormatter +from danswer.utils.logger import setup_logger +from shared_configs.configs import SENTRY_DSN + +logger = setup_logger() + +task_logger = get_task_logger(__name__) + +if SENTRY_DSN: + sentry_sdk.init( + dsn=SENTRY_DSN, + integrations=[CeleryIntegration()], + traces_sample_rate=0.5, + ) + logger.info("Sentry initialized") +else: + logger.debug("Sentry DSN not provided, skipping Sentry initialization") + + +def on_task_prerun( + sender: Any | None = None, + task_id: str | None = None, + task: Task | None = None, + args: tuple | None = None, + kwargs: dict | None = None, + **kwds: Any, +) -> None: + pass + + +def on_task_postrun( + sender: Any | None = None, + task_id: str | None = None, + task: Task | None = None, + args: tuple | None = None, + kwargs: dict | None = None, + retval: Any | None = None, + state: str | None = None, + **kwds: Any, +) -> None: + """We handle this signal in order to remove completed tasks + from their respective tasksets. This allows us to track the progress of document set + and user group syncs. + + This function runs after any task completes (both success and failure) + Note that this signal does not fire on a task that failed to complete and is going + to be retried. + + This also does not fire if a worker with acks_late=False crashes (which all of our + long running workers are) + """ + if not task: + return + + task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}") + + if state not in READY_STATES: + return + + if not task_id: + return + + r = get_redis_client() + + if task_id.startswith(RedisConnectorCredentialPair.PREFIX): + r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id) + return + + if task_id.startswith(RedisDocumentSet.PREFIX): + document_set_id = RedisDocumentSet.get_id_from_task_id(task_id) + if document_set_id is not None: + rds = RedisDocumentSet(int(document_set_id)) + r.srem(rds.taskset_key, task_id) + return + + if task_id.startswith(RedisUserGroup.PREFIX): + usergroup_id = RedisUserGroup.get_id_from_task_id(task_id) + if usergroup_id is not None: + rug = RedisUserGroup(int(usergroup_id)) + r.srem(rug.taskset_key, task_id) + return + + if task_id.startswith(RedisConnectorDeletion.PREFIX): + cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id) + if cc_pair_id is not None: + rcd = RedisConnectorDeletion(int(cc_pair_id)) + r.srem(rcd.taskset_key, task_id) + return + + if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX): + cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id) + if cc_pair_id is not None: + rcp = RedisConnectorPruning(int(cc_pair_id)) + r.srem(rcp.taskset_key, task_id) + return + + +def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: + """The first signal sent on celery worker startup""" + multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn + + +def wait_for_redis(sender: Any, **kwargs: Any) -> None: + r = get_redis_client() + + WAIT_INTERVAL = 5 + WAIT_LIMIT = 60 + + time_start = time.monotonic() + logger.info("Redis: Readiness check starting.") + while True: + try: + if r.ping(): + break + except Exception: + pass + + time_elapsed = time.monotonic() - time_start + logger.info( + f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" + ) + if time_elapsed > WAIT_LIMIT: + msg = ( + f"Redis: Readiness check did not succeed within the timeout " + f"({WAIT_LIMIT} seconds). Exiting..." + ) + logger.error(msg) + raise WorkerShutdown(msg) + + time.sleep(WAIT_INTERVAL) + + logger.info("Redis: Readiness check succeeded. Continuing...") + return + + +def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None: + r = get_redis_client() + + WAIT_INTERVAL = 5 + WAIT_LIMIT = 60 + + logger.info("Running as a secondary celery worker.") + logger.info("Waiting for primary worker to be ready...") + time_start = time.monotonic() + while True: + if r.exists(DanswerRedisLocks.PRIMARY_WORKER): + break + + time.monotonic() + time_elapsed = time.monotonic() - time_start + logger.info( + f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" + ) + if time_elapsed > WAIT_LIMIT: + msg = ( + f"Primary worker was not ready within the timeout. " + f"({WAIT_LIMIT} seconds). Exiting..." + ) + logger.error(msg) + raise WorkerShutdown(msg) + + time.sleep(WAIT_INTERVAL) + + logger.info("Wait for primary worker completed successfully. Continuing...") + return + + +def on_worker_ready(sender: Any, **kwargs: Any) -> None: + task_logger.info("worker_ready signal received.") + + +def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: + if not celery_is_worker_primary(sender): + return + + if not sender.primary_worker_lock: + return + + logger.info("Releasing primary worker lock.") + lock = sender.primary_worker_lock + if lock.owned(): + lock.release() + sender.primary_worker_lock = None + + +def on_setup_logging( + loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any +) -> None: + # TODO: could unhardcode format and colorize and accept these as options from + # celery's config + + # reformats the root logger + root_logger = logging.getLogger() + + root_handler = logging.StreamHandler() # Set up a handler for the root logger + root_formatter = ColoredFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + root_handler.setFormatter(root_formatter) + root_logger.addHandler(root_handler) # Apply the handler to the root logger + + if logfile: + root_file_handler = logging.FileHandler(logfile) + root_file_formatter = PlainFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + root_file_handler.setFormatter(root_file_formatter) + root_logger.addHandler(root_file_handler) + + root_logger.setLevel(loglevel) + + # reformats celery's task logger + task_formatter = CeleryTaskColoredFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + task_handler = logging.StreamHandler() # Set up a handler for the task logger + task_handler.setFormatter(task_formatter) + task_logger.addHandler(task_handler) # Apply the handler to the task logger + + if logfile: + task_file_handler = logging.FileHandler(logfile) + task_file_formatter = CeleryTaskPlainFormatter( + "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + ) + task_file_handler.setFormatter(task_file_formatter) + task_logger.addHandler(task_file_handler) + + task_logger.setLevel(loglevel) + task_logger.propagate = False diff --git a/backend/danswer/background/celery/apps/beat.py b/backend/danswer/background/celery/apps/beat.py new file mode 100644 index 00000000000..47be61e36be --- /dev/null +++ b/backend/danswer/background/celery/apps/beat.py @@ -0,0 +1,99 @@ +from datetime import timedelta +from typing import Any + +from celery import Celery +from celery import signals +from celery.signals import beat_init + +import danswer.background.celery.apps.app_base as app_base +from danswer.configs.constants import DanswerCeleryPriority +from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME +from danswer.db.engine import get_all_tenant_ids +from danswer.db.engine import SqlEngine +from danswer.utils.logger import setup_logger + +logger = setup_logger() + +celery_app = Celery(__name__) +celery_app.config_from_object("danswer.background.celery.configs.beat") + + +@beat_init.connect +def on_beat_init(sender: Any, **kwargs: Any) -> None: + logger.info("beat_init signal received.") + SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME) + SqlEngine.init_engine(pool_size=2, max_overflow=0) + app_base.wait_for_redis(sender, **kwargs) + + +@signals.setup_logging.connect +def on_setup_logging( + loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any +) -> None: + app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) + + +##### +# Celery Beat (Periodic Tasks) Settings +##### + +tenant_ids = get_all_tenant_ids() + +tasks_to_schedule = [ + { + "name": "check-for-vespa-sync", + "task": "check_for_vespa_sync_task", + "schedule": timedelta(seconds=5), + "options": {"priority": DanswerCeleryPriority.HIGH}, + }, + { + "name": "check-for-connector-deletion", + "task": "check_for_connector_deletion_task", + "schedule": timedelta(seconds=60), + "options": {"priority": DanswerCeleryPriority.HIGH}, + }, + { + "name": "check-for-indexing", + "task": "check_for_indexing", + "schedule": timedelta(seconds=10), + "options": {"priority": DanswerCeleryPriority.HIGH}, + }, + { + "name": "check-for-prune", + "task": "check_for_pruning", + "schedule": timedelta(seconds=10), + "options": {"priority": DanswerCeleryPriority.HIGH}, + }, + { + "name": "kombu-message-cleanup", + "task": "kombu_message_cleanup_task", + "schedule": timedelta(seconds=3600), + "options": {"priority": DanswerCeleryPriority.LOWEST}, + }, + { + "name": "monitor-vespa-sync", + "task": "monitor_vespa_sync", + "schedule": timedelta(seconds=5), + "options": {"priority": DanswerCeleryPriority.HIGH}, + }, +] + +# Build the celery beat schedule dynamically +beat_schedule = {} + +for tenant_id in tenant_ids: + for task in tasks_to_schedule: + task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task + beat_schedule[task_name] = { + "task": task["task"], + "schedule": task["schedule"], + "options": task["options"], + "args": (tenant_id,), # Must pass tenant_id as an argument + } + +# Include any existing beat schedules +existing_beat_schedule = celery_app.conf.beat_schedule or {} +beat_schedule.update(existing_beat_schedule) + +# Update the Celery app configuration once +celery_app.conf.beat_schedule = beat_schedule diff --git a/backend/danswer/background/celery/apps/heavy.py b/backend/danswer/background/celery/apps/heavy.py new file mode 100644 index 00000000000..ba53776bedb --- /dev/null +++ b/backend/danswer/background/celery/apps/heavy.py @@ -0,0 +1,88 @@ +import multiprocessing +from typing import Any + +from celery import Celery +from celery import signals +from celery import Task +from celery.signals import celeryd_init +from celery.signals import worker_init +from celery.signals import worker_ready +from celery.signals import worker_shutdown + +import danswer.background.celery.apps.app_base as app_base +from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME +from danswer.db.engine import SqlEngine +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + +celery_app = Celery(__name__) +celery_app.config_from_object("danswer.background.celery.configs.heavy") + + +@signals.task_prerun.connect +def on_task_prerun( + sender: Any | None = None, + task_id: str | None = None, + task: Task | None = None, + args: tuple | None = None, + kwargs: dict | None = None, + **kwds: Any, +) -> None: + app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds) + + +@signals.task_postrun.connect +def on_task_postrun( + sender: Any | None = None, + task_id: str | None = None, + task: Task | None = None, + args: tuple | None = None, + kwargs: dict | None = None, + retval: Any | None = None, + state: str | None = None, + **kwds: Any, +) -> None: + app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds) + + +@celeryd_init.connect +def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: + app_base.on_celeryd_init(sender, conf, **kwargs) + + +@worker_init.connect +def on_worker_init(sender: Any, **kwargs: Any) -> None: + logger.info("worker_init signal received.") + logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") + + SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME) + SqlEngine.init_engine(pool_size=8, max_overflow=0) + + app_base.wait_for_redis(sender, **kwargs) + app_base.on_secondary_worker_init(sender, **kwargs) + + +@worker_ready.connect +def on_worker_ready(sender: Any, **kwargs: Any) -> None: + app_base.on_worker_ready(sender, **kwargs) + + +@worker_shutdown.connect +def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: + app_base.on_worker_shutdown(sender, **kwargs) + + +@signals.setup_logging.connect +def on_setup_logging( + loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any +) -> None: + app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) + + +celery_app.autodiscover_tasks( + [ + "danswer.background.celery.tasks.pruning", + ] +) diff --git a/backend/danswer/background/celery/apps/indexing.py b/backend/danswer/background/celery/apps/indexing.py new file mode 100644 index 00000000000..5e51ebc8c54 --- /dev/null +++ b/backend/danswer/background/celery/apps/indexing.py @@ -0,0 +1,116 @@ +import multiprocessing +from typing import Any + +from celery import Celery +from celery import signals +from celery import Task +from celery.signals import celeryd_init +from celery.signals import worker_init +from celery.signals import worker_ready +from celery.signals import worker_shutdown +from sqlalchemy.orm import Session + +import danswer.background.celery.apps.app_base as app_base +from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME +from danswer.db.engine import SqlEngine +from danswer.db.search_settings import get_current_search_settings +from danswer.db.swap_index import check_index_swap +from danswer.natural_language_processing.search_nlp_models import EmbeddingModel +from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder +from danswer.utils.logger import setup_logger +from shared_configs.configs import INDEXING_MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT + + +logger = setup_logger() + +celery_app = Celery(__name__) +celery_app.config_from_object("danswer.background.celery.configs.indexing") + + +@signals.task_prerun.connect +def on_task_prerun( + sender: Any | None = None, + task_id: str | None = None, + task: Task | None = None, + args: tuple | None = None, + kwargs: dict | None = None, + **kwds: Any, +) -> None: + app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds) + + +@signals.task_postrun.connect +def on_task_postrun( + sender: Any | None = None, + task_id: str | None = None, + task: Task | None = None, + args: tuple | None = None, + kwargs: dict | None = None, + retval: Any | None = None, + state: str | None = None, + **kwds: Any, +) -> None: + app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds) + + +@celeryd_init.connect +def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: + app_base.on_celeryd_init(sender, conf, **kwargs) + + +@worker_init.connect +def on_worker_init(sender: Any, **kwargs: Any) -> None: + logger.info("worker_init signal received.") + logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") + + SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME) + SqlEngine.init_engine(pool_size=8, max_overflow=0) + + # TODO: why is this necessary for the indexer to do? + engine = SqlEngine.get_engine() + with Session(engine) as db_session: + check_index_swap(db_session=db_session) + search_settings = get_current_search_settings(db_session) + + # So that the first time users aren't surprised by really slow speed of first + # batch of documents indexed + if search_settings.provider_type is None: + logger.notice("Running a first inference to warm up embedding model") + embedding_model = EmbeddingModel.from_db_model( + search_settings=search_settings, + server_host=INDEXING_MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ) + + warm_up_bi_encoder( + embedding_model=embedding_model, + ) + logger.notice("First inference complete.") + + app_base.wait_for_redis(sender, **kwargs) + app_base.on_secondary_worker_init(sender, **kwargs) + + +@worker_ready.connect +def on_worker_ready(sender: Any, **kwargs: Any) -> None: + app_base.on_worker_ready(sender, **kwargs) + + +@worker_shutdown.connect +def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: + app_base.on_worker_shutdown(sender, **kwargs) + + +@signals.setup_logging.connect +def on_setup_logging( + loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any +) -> None: + app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) + + +celery_app.autodiscover_tasks( + [ + "danswer.background.celery.tasks.indexing", + ] +) diff --git a/backend/danswer/background/celery/apps/light.py b/backend/danswer/background/celery/apps/light.py new file mode 100644 index 00000000000..6f39074b601 --- /dev/null +++ b/backend/danswer/background/celery/apps/light.py @@ -0,0 +1,89 @@ +import multiprocessing +from typing import Any + +from celery import Celery +from celery import signals +from celery import Task +from celery.signals import celeryd_init +from celery.signals import worker_init +from celery.signals import worker_ready +from celery.signals import worker_shutdown + +import danswer.background.celery.apps.app_base as app_base +from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME +from danswer.db.engine import SqlEngine +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + +celery_app = Celery(__name__) +celery_app.config_from_object("danswer.background.celery.configs.light") + + +@signals.task_prerun.connect +def on_task_prerun( + sender: Any | None = None, + task_id: str | None = None, + task: Task | None = None, + args: tuple | None = None, + kwargs: dict | None = None, + **kwds: Any, +) -> None: + app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds) + + +@signals.task_postrun.connect +def on_task_postrun( + sender: Any | None = None, + task_id: str | None = None, + task: Task | None = None, + args: tuple | None = None, + kwargs: dict | None = None, + retval: Any | None = None, + state: str | None = None, + **kwds: Any, +) -> None: + app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds) + + +@celeryd_init.connect +def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: + app_base.on_celeryd_init(sender, conf, **kwargs) + + +@worker_init.connect +def on_worker_init(sender: Any, **kwargs: Any) -> None: + logger.info("worker_init signal received.") + logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") + + SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME) + SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) + + app_base.wait_for_redis(sender, **kwargs) + app_base.on_secondary_worker_init(sender, **kwargs) + + +@worker_ready.connect +def on_worker_ready(sender: Any, **kwargs: Any) -> None: + app_base.on_worker_ready(sender, **kwargs) + + +@worker_shutdown.connect +def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: + app_base.on_worker_shutdown(sender, **kwargs) + + +@signals.setup_logging.connect +def on_setup_logging( + loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any +) -> None: + app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) + + +celery_app.autodiscover_tasks( + [ + "danswer.background.celery.tasks.shared", + "danswer.background.celery.tasks.vespa", + ] +) diff --git a/backend/danswer/background/celery/apps/primary.py b/backend/danswer/background/celery/apps/primary.py new file mode 100644 index 00000000000..c99607f4bc3 --- /dev/null +++ b/backend/danswer/background/celery/apps/primary.py @@ -0,0 +1,278 @@ +import multiprocessing +from typing import Any + +import redis +from celery import bootsteps # type: ignore +from celery import Celery +from celery import signals +from celery import Task +from celery.exceptions import WorkerShutdown +from celery.signals import celeryd_init +from celery.signals import worker_init +from celery.signals import worker_ready +from celery.signals import worker_shutdown +from celery.utils.log import get_task_logger + +import danswer.background.celery.apps.app_base as app_base +from danswer.background.celery.celery_redis import RedisConnectorCredentialPair +from danswer.background.celery.celery_redis import RedisConnectorDeletion +from danswer.background.celery.celery_redis import RedisConnectorIndexing +from danswer.background.celery.celery_redis import RedisConnectorPruning +from danswer.background.celery.celery_redis import RedisDocumentSet +from danswer.background.celery.celery_redis import RedisUserGroup +from danswer.background.celery.celery_utils import celery_is_worker_primary +from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT +from danswer.configs.constants import DanswerRedisLocks +from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME +from danswer.db.engine import SqlEngine +from danswer.redis.redis_pool import get_redis_client +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + +# use this within celery tasks to get celery task specific logging +task_logger = get_task_logger(__name__) + +celery_app = Celery(__name__) +celery_app.config_from_object("danswer.background.celery.configs.primary") + + +@signals.task_prerun.connect +def on_task_prerun( + sender: Any | None = None, + task_id: str | None = None, + task: Task | None = None, + args: tuple | None = None, + kwargs: dict | None = None, + **kwds: Any, +) -> None: + app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds) + + +@signals.task_postrun.connect +def on_task_postrun( + sender: Any | None = None, + task_id: str | None = None, + task: Task | None = None, + args: tuple | None = None, + kwargs: dict | None = None, + retval: Any | None = None, + state: str | None = None, + **kwds: Any, +) -> None: + app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds) + + +@celeryd_init.connect +def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: + app_base.on_celeryd_init(sender, conf, **kwargs) + + +@worker_init.connect +def on_worker_init(sender: Any, **kwargs: Any) -> None: + logger.info("worker_init signal received.") + logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") + + SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME) + SqlEngine.init_engine(pool_size=8, max_overflow=0) + + app_base.wait_for_redis(sender, **kwargs) + + logger.info("Running as the primary celery worker.") + + # This is singleton work that should be done on startup exactly once + # by the primary worker + r = get_redis_client() + + # For the moment, we're assuming that we are the only primary worker + # that should be running. + # TODO: maybe check for or clean up another zombie primary worker if we detect it + r.delete(DanswerRedisLocks.PRIMARY_WORKER) + + # this process wide lock is taken to help other workers start up in order. + # it is planned to use this lock to enforce singleton behavior on the primary + # worker, since the primary worker does redis cleanup on startup, but this isn't + # implemented yet. + lock = r.lock( + DanswerRedisLocks.PRIMARY_WORKER, + timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, + ) + + logger.info("Primary worker lock: Acquire starting.") + acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2) + if acquired: + logger.info("Primary worker lock: Acquire succeeded.") + else: + logger.error("Primary worker lock: Acquire failed!") + raise WorkerShutdown("Primary worker lock could not be acquired!") + + sender.primary_worker_lock = lock + + # As currently designed, when this worker starts as "primary", we reinitialize redis + # to a clean state (for our purposes, anyway) + r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) + r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) + + r.delete(RedisConnectorCredentialPair.get_taskset_key()) + r.delete(RedisConnectorCredentialPair.get_fence_key()) + + for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"): + r.delete(key) + + for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): + r.delete(key) + + +# @worker_process_init.connect +# def on_worker_process_init(sender: Any, **kwargs: Any) -> None: +# """This only runs inside child processes when the worker is in pool=prefork mode. +# This may be technically unnecessary since we're finding prefork pools to be +# unstable and currently aren't planning on using them.""" +# logger.info("worker_process_init signal received.") +# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME) +# SqlEngine.init_engine(pool_size=5, max_overflow=0) + +# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error +# SqlEngine.get_engine().dispose(close=False) + + +@worker_ready.connect +def on_worker_ready(sender: Any, **kwargs: Any) -> None: + app_base.on_worker_ready(sender, **kwargs) + + +@worker_shutdown.connect +def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: + app_base.on_worker_shutdown(sender, **kwargs) + + +@signals.setup_logging.connect +def on_setup_logging( + loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any +) -> None: + app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) + + +class HubPeriodicTask(bootsteps.StartStopStep): + """Regularly reacquires the primary worker lock outside of the task queue. + Use the task_logger in this class to avoid double logging. + + This cannot be done inside a regular beat task because it must run on schedule and + a queue of existing work would starve the task from running. + """ + + # it's unclear to me whether using the hub's timer or the bootstep timer is better + requires = {"celery.worker.components:Hub"} + + def __init__(self, worker: Any, **kwargs: Any) -> None: + self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds + self.task_tref = None + + def start(self, worker: Any) -> None: + if not celery_is_worker_primary(worker): + return + + # Access the worker's event loop (hub) + hub = worker.consumer.controller.hub + + # Schedule the periodic task + self.task_tref = hub.call_repeatedly( + self.interval, self.run_periodic_task, worker + ) + task_logger.info("Scheduled periodic task with hub.") + + def run_periodic_task(self, worker: Any) -> None: + try: + if not worker.primary_worker_lock: + return + + if not hasattr(worker, "primary_worker_lock"): + return + + r = get_redis_client() + + lock: redis.lock.Lock = worker.primary_worker_lock + + if lock.owned(): + task_logger.debug("Reacquiring primary worker lock.") + lock.reacquire() + else: + task_logger.warning( + "Full acquisition of primary worker lock. " + "Reasons could be computer sleep or a clock change." + ) + lock = r.lock( + DanswerRedisLocks.PRIMARY_WORKER, + timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, + ) + + task_logger.info("Primary worker lock: Acquire starting.") + acquired = lock.acquire( + blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2 + ) + if acquired: + task_logger.info("Primary worker lock: Acquire succeeded.") + else: + task_logger.error("Primary worker lock: Acquire failed!") + raise TimeoutError("Primary worker lock could not be acquired!") + + worker.primary_worker_lock = lock + except Exception: + task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.") + + def stop(self, worker: Any) -> None: + # Cancel the scheduled task when the worker stops + if self.task_tref: + self.task_tref.cancel() + task_logger.info("Canceled periodic task with hub.") + + +celery_app.steps["worker"].add(HubPeriodicTask) + +celery_app.autodiscover_tasks( + [ + "danswer.background.celery.tasks.connector_deletion", + "danswer.background.celery.tasks.indexing", + "danswer.background.celery.tasks.periodic", + "danswer.background.celery.tasks.pruning", + "danswer.background.celery.tasks.shared", + "danswer.background.celery.tasks.vespa", + ] +) diff --git a/backend/danswer/background/celery/apps/task_formatters.py b/backend/danswer/background/celery/apps/task_formatters.py new file mode 100644 index 00000000000..e82b23a5431 --- /dev/null +++ b/backend/danswer/background/celery/apps/task_formatters.py @@ -0,0 +1,26 @@ +import logging + +from celery import current_task + +from danswer.utils.logger import ColoredFormatter +from danswer.utils.logger import PlainFormatter + + +class CeleryTaskPlainFormatter(PlainFormatter): + def format(self, record: logging.LogRecord) -> str: + task = current_task + if task and task.request: + record.__dict__.update(task_id=task.request.id, task_name=task.name) + record.msg = f"[{task.name}({task.request.id})] {record.msg}" + + return super().format(record) + + +class CeleryTaskColoredFormatter(ColoredFormatter): + def format(self, record: logging.LogRecord) -> str: + task = current_task + if task and task.request: + record.__dict__.update(task_id=task.request.id, task_name=task.name) + record.msg = f"[{task.name}({task.request.id})] {record.msg}" + + return super().format(record) diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py deleted file mode 100644 index ee59b8d50fd..00000000000 --- a/backend/danswer/background/celery/celery_app.py +++ /dev/null @@ -1,601 +0,0 @@ -import logging -import multiprocessing -import time -from datetime import timedelta -from typing import Any - -import redis -import sentry_sdk -from celery import bootsteps # type: ignore -from celery import Celery -from celery import current_task -from celery import signals -from celery import Task -from celery.exceptions import WorkerShutdown -from celery.signals import beat_init -from celery.signals import celeryd_init -from celery.signals import worker_init -from celery.signals import worker_ready -from celery.signals import worker_shutdown -from celery.states import READY_STATES -from celery.utils.log import get_task_logger -from sentry_sdk.integrations.celery import CeleryIntegration - -from danswer.background.celery.celery_redis import RedisConnectorCredentialPair -from danswer.background.celery.celery_redis import RedisConnectorDeletion -from danswer.background.celery.celery_redis import RedisConnectorIndexing -from danswer.background.celery.celery_redis import RedisConnectorPruning -from danswer.background.celery.celery_redis import RedisDocumentSet -from danswer.background.celery.celery_redis import RedisUserGroup -from danswer.background.celery.celery_utils import celery_is_worker_primary -from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT -from danswer.configs.constants import DanswerCeleryPriority -from danswer.configs.constants import DanswerRedisLocks -from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME -from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME -from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME -from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME -from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME -from danswer.db.engine import get_all_tenant_ids -from danswer.db.engine import get_session_with_tenant -from danswer.db.engine import SqlEngine -from danswer.db.search_settings import get_current_search_settings -from danswer.db.swap_index import check_index_swap -from danswer.natural_language_processing.search_nlp_models import EmbeddingModel -from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder -from danswer.redis.redis_pool import get_redis_client -from danswer.utils.logger import ColoredFormatter -from danswer.utils.logger import PlainFormatter -from danswer.utils.logger import setup_logger -from shared_configs.configs import INDEXING_MODEL_SERVER_HOST -from shared_configs.configs import MODEL_SERVER_PORT -from shared_configs.configs import SENTRY_DSN - -logger = setup_logger() - -# use this within celery tasks to get celery task specific logging -task_logger = get_task_logger(__name__) - -if SENTRY_DSN: - sentry_sdk.init( - dsn=SENTRY_DSN, - integrations=[CeleryIntegration()], - traces_sample_rate=0.5, - ) - logger.info("Sentry initialized") -else: - logger.debug("Sentry DSN not provided, skipping Sentry initialization") - - -celery_app = Celery(__name__) -celery_app.config_from_object( - "danswer.background.celery.celeryconfig" -) # Load configuration from 'celeryconfig.py' - - -@signals.task_prerun.connect -def on_task_prerun( - sender: Any | None = None, - task_id: str | None = None, - task: Task | None = None, - args: tuple | None = None, - kwargs: dict | None = None, - **kwds: Any, -) -> None: - pass - - -@signals.task_postrun.connect -def on_task_postrun( - sender: Any | None = None, - task_id: str | None = None, - task: Task | None = None, - args: tuple | None = None, - kwargs: dict | None = None, - retval: Any | None = None, - state: str | None = None, - **kwds: Any, -) -> None: - """We handle this signal in order to remove completed tasks - from their respective tasksets. This allows us to track the progress of document set - and user group syncs. - - This function runs after any task completes (both success and failure) - Note that this signal does not fire on a task that failed to complete and is going - to be retried. - - This also does not fire if a worker with acks_late=False crashes (which all of our - long running workers are) - """ - if not task: - return - - task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}") - - if state not in READY_STATES: - return - - if not task_id: - return - - r = get_redis_client() - - if task_id.startswith(RedisConnectorCredentialPair.PREFIX): - r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id) - return - - if task_id.startswith(RedisDocumentSet.PREFIX): - document_set_id = RedisDocumentSet.get_id_from_task_id(task_id) - if document_set_id is not None: - rds = RedisDocumentSet(int(document_set_id)) - r.srem(rds.taskset_key, task_id) - return - - if task_id.startswith(RedisUserGroup.PREFIX): - usergroup_id = RedisUserGroup.get_id_from_task_id(task_id) - if usergroup_id is not None: - rug = RedisUserGroup(int(usergroup_id)) - r.srem(rug.taskset_key, task_id) - return - - if task_id.startswith(RedisConnectorDeletion.PREFIX): - cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id) - if cc_pair_id is not None: - rcd = RedisConnectorDeletion(int(cc_pair_id)) - r.srem(rcd.taskset_key, task_id) - return - - if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX): - cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id) - if cc_pair_id is not None: - rcp = RedisConnectorPruning(int(cc_pair_id)) - r.srem(rcp.taskset_key, task_id) - return - - -@celeryd_init.connect -def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: - """The first signal sent on celery worker startup""" - multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn - - -@beat_init.connect -def on_beat_init(sender: Any, **kwargs: Any) -> None: - SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME) - SqlEngine.init_engine(pool_size=2, max_overflow=0) - - -@worker_init.connect -def on_worker_init(sender: Any, **kwargs: Any) -> None: - logger.info("worker_init signal received.") - logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") - - # decide some initial startup settings based on the celery worker's hostname - # (set at the command line) - hostname = sender.hostname - if hostname.startswith("light"): - SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME) - SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) - elif hostname.startswith("heavy"): - SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME) - SqlEngine.init_engine(pool_size=8, max_overflow=0) - elif hostname.startswith("indexing"): - SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME) - SqlEngine.init_engine(pool_size=8, max_overflow=0) - - # TODO: why is this necessary for the indexer to do? - with get_session_with_tenant(tenant_id) as db_session: - check_index_swap(db_session=db_session) - search_settings = get_current_search_settings(db_session) - - # So that the first time users aren't surprised by really slow speed of first - # batch of documents indexed - - if search_settings.provider_type is None: - logger.notice("Running a first inference to warm up embedding model") - embedding_model = EmbeddingModel.from_db_model( - search_settings=search_settings, - server_host=INDEXING_MODEL_SERVER_HOST, - server_port=MODEL_SERVER_PORT, - ) - - warm_up_bi_encoder( - embedding_model=embedding_model, - ) - logger.notice("First inference complete.") - else: - SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME) - SqlEngine.init_engine(pool_size=8, max_overflow=0) - - r = get_redis_client() - - WAIT_INTERVAL = 5 - WAIT_LIMIT = 60 - - time_start = time.monotonic() - logger.info("Redis: Readiness check starting.") - while True: - try: - if r.ping(): - break - except Exception: - pass - - time_elapsed = time.monotonic() - time_start - logger.info( - f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" - ) - if time_elapsed > WAIT_LIMIT: - msg = ( - f"Redis: Readiness check did not succeed within the timeout " - f"({WAIT_LIMIT} seconds). Exiting..." - ) - logger.error(msg) - raise WorkerShutdown(msg) - - time.sleep(WAIT_INTERVAL) - - logger.info("Redis: Readiness check succeeded. Continuing...") - - if not celery_is_worker_primary(sender): - logger.info("Running as a secondary celery worker.") - logger.info("Waiting for primary worker to be ready...") - time_start = time.monotonic() - while True: - if r.exists(DanswerRedisLocks.PRIMARY_WORKER): - break - - time.monotonic() - time_elapsed = time.monotonic() - time_start - logger.info( - f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" - ) - if time_elapsed > WAIT_LIMIT: - msg = ( - f"Primary worker was not ready within the timeout. " - f"({WAIT_LIMIT} seconds). Exiting..." - ) - logger.error(msg) - raise WorkerShutdown(msg) - - time.sleep(WAIT_INTERVAL) - - logger.info("Wait for primary worker completed successfully. Continuing...") - return - - logger.info("Running as the primary celery worker.") - - # This is singleton work that should be done on startup exactly once - # by the primary worker - r = get_redis_client() - - # For the moment, we're assuming that we are the only primary worker - # that should be running. - # TODO: maybe check for or clean up another zombie primary worker if we detect it - r.delete(DanswerRedisLocks.PRIMARY_WORKER) - - # this process wide lock is taken to help other workers start up in order. - # it is planned to use this lock to enforce singleton behavior on the primary - # worker, since the primary worker does redis cleanup on startup, but this isn't - # implemented yet. - lock = r.lock( - DanswerRedisLocks.PRIMARY_WORKER, - timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, - ) - - logger.info("Primary worker lock: Acquire starting.") - acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2) - if acquired: - logger.info("Primary worker lock: Acquire succeeded.") - else: - logger.error("Primary worker lock: Acquire failed!") - raise WorkerShutdown("Primary worker lock could not be acquired!") - - sender.primary_worker_lock = lock - - # As currently designed, when this worker starts as "primary", we reinitialize redis - # to a clean state (for our purposes, anyway) - r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) - r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) - - r.delete(RedisConnectorCredentialPair.get_taskset_key()) - r.delete(RedisConnectorCredentialPair.get_fence_key()) - - for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"): - r.delete(key) - - for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): - r.delete(key) - - -# @worker_process_init.connect -# def on_worker_process_init(sender: Any, **kwargs: Any) -> None: -# """This only runs inside child processes when the worker is in pool=prefork mode. -# This may be technically unnecessary since we're finding prefork pools to be -# unstable and currently aren't planning on using them.""" -# logger.info("worker_process_init signal received.") -# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME) -# SqlEngine.init_engine(pool_size=5, max_overflow=0) - -# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error -# SqlEngine.get_engine().dispose(close=False) - - -@worker_ready.connect -def on_worker_ready(sender: Any, **kwargs: Any) -> None: - task_logger.info("worker_ready signal received.") - - -@worker_shutdown.connect -def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: - if not celery_is_worker_primary(sender): - return - - if not sender.primary_worker_lock: - return - - logger.info("Releasing primary worker lock.") - lock = sender.primary_worker_lock - if lock.owned(): - lock.release() - sender.primary_worker_lock = None - - -class CeleryTaskPlainFormatter(PlainFormatter): - def format(self, record: logging.LogRecord) -> str: - task = current_task - if task and task.request: - record.__dict__.update(task_id=task.request.id, task_name=task.name) - record.msg = f"[{task.name}({task.request.id})] {record.msg}" - - return super().format(record) - - -class CeleryTaskColoredFormatter(ColoredFormatter): - def format(self, record: logging.LogRecord) -> str: - task = current_task - if task and task.request: - record.__dict__.update(task_id=task.request.id, task_name=task.name) - record.msg = f"[{task.name}({task.request.id})] {record.msg}" - - return super().format(record) - - -@signals.setup_logging.connect -def on_setup_logging( - loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any -) -> None: - # TODO: could unhardcode format and colorize and accept these as options from - # celery's config - - # reformats the root logger - root_logger = logging.getLogger() - - root_handler = logging.StreamHandler() # Set up a handler for the root logger - root_formatter = ColoredFormatter( - "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", - datefmt="%m/%d/%Y %I:%M:%S %p", - ) - root_handler.setFormatter(root_formatter) - root_logger.addHandler(root_handler) # Apply the handler to the root logger - - if logfile: - root_file_handler = logging.FileHandler(logfile) - root_file_formatter = PlainFormatter( - "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", - datefmt="%m/%d/%Y %I:%M:%S %p", - ) - root_file_handler.setFormatter(root_file_formatter) - root_logger.addHandler(root_file_handler) - - root_logger.setLevel(loglevel) - - # reformats celery's task logger - task_formatter = CeleryTaskColoredFormatter( - "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", - datefmt="%m/%d/%Y %I:%M:%S %p", - ) - task_handler = logging.StreamHandler() # Set up a handler for the task logger - task_handler.setFormatter(task_formatter) - task_logger.addHandler(task_handler) # Apply the handler to the task logger - - if logfile: - task_file_handler = logging.FileHandler(logfile) - task_file_formatter = CeleryTaskPlainFormatter( - "%(asctime)s %(filename)30s %(lineno)4s: %(message)s", - datefmt="%m/%d/%Y %I:%M:%S %p", - ) - task_file_handler.setFormatter(task_file_formatter) - task_logger.addHandler(task_file_handler) - - task_logger.setLevel(loglevel) - task_logger.propagate = False - - -class HubPeriodicTask(bootsteps.StartStopStep): - """Regularly reacquires the primary worker lock outside of the task queue. - Use the task_logger in this class to avoid double logging. - - This cannot be done inside a regular beat task because it must run on schedule and - a queue of existing work would starve the task from running. - """ - - # it's unclear to me whether using the hub's timer or the bootstep timer is better - requires = {"celery.worker.components:Hub"} - - def __init__(self, worker: Any, **kwargs: Any) -> None: - self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds - self.task_tref = None - - def start(self, worker: Any) -> None: - if not celery_is_worker_primary(worker): - return - - # Access the worker's event loop (hub) - hub = worker.consumer.controller.hub - - # Schedule the periodic task - self.task_tref = hub.call_repeatedly( - self.interval, self.run_periodic_task, worker - ) - task_logger.info("Scheduled periodic task with hub.") - - def run_periodic_task(self, worker: Any) -> None: - try: - if not worker.primary_worker_lock: - return - - if not hasattr(worker, "primary_worker_lock"): - return - - r = get_redis_client() - - lock: redis.lock.Lock = worker.primary_worker_lock - - if lock.owned(): - task_logger.debug("Reacquiring primary worker lock.") - lock.reacquire() - else: - task_logger.warning( - "Full acquisition of primary worker lock. " - "Reasons could be computer sleep or a clock change." - ) - lock = r.lock( - DanswerRedisLocks.PRIMARY_WORKER, - timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, - ) - - task_logger.info("Primary worker lock: Acquire starting.") - acquired = lock.acquire( - blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2 - ) - if acquired: - task_logger.info("Primary worker lock: Acquire succeeded.") - else: - task_logger.error("Primary worker lock: Acquire failed!") - raise TimeoutError("Primary worker lock could not be acquired!") - - worker.primary_worker_lock = lock - except Exception: - task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.") - - def stop(self, worker: Any) -> None: - # Cancel the scheduled task when the worker stops - if self.task_tref: - self.task_tref.cancel() - task_logger.info("Canceled periodic task with hub.") - - -celery_app.steps["worker"].add(HubPeriodicTask) - -celery_app.autodiscover_tasks( - [ - "danswer.background.celery.tasks.connector_deletion", - "danswer.background.celery.tasks.indexing", - "danswer.background.celery.tasks.periodic", - "danswer.background.celery.tasks.pruning", - "danswer.background.celery.tasks.shared", - "danswer.background.celery.tasks.vespa", - ] -) - -##### -# Celery Beat (Periodic Tasks) Settings -##### - -tenant_ids = get_all_tenant_ids() - -tasks_to_schedule = [ - { - "name": "check-for-vespa-sync", - "task": "check_for_vespa_sync_task", - "schedule": timedelta(seconds=5), - "options": {"priority": DanswerCeleryPriority.HIGH}, - }, - { - "name": "check-for-connector-deletion", - "task": "check_for_connector_deletion_task", - "schedule": timedelta(seconds=60), - "options": {"priority": DanswerCeleryPriority.HIGH}, - }, - { - "name": "check-for-indexing", - "task": "check_for_indexing", - "schedule": timedelta(seconds=10), - "options": {"priority": DanswerCeleryPriority.HIGH}, - }, - { - "name": "check-for-prune", - "task": "check_for_pruning", - "schedule": timedelta(seconds=10), - "options": {"priority": DanswerCeleryPriority.HIGH}, - }, - { - "name": "kombu-message-cleanup", - "task": "kombu_message_cleanup_task", - "schedule": timedelta(seconds=3600), - "options": {"priority": DanswerCeleryPriority.LOWEST}, - }, - { - "name": "monitor-vespa-sync", - "task": "monitor_vespa_sync", - "schedule": timedelta(seconds=5), - "options": {"priority": DanswerCeleryPriority.HIGH}, - }, -] - -# Build the celery beat schedule dynamically -beat_schedule = {} - -for tenant_id in tenant_ids: - for task in tasks_to_schedule: - task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task - beat_schedule[task_name] = { - "task": task["task"], - "schedule": task["schedule"], - "options": task["options"], - "args": (tenant_id,), # Must pass tenant_id as an argument - } - -# Include any existing beat schedules -existing_beat_schedule = celery_app.conf.beat_schedule or {} -beat_schedule.update(existing_beat_schedule) - -# Update the Celery app configuration once -celery_app.conf.beat_schedule = beat_schedule diff --git a/backend/danswer/background/celery/celery_redis.py b/backend/danswer/background/celery/celery_redis.py index 53f20946077..f1a5697e246 100644 --- a/backend/danswer/background/celery/celery_redis.py +++ b/backend/danswer/background/celery/celery_redis.py @@ -10,7 +10,7 @@ from redis import Redis from sqlalchemy.orm import Session -from danswer.background.celery.celeryconfig import CELERY_SEPARATOR +from danswer.background.celery.configs.base import CELERY_SEPARATOR from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerCeleryQueues diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index 4b499268cb4..794f89232c5 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -3,13 +3,10 @@ from datetime import timezone from typing import Any -from sqlalchemy import text from sqlalchemy.orm import Session from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE -from danswer.configs.app_configs import MULTI_TENANT -from danswer.configs.constants import TENANT_ID_PREFIX from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) @@ -19,7 +16,6 @@ from danswer.connectors.interfaces import SlimConnector from danswer.connectors.models import Document from danswer.db.connector_credential_pair import get_connector_credential_pair -from danswer.db.engine import get_session_with_tenant from danswer.db.enums import TaskStatus from danswer.db.models import TaskQueueState from danswer.redis.redis_pool import get_redis_client @@ -129,33 +125,10 @@ def celery_is_listening_to_queue(worker: Any, name: str) -> bool: def celery_is_worker_primary(worker: Any) -> bool: """There are multiple approaches that could be taken to determine if a celery worker is 'primary', as defined by us. But the way we do it is to check the hostname set - for the celery worker, which can be done either in celeryconfig.py or on the + for the celery worker, which can be done on the command line with '--hostname'.""" hostname = worker.hostname if hostname.startswith("primary"): return True return False - - -def get_all_tenant_ids() -> list[str] | list[None]: - if not MULTI_TENANT: - return [None] - with get_session_with_tenant(tenant_id="public") as session: - result = session.execute( - text( - """ - SELECT schema_name - FROM information_schema.schemata - WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')""" - ) - ) - tenant_ids = [row[0] for row in result] - - valid_tenants = [ - tenant - for tenant in tenant_ids - if tenant is None or tenant.startswith(TENANT_ID_PREFIX) - ] - - return valid_tenants diff --git a/backend/danswer/background/celery/celeryconfig.py b/backend/danswer/background/celery/configs/base.py similarity index 95% rename from backend/danswer/background/celery/celeryconfig.py rename to backend/danswer/background/celery/configs/base.py index 3f96364de1e..886fcf545c9 100644 --- a/backend/danswer/background/celery/celeryconfig.py +++ b/backend/danswer/background/celery/configs/base.py @@ -31,21 +31,10 @@ if REDIS_SSL_CA_CERTS: SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}" +# region Broker settings # example celery_broker_url: "redis://:password@localhost:6379/15" broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}" -result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}" - -# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks -# however, prefetching is bad when tasks are lengthy as those tasks -# can stall other tasks. -worker_prefetch_multiplier = 4 - -# Leaving this to the default of True may cause double logging since both our own app -# and celery think they are controlling the logger. -# TODO: Configure celery's logger entirely manually and set this to False -# worker_hijack_root_logger = False - broker_connection_retry_on_startup = True broker_pool_limit = CELERY_BROKER_POOL_LIMIT @@ -60,6 +49,7 @@ "socket_keepalive": True, "socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS, } +# endregion # redis backend settings # https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings @@ -73,10 +63,19 @@ task_default_priority = DanswerCeleryPriority.MEDIUM task_acks_late = True +# region Task result backend settings # It's possible we don't even need celery's result backend, in which case all of the optimization below # might be irrelevant +result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}" result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default +# endregion + +# Leaving this to the default of True may cause double logging since both our own app +# and celery think they are controlling the logger. +# TODO: Configure celery's logger entirely manually and set this to False +# worker_hijack_root_logger = False +# region Notes on serialization performance # Option 0: Defaults (json serializer, no compression) # about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result @@ -102,3 +101,4 @@ # task_serializer = "pickle-bzip2" # result_serializer = "pickle-bzip2" # accept_content=["pickle", "pickle-bzip2"] +# endregion diff --git a/backend/danswer/background/celery/configs/beat.py b/backend/danswer/background/celery/configs/beat.py new file mode 100644 index 00000000000..ef8b21c386f --- /dev/null +++ b/backend/danswer/background/celery/configs/beat.py @@ -0,0 +1,14 @@ +# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html +import danswer.background.celery.configs.base as shared_config + +broker_url = shared_config.broker_url +broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup +broker_pool_limit = shared_config.broker_pool_limit +broker_transport_options = shared_config.broker_transport_options + +redis_socket_keepalive = shared_config.redis_socket_keepalive +redis_retry_on_timeout = shared_config.redis_retry_on_timeout +redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval + +result_backend = shared_config.result_backend +result_expires = shared_config.result_expires # 86400 seconds is the default diff --git a/backend/danswer/background/celery/configs/heavy.py b/backend/danswer/background/celery/configs/heavy.py new file mode 100644 index 00000000000..2d1c65aa86e --- /dev/null +++ b/backend/danswer/background/celery/configs/heavy.py @@ -0,0 +1,20 @@ +import danswer.background.celery.configs.base as shared_config + +broker_url = shared_config.broker_url +broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup +broker_pool_limit = shared_config.broker_pool_limit +broker_transport_options = shared_config.broker_transport_options + +redis_socket_keepalive = shared_config.redis_socket_keepalive +redis_retry_on_timeout = shared_config.redis_retry_on_timeout +redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval + +result_backend = shared_config.result_backend +result_expires = shared_config.result_expires # 86400 seconds is the default + +task_default_priority = shared_config.task_default_priority +task_acks_late = shared_config.task_acks_late + +worker_concurrency = 4 +worker_pool = "threads" +worker_prefetch_multiplier = 1 diff --git a/backend/danswer/background/celery/configs/indexing.py b/backend/danswer/background/celery/configs/indexing.py new file mode 100644 index 00000000000..d2b1b99baa9 --- /dev/null +++ b/backend/danswer/background/celery/configs/indexing.py @@ -0,0 +1,21 @@ +import danswer.background.celery.configs.base as shared_config +from danswer.configs.app_configs import CELERY_WORKER_INDEXING_CONCURRENCY + +broker_url = shared_config.broker_url +broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup +broker_pool_limit = shared_config.broker_pool_limit +broker_transport_options = shared_config.broker_transport_options + +redis_socket_keepalive = shared_config.redis_socket_keepalive +redis_retry_on_timeout = shared_config.redis_retry_on_timeout +redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval + +result_backend = shared_config.result_backend +result_expires = shared_config.result_expires # 86400 seconds is the default + +task_default_priority = shared_config.task_default_priority +task_acks_late = shared_config.task_acks_late + +worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY +worker_pool = "threads" +worker_prefetch_multiplier = 1 diff --git a/backend/danswer/background/celery/configs/light.py b/backend/danswer/background/celery/configs/light.py new file mode 100644 index 00000000000..f75ddfd0fb5 --- /dev/null +++ b/backend/danswer/background/celery/configs/light.py @@ -0,0 +1,22 @@ +import danswer.background.celery.configs.base as shared_config +from danswer.configs.app_configs import CELERY_WORKER_LIGHT_CONCURRENCY +from danswer.configs.app_configs import CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER + +broker_url = shared_config.broker_url +broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup +broker_pool_limit = shared_config.broker_pool_limit +broker_transport_options = shared_config.broker_transport_options + +redis_socket_keepalive = shared_config.redis_socket_keepalive +redis_retry_on_timeout = shared_config.redis_retry_on_timeout +redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval + +result_backend = shared_config.result_backend +result_expires = shared_config.result_expires # 86400 seconds is the default + +task_default_priority = shared_config.task_default_priority +task_acks_late = shared_config.task_acks_late + +worker_concurrency = CELERY_WORKER_LIGHT_CONCURRENCY +worker_pool = "threads" +worker_prefetch_multiplier = CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER diff --git a/backend/danswer/background/celery/configs/primary.py b/backend/danswer/background/celery/configs/primary.py new file mode 100644 index 00000000000..2d1c65aa86e --- /dev/null +++ b/backend/danswer/background/celery/configs/primary.py @@ -0,0 +1,20 @@ +import danswer.background.celery.configs.base as shared_config + +broker_url = shared_config.broker_url +broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup +broker_pool_limit = shared_config.broker_pool_limit +broker_transport_options = shared_config.broker_transport_options + +redis_socket_keepalive = shared_config.redis_socket_keepalive +redis_retry_on_timeout = shared_config.redis_retry_on_timeout +redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval + +result_backend = shared_config.result_backend +result_expires = shared_config.result_expires # 86400 seconds is the default + +task_default_priority = shared_config.task_default_priority +task_acks_late = shared_config.task_acks_late + +worker_concurrency = 4 +worker_pool = "threads" +worker_prefetch_multiplier = 1 diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index b13daff61fc..b3c2eea30b0 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -1,20 +1,20 @@ import redis +from celery import Celery from celery import shared_task +from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis from sqlalchemy.orm import Session -from sqlalchemy.orm.exc import ObjectDeletedError -from danswer.background.celery.celery_app import celery_app -from danswer.background.celery.celery_app import task_logger +from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerRedisLocks +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.engine import get_session_with_tenant from danswer.db.enums import ConnectorCredentialPairStatus -from danswer.db.models import ConnectorCredentialPair from danswer.redis.redis_pool import get_redis_client @@ -22,8 +22,9 @@ name="check_for_connector_deletion_task", soft_time_limit=JOB_TIMEOUT, trail=False, + bind=True, ) -def check_for_connector_deletion_task(tenant_id: str | None) -> None: +def check_for_connector_deletion_task(self: Task, tenant_id: str | None) -> None: r = get_redis_client() lock_beat = r.lock( @@ -36,11 +37,16 @@ def check_for_connector_deletion_task(tenant_id: str | None) -> None: if not lock_beat.acquire(blocking=False): return + cc_pair_ids: list[int] = [] with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: + cc_pair_ids.append(cc_pair.id) + + for cc_pair_id in cc_pair_ids: + with get_session_with_tenant(tenant_id) as db_session: try_generate_document_cc_pair_cleanup_tasks( - cc_pair, db_session, r, lock_beat, tenant_id + self.app, cc_pair_id, db_session, r, lock_beat, tenant_id ) except SoftTimeLimitExceeded: task_logger.info( @@ -54,7 +60,8 @@ def check_for_connector_deletion_task(tenant_id: str | None) -> None: def try_generate_document_cc_pair_cleanup_tasks( - cc_pair: ConnectorCredentialPair, + app: Celery, + cc_pair_id: int, db_session: Session, r: Redis, lock_beat: redis.lock.Lock, @@ -67,18 +74,17 @@ def try_generate_document_cc_pair_cleanup_tasks( lock_beat.reacquire() - rcd = RedisConnectorDeletion(cc_pair.id) + rcd = RedisConnectorDeletion(cc_pair_id) # don't generate sync tasks if tasks are still pending if r.exists(rcd.fence_key): return None - # we need to refresh the state of the object inside the fence + # we need to load the state of the object inside the fence # to avoid a race condition with db.commit/fence deletion # at the end of this taskset - try: - db_session.refresh(cc_pair) - except ObjectDeletedError: + cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + if not cc_pair: return None if cc_pair.status != ConnectorCredentialPairStatus.DELETING: @@ -91,9 +97,7 @@ def try_generate_document_cc_pair_cleanup_tasks( task_logger.info( f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}" ) - tasks_generated = rcd.generate_tasks( - celery_app, db_session, r, lock_beat, tenant_id - ) + tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id) if tasks_generated is None: return None diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index 0e8e59bf5a6..ed08787d53e 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -5,15 +5,18 @@ from typing import cast from uuid import uuid4 +from celery import Celery from celery import shared_task +from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis from sqlalchemy.orm import Session -from danswer.background.celery.celery_app import celery_app -from danswer.background.celery.celery_app import task_logger +from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.celery_redis import RedisConnectorIndexing -from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData +from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import ( + RedisConnectorIndexingFenceData, +) from danswer.background.indexing.job_client import SimpleJobClient from danswer.background.indexing.run_indexing import run_indexing_entrypoint from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP @@ -50,8 +53,9 @@ @shared_task( name="check_for_indexing", soft_time_limit=300, + bind=True, ) -def check_for_indexing(tenant_id: str | None) -> int | None: +def check_for_indexing(self: Task, tenant_id: str | None) -> int | None: tasks_created = 0 r = get_redis_client() @@ -66,26 +70,37 @@ def check_for_indexing(tenant_id: str | None) -> int | None: if not lock_beat.acquire(blocking=False): return None + cc_pair_ids: list[int] = [] with get_session_with_tenant(tenant_id) as db_session: - # Get the primary search settings - primary_search_settings = get_current_search_settings(db_session) - search_settings = [primary_search_settings] + cc_pairs = fetch_connector_credential_pairs(db_session) + for cc_pair_entry in cc_pairs: + cc_pair_ids.append(cc_pair_entry.id) - # Check for secondary search settings - secondary_search_settings = get_secondary_search_settings(db_session) - if secondary_search_settings is not None: - # If secondary settings exist, add them to the list - search_settings.append(secondary_search_settings) + for cc_pair_id in cc_pair_ids: + with get_session_with_tenant(tenant_id) as db_session: + # Get the primary search settings + primary_search_settings = get_current_search_settings(db_session) + search_settings = [primary_search_settings] + + # Check for secondary search settings + secondary_search_settings = get_secondary_search_settings(db_session) + if secondary_search_settings is not None: + # If secondary settings exist, add them to the list + search_settings.append(secondary_search_settings) - cc_pairs = fetch_connector_credential_pairs(db_session) - for cc_pair in cc_pairs: for search_settings_instance in search_settings: rci = RedisConnectorIndexing( - cc_pair.id, search_settings_instance.id + cc_pair_id, search_settings_instance.id ) if r.exists(rci.fence_key): continue + cc_pair = get_connector_credential_pair_from_id( + cc_pair_id, db_session + ) + if not cc_pair: + continue + last_attempt = get_last_attempt_for_cc_pair( cc_pair.id, search_settings_instance.id, db_session ) @@ -101,6 +116,7 @@ def check_for_indexing(tenant_id: str | None) -> int | None: # using a task queue and only allowing one task per cc_pair/search_setting # prevents us from starving out certain attempts attempt_id = try_creating_indexing_task( + self.app, cc_pair, search_settings_instance, False, @@ -210,6 +226,7 @@ def _should_index( def try_creating_indexing_task( + celery_app: Celery, cc_pair: ConnectorCredentialPair, search_settings: SearchSettings, reindex: bool, diff --git a/backend/danswer/background/celery/tasks/periodic/tasks.py b/backend/danswer/background/celery/tasks/periodic/tasks.py index d8da5ba9ca9..20baa7c52fa 100644 --- a/backend/danswer/background/celery/tasks/periodic/tasks.py +++ b/backend/danswer/background/celery/tasks/periodic/tasks.py @@ -11,7 +11,7 @@ from sqlalchemy import text from sqlalchemy.orm import Session -from danswer.background.celery.celery_app import task_logger +from danswer.background.celery.apps.app_base import task_logger from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import PostgresAdvisoryLocks from danswer.db.engine import get_session_with_tenant diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index 4bfde82292a..698c2937299 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -3,13 +3,14 @@ from datetime import timezone from uuid import uuid4 +from celery import Celery from celery import shared_task +from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis from sqlalchemy.orm import Session -from danswer.background.celery.celery_app import celery_app -from danswer.background.celery.celery_app import task_logger +from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.celery_redis import RedisConnectorPruning from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING @@ -23,6 +24,7 @@ from danswer.connectors.factory import instantiate_connector from danswer.connectors.models import InputType from danswer.db.connector_credential_pair import get_connector_credential_pair +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.document import get_documents_for_connector_credential_pair from danswer.db.engine import get_session_with_tenant @@ -37,8 +39,9 @@ @shared_task( name="check_for_pruning", soft_time_limit=JOB_TIMEOUT, + bind=True, ) -def check_for_pruning(tenant_id: str | None) -> None: +def check_for_pruning(self: Task, tenant_id: str | None) -> None: r = get_redis_client() lock_beat = r.lock( @@ -51,15 +54,24 @@ def check_for_pruning(tenant_id: str | None) -> None: if not lock_beat.acquire(blocking=False): return + cc_pair_ids: list[int] = [] with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_connector_credential_pairs(db_session) - for cc_pair in cc_pairs: - lock_beat.reacquire() + for cc_pair_entry in cc_pairs: + cc_pair_ids.append(cc_pair_entry.id) + + for cc_pair_id in cc_pair_ids: + lock_beat.reacquire() + with get_session_with_tenant(tenant_id) as db_session: + cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + if not cc_pair: + continue + if not is_pruning_due(cc_pair, db_session, r): continue tasks_created = try_creating_prune_generator_task( - cc_pair, db_session, r, tenant_id + self.app, cc_pair, db_session, r, tenant_id ) if not tasks_created: continue @@ -118,6 +130,7 @@ def is_pruning_due( def try_creating_prune_generator_task( + celery_app: Celery, cc_pair: ConnectorCredentialPair, db_session: Session, r: Redis, @@ -196,9 +209,14 @@ def try_creating_prune_generator_task( soft_time_limit=JOB_TIMEOUT, track_started=True, trail=False, + bind=True, ) def connector_pruning_generator_task( - cc_pair_id: int, connector_id: int, credential_id: int, tenant_id: str | None + self: Task, + cc_pair_id: int, + connector_id: int, + credential_id: int, + tenant_id: str | None, ) -> None: """connector pruning task. For a cc pair, this task pulls all document IDs from the source and compares those IDs to locally stored documents and deletes all locally stored IDs missing @@ -278,7 +296,7 @@ def redis_increment_callback(amount: int) -> None: f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}" ) tasks_generated = rcp.generate_tasks( - celery_app, db_session, r, None, tenant_id + self.app, db_session, r, None, tenant_id ) if tasks_generated is None: return None diff --git a/backend/danswer/background/celery/tasks/shared/RedisConnectorIndexingFenceData.py b/backend/danswer/background/celery/tasks/shared/RedisConnectorIndexingFenceData.py new file mode 100644 index 00000000000..224571a4231 --- /dev/null +++ b/backend/danswer/background/celery/tasks/shared/RedisConnectorIndexingFenceData.py @@ -0,0 +1,10 @@ +from datetime import datetime + +from pydantic import BaseModel + + +class RedisConnectorIndexingFenceData(BaseModel): + index_attempt_id: int | None + started: datetime | None + submitted: datetime + celery_task_id: str | None diff --git a/backend/danswer/background/celery/tasks/shared/tasks.py b/backend/danswer/background/celery/tasks/shared/tasks.py index 26f9d1aac10..474a749e786 100644 --- a/backend/danswer/background/celery/tasks/shared/tasks.py +++ b/backend/danswer/background/celery/tasks/shared/tasks.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from danswer.access.access import get_access_for_document -from danswer.background.celery.celery_app import task_logger +from danswer.background.celery.apps.app_base import task_logger from danswer.db.document import delete_document_by_connector_credential_pair__no_commit from danswer.db.document import delete_documents_complete__no_commit from danswer.db.document import get_document diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index 2d79045c44f..53e26be6954 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -5,6 +5,7 @@ from typing import cast import redis +from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded @@ -14,8 +15,7 @@ from sqlalchemy.orm import Session from danswer.access.access import get_access_for_document -from danswer.background.celery.celery_app import celery_app -from danswer.background.celery.celery_app import task_logger +from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.celery_redis import celery_get_queue_length from danswer.background.celery.celery_redis import RedisConnectorCredentialPair from danswer.background.celery.celery_redis import RedisConnectorDeletion @@ -23,7 +23,9 @@ from danswer.background.celery.celery_redis import RedisConnectorPruning from danswer.background.celery.celery_redis import RedisDocumentSet from danswer.background.celery.celery_redis import RedisUserGroup -from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData +from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import ( + RedisConnectorIndexingFenceData, +) from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerCeleryQueues @@ -54,7 +56,6 @@ from danswer.db.index_attempt import mark_attempt_failed from danswer.db.models import DocumentSet from danswer.db.models import IndexAttempt -from danswer.db.models import UserGroup from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import VespaDocumentFields @@ -73,8 +74,9 @@ name="check_for_vespa_sync_task", soft_time_limit=JOB_TIMEOUT, trail=False, + bind=True, ) -def check_for_vespa_sync_task(tenant_id: str | None) -> None: +def check_for_vespa_sync_task(self: Task, tenant_id: str | None) -> None: """Runs periodically to check if any document needs syncing. Generates sets of tasks for Celery if syncing is needed.""" @@ -91,35 +93,53 @@ def check_for_vespa_sync_task(tenant_id: str | None) -> None: return with get_session_with_tenant(tenant_id) as db_session: - try_generate_stale_document_sync_tasks(db_session, r, lock_beat, tenant_id) + try_generate_stale_document_sync_tasks( + self.app, db_session, r, lock_beat, tenant_id + ) + # region document set scan + document_set_ids: list[int] = [] + with get_session_with_tenant(tenant_id) as db_session: # check if any document sets are not synced document_set_info = fetch_document_sets( user_id=None, db_session=db_session, include_outdated=True ) + for document_set, _ in document_set_info: + document_set_ids.append(document_set.id) + + for document_set_id in document_set_ids: + with get_session_with_tenant(tenant_id) as db_session: try_generate_document_set_sync_tasks( - document_set, db_session, r, lock_beat, tenant_id + self.app, document_set_id, db_session, r, lock_beat, tenant_id ) + # endregion - # check if any user groups are not synced - if global_version.is_ee_version(): - try: - fetch_user_groups = fetch_versioned_implementation( - "danswer.db.user_group", "fetch_user_groups" - ) - + # check if any user groups are not synced + if global_version.is_ee_version(): + try: + fetch_user_groups = fetch_versioned_implementation( + "danswer.db.user_group", "fetch_user_groups" + ) + except ModuleNotFoundError: + # Always exceptions on the MIT version, which is expected + # We shouldn't actually get here if the ee version check works + pass + else: + usergroup_ids: list[int] = [] + with get_session_with_tenant(tenant_id) as db_session: user_groups = fetch_user_groups( db_session=db_session, only_up_to_date=False ) + for usergroup in user_groups: + usergroup_ids.append(usergroup.id) + + for usergroup_id in usergroup_ids: + with get_session_with_tenant(tenant_id) as db_session: try_generate_user_group_sync_tasks( - usergroup, db_session, r, lock_beat, tenant_id + self.app, usergroup_id, db_session, r, lock_beat, tenant_id ) - except ModuleNotFoundError: - # Always exceptions on the MIT version, which is expected - # We shouldn't actually get here if the ee version check works - pass except SoftTimeLimitExceeded: task_logger.info( @@ -133,7 +153,11 @@ def check_for_vespa_sync_task(tenant_id: str | None) -> None: def try_generate_stale_document_sync_tasks( - db_session: Session, r: Redis, lock_beat: redis.lock.Lock, tenant_id: str | None + celery_app: Celery, + db_session: Session, + r: Redis, + lock_beat: redis.lock.Lock, + tenant_id: str | None, ) -> int | None: # the fence is up, do nothing if r.exists(RedisConnectorCredentialPair.get_fence_key()): @@ -184,7 +208,8 @@ def try_generate_stale_document_sync_tasks( def try_generate_document_set_sync_tasks( - document_set: DocumentSet, + celery_app: Celery, + document_set_id: int, db_session: Session, r: Redis, lock_beat: redis.lock.Lock, @@ -192,7 +217,7 @@ def try_generate_document_set_sync_tasks( ) -> int | None: lock_beat.reacquire() - rds = RedisDocumentSet(document_set.id) + rds = RedisDocumentSet(document_set_id) # don't generate document set sync tasks if tasks are still pending if r.exists(rds.fence_key): @@ -200,7 +225,10 @@ def try_generate_document_set_sync_tasks( # don't generate sync tasks if we're up to date # race condition with the monitor/cleanup function if we use a cached result! - db_session.refresh(document_set) + document_set = get_document_set_by_id(db_session, document_set_id) + if not document_set: + return None + if document_set.is_up_to_date: return None @@ -235,7 +263,8 @@ def try_generate_document_set_sync_tasks( def try_generate_user_group_sync_tasks( - usergroup: UserGroup, + celery_app: Celery, + usergroup_id: int, db_session: Session, r: Redis, lock_beat: redis.lock.Lock, @@ -243,14 +272,21 @@ def try_generate_user_group_sync_tasks( ) -> int | None: lock_beat.reacquire() - rug = RedisUserGroup(usergroup.id) + rug = RedisUserGroup(usergroup_id) # don't generate sync tasks if tasks are still pending if r.exists(rug.fence_key): return None # race condition with the monitor/cleanup function if we use a cached result! - db_session.refresh(usergroup) + fetch_user_group = fetch_versioned_implementation( + "danswer.db.user_group", "fetch_user_group" + ) + + usergroup = fetch_user_group(db_session, usergroup_id) + if not usergroup: + return None + if usergroup.is_up_to_date: return None @@ -680,36 +716,9 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: f"pruning={n_pruning}" ) - lock_beat.reacquire() - if r.exists(RedisConnectorCredentialPair.get_fence_key()): - monitor_connector_taskset(r) - - lock_beat.reacquire() - for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): - monitor_connector_deletion_taskset(key_bytes, r, tenant_id) - + # do some cleanup before clearing fences + # check the db for any outstanding index attempts with get_session_with_tenant(tenant_id) as db_session: - lock_beat.reacquire() - for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): - monitor_document_set_taskset(key_bytes, r, db_session) - - lock_beat.reacquire() - for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): - monitor_usergroup_taskset = ( - fetch_versioned_implementation_with_fallback( - "danswer.background.celery.tasks.vespa.tasks", - "monitor_usergroup_taskset", - noop_fallback, - ) - ) - monitor_usergroup_taskset(key_bytes, r, db_session) - - lock_beat.reacquire() - for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): - monitor_ccpair_pruning_taskset(key_bytes, r, db_session) - - # do some cleanup before clearing fences - # check the db for any outstanding index attempts attempts: list[IndexAttempt] = [] attempts.extend( get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session) @@ -727,8 +736,42 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: if not r.exists(rci.fence_key): mark_attempt_failed(a, db_session, failure_reason=failure_reason) + lock_beat.reacquire() + if r.exists(RedisConnectorCredentialPair.get_fence_key()): + monitor_connector_taskset(r) + + lock_beat.reacquire() + for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): + lock_beat.reacquire() + monitor_connector_deletion_taskset(key_bytes, r, tenant_id) + + lock_beat.reacquire() + for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): + lock_beat.reacquire() + with get_session_with_tenant(tenant_id) as db_session: + monitor_document_set_taskset(key_bytes, r, db_session) + + lock_beat.reacquire() + for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): + lock_beat.reacquire() + monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback( + "danswer.background.celery.tasks.vespa.tasks", + "monitor_usergroup_taskset", + noop_fallback, + ) + with get_session_with_tenant(tenant_id) as db_session: + monitor_usergroup_taskset(key_bytes, r, db_session) + + lock_beat.reacquire() + for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"): + lock_beat.reacquire() + with get_session_with_tenant(tenant_id) as db_session: + monitor_ccpair_pruning_taskset(key_bytes, r, db_session) + + lock_beat.reacquire() + for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): lock_beat.reacquire() - for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): + with get_session_with_tenant(tenant_id) as db_session: monitor_ccpair_indexing_taskset(key_bytes, r, db_session) # uncomment for debugging if needed diff --git a/backend/danswer/background/celery/celery_run.py b/backend/danswer/background/celery/versioned_apps/beat.py similarity index 55% rename from backend/danswer/background/celery/celery_run.py rename to backend/danswer/background/celery/versioned_apps/beat.py index 0fdb2f044a8..d1b7dc591d9 100644 --- a/backend/danswer/background/celery/celery_run.py +++ b/backend/danswer/background/celery/versioned_apps/beat.py @@ -1,9 +1,8 @@ -"""Entry point for running celery worker / celery beat.""" +"""Factory stub for running celery worker / celery beat.""" from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable - set_is_ee_based_on_env_variable() -celery_app = fetch_versioned_implementation( - "danswer.background.celery.celery_app", "celery_app" +app = fetch_versioned_implementation( + "danswer.background.celery.apps.beat", "celery_app" ) diff --git a/backend/danswer/background/celery/versioned_apps/heavy.py b/backend/danswer/background/celery/versioned_apps/heavy.py new file mode 100644 index 00000000000..c2b58a53bfc --- /dev/null +++ b/backend/danswer/background/celery/versioned_apps/heavy.py @@ -0,0 +1,17 @@ +"""Factory stub for running celery worker / celery beat. +This code is different from the primary/beat stubs because there is no EE version to +fetch. Port over the code in those files if we add an EE version of this worker.""" +from celery import Celery + +from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable + +set_is_ee_based_on_env_variable() + + +def get_app() -> Celery: + from danswer.background.celery.apps.heavy import celery_app + + return celery_app + + +app = get_app() diff --git a/backend/danswer/background/celery/versioned_apps/indexing.py b/backend/danswer/background/celery/versioned_apps/indexing.py new file mode 100644 index 00000000000..ed26fc548bc --- /dev/null +++ b/backend/danswer/background/celery/versioned_apps/indexing.py @@ -0,0 +1,17 @@ +"""Factory stub for running celery worker / celery beat. +This code is different from the primary/beat stubs because there is no EE version to +fetch. Port over the code in those files if we add an EE version of this worker.""" +from celery import Celery + +from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable + +set_is_ee_based_on_env_variable() + + +def get_app() -> Celery: + from danswer.background.celery.apps.indexing import celery_app + + return celery_app + + +app = get_app() diff --git a/backend/danswer/background/celery/versioned_apps/light.py b/backend/danswer/background/celery/versioned_apps/light.py new file mode 100644 index 00000000000..3d229431ce5 --- /dev/null +++ b/backend/danswer/background/celery/versioned_apps/light.py @@ -0,0 +1,17 @@ +"""Factory stub for running celery worker / celery beat. +This code is different from the primary/beat stubs because there is no EE version to +fetch. Port over the code in those files if we add an EE version of this worker.""" +from celery import Celery + +from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable + +set_is_ee_based_on_env_variable() + + +def get_app() -> Celery: + from danswer.background.celery.apps.light import celery_app + + return celery_app + + +app = get_app() diff --git a/backend/danswer/background/celery/versioned_apps/primary.py b/backend/danswer/background/celery/versioned_apps/primary.py new file mode 100644 index 00000000000..2d97caa3da5 --- /dev/null +++ b/backend/danswer/background/celery/versioned_apps/primary.py @@ -0,0 +1,8 @@ +"""Factory stub for running celery worker / celery beat.""" +from danswer.utils.variable_functionality import fetch_versioned_implementation +from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable + +set_is_ee_based_on_env_variable() +app = fetch_versioned_implementation( + "danswer.background.celery.apps.primary", "celery_app" +) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index d53fb0b12ea..caf7a103b94 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -198,6 +198,41 @@ except ValueError: CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT +CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT = 24 +try: + CELERY_WORKER_LIGHT_CONCURRENCY = int( + os.environ.get( + "CELERY_WORKER_LIGHT_CONCURRENCY", CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT + ) + ) +except ValueError: + CELERY_WORKER_LIGHT_CONCURRENCY = CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT + +CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT = 8 +try: + CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = int( + os.environ.get( + "CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER", + CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT, + ) + ) +except ValueError: + CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = ( + CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT + ) + +CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1 +try: + env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY") + if not env_value: + env_value = os.environ.get("NUM_INDEXING_WORKERS") + + if not env_value: + env_value = str(CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT) + CELERY_WORKER_INDEXING_CONCURRENCY = int(env_value) +except ValueError: + CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT + ##### # Connector Configs ##### diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 9cfe72275af..db35807ad54 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -16,6 +16,7 @@ from danswer.background.celery.tasks.pruning.tasks import ( try_creating_prune_generator_task, ) +from danswer.background.celery.versioned_apps.primary import app as primary_app from danswer.db.connector_credential_pair import add_credential_to_connector from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id from danswer.db.connector_credential_pair import remove_credential_from_connector @@ -49,6 +50,7 @@ ) from ee.danswer.db.user_group import validate_user_creation_permissions + logger = setup_logger() router = APIRouter(prefix="/manage") @@ -261,7 +263,7 @@ def prune_cc_pair( f"{cc_pair.connector.name} connector." ) tasks_created = try_creating_prune_generator_task( - cc_pair, db_session, r, current_tenant_id.get() + primary_app, cc_pair, db_session, r, current_tenant_id.get() ) if not tasks_created: raise HTTPException( @@ -318,7 +320,7 @@ def sync_cc_pair( db_session: Session = Depends(get_session), ) -> StatusResponse[list[int]]: # avoiding circular refs - from ee.danswer.background.celery.celery_app import ( + from ee.danswer.background.celery.apps.primary import ( sync_external_doc_permissions_task, ) diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 8de42db3863..54d11e867bd 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -18,6 +18,7 @@ from danswer.background.celery.celery_redis import RedisConnectorIndexing from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task +from danswer.background.celery.versioned_apps.primary import app as primary_app from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES from danswer.configs.constants import DocumentSource from danswer.configs.constants import FileOrigin @@ -834,6 +835,7 @@ def connector_run_once( for cc_pair in connector_credential_pairs: if cc_pair is not None: attempt_id = try_creating_indexing_task( + primary_app, cc_pair, search_settings, run_info.from_beginning, diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index 7771c1ed824..d16aa59c4cb 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -10,7 +10,7 @@ from danswer.auth.users import current_admin_user from danswer.auth.users import current_curator_or_admin_user -from danswer.background.celery.celery_app import celery_app +from danswer.background.celery.versioned_apps.primary import app as primary_app from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DocumentSource @@ -195,7 +195,7 @@ def create_deletion_attempt_for_connector_id( db_session.commit() # run the beat task to pick up this deletion from the db immediately - celery_app.send_task( + primary_app.send_task( "check_for_connector_deletion_task", priority=DanswerCeleryPriority.HIGH, kwargs={"tenant_id": tenant_id}, diff --git a/backend/ee/danswer/background/celery/apps/beat.py b/backend/ee/danswer/background/celery/apps/beat.py new file mode 100644 index 00000000000..20325e77df6 --- /dev/null +++ b/backend/ee/danswer/background/celery/apps/beat.py @@ -0,0 +1,52 @@ +##### +# Celery Beat (Periodic Tasks) Settings +##### +from datetime import timedelta + +from danswer.background.celery.apps.beat import celery_app +from danswer.db.engine import get_all_tenant_ids + + +tenant_ids = get_all_tenant_ids() + +tasks_to_schedule = [ + { + "name": "sync-external-doc-permissions", + "task": "check_sync_external_doc_permissions_task", + "schedule": timedelta(seconds=5), # TODO: optimize this + }, + { + "name": "sync-external-group-permissions", + "task": "check_sync_external_group_permissions_task", + "schedule": timedelta(seconds=5), # TODO: optimize this + }, + { + "name": "autogenerate_usage_report", + "task": "autogenerate_usage_report_task", + "schedule": timedelta(days=30), # TODO: change this to config flag + }, + { + "name": "check-ttl-management", + "task": "check_ttl_management_task", + "schedule": timedelta(hours=1), + }, +] + +# Build the celery beat schedule dynamically +beat_schedule = {} + +for tenant_id in tenant_ids: + for task in tasks_to_schedule: + task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task + beat_schedule[task_name] = { + "task": task["task"], + "schedule": task["schedule"], + "args": (tenant_id,), # Must pass tenant_id as an argument + } + +# Include any existing beat schedules +existing_beat_schedule = celery_app.conf.beat_schedule or {} +beat_schedule.update(existing_beat_schedule) + +# Update the Celery app configuration +celery_app.conf.beat_schedule = beat_schedule diff --git a/backend/ee/danswer/background/celery/celery_app.py b/backend/ee/danswer/background/celery/apps/primary.py similarity index 76% rename from backend/ee/danswer/background/celery/celery_app.py rename to backend/ee/danswer/background/celery/apps/primary.py index 4010b8b3998..97c5b0221ca 100644 --- a/backend/ee/danswer/background/celery/celery_app.py +++ b/backend/ee/danswer/background/celery/apps/primary.py @@ -1,11 +1,8 @@ -from datetime import timedelta - -from danswer.background.celery.celery_app import celery_app +from danswer.background.celery.apps.primary import celery_app from danswer.background.task_utils import build_celery_task_wrapper from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.app_configs import MULTI_TENANT from danswer.db.chat import delete_chat_sessions_older_than -from danswer.db.engine import get_all_tenant_ids from danswer.db.engine import get_session_with_tenant from danswer.server.settings.store import load_settings from danswer.utils.logger import setup_logger @@ -138,53 +135,3 @@ def autogenerate_usage_report_task(tenant_id: str | None) -> None: user_id=None, period=None, ) - - -##### -# Celery Beat (Periodic Tasks) Settings -##### - - -tenant_ids = get_all_tenant_ids() - -tasks_to_schedule = [ - { - "name": "sync-external-doc-permissions", - "task": "check_sync_external_doc_permissions_task", - "schedule": timedelta(seconds=5), # TODO: optimize this - }, - { - "name": "sync-external-group-permissions", - "task": "check_sync_external_group_permissions_task", - "schedule": timedelta(seconds=5), # TODO: optimize this - }, - { - "name": "autogenerate_usage_report", - "task": "autogenerate_usage_report_task", - "schedule": timedelta(days=30), # TODO: change this to config flag - }, - { - "name": "check-ttl-management", - "task": "check_ttl_management_task", - "schedule": timedelta(hours=1), - }, -] - -# Build the celery beat schedule dynamically -beat_schedule = {} - -for tenant_id in tenant_ids: - for task in tasks_to_schedule: - task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task - beat_schedule[task_name] = { - "task": task["task"], - "schedule": task["schedule"], - "args": (tenant_id,), # Must pass tenant_id as an argument - } - -# Include any existing beat schedules -existing_beat_schedule = celery_app.conf.beat_schedule or {} -beat_schedule.update(existing_beat_schedule) - -# Update the Celery app configuration -celery_app.conf.beat_schedule = beat_schedule diff --git a/backend/ee/danswer/background/celery/tasks/vespa/tasks.py b/backend/ee/danswer/background/celery/tasks/vespa/tasks.py index 259f2474928..a2b45324d46 100644 --- a/backend/ee/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/ee/danswer/background/celery/tasks/vespa/tasks.py @@ -3,7 +3,7 @@ from redis import Redis from sqlalchemy.orm import Session -from danswer.background.celery.celery_app import task_logger +from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.celery_redis import RedisUserGroup from danswer.utils.logger import setup_logger from ee.danswer.db.user_group import delete_user_group diff --git a/backend/scripts/dev_run_background_jobs.py b/backend/scripts/dev_run_background_jobs.py index f3a00392465..1ca823e0935 100644 --- a/backend/scripts/dev_run_background_jobs.py +++ b/backend/scripts/dev_run_background_jobs.py @@ -20,14 +20,13 @@ def run_jobs() -> None: cmd_worker_primary = [ "celery", "-A", - "ee.danswer.background.celery.celery_app", + "danswer.background.celery.versioned_apps.primary", "worker", "--pool=threads", "--concurrency=6", "--prefetch-multiplier=1", "--loglevel=INFO", - "-n", - "primary@%n", + "--hostname=primary@%n", "-Q", "celery", ] @@ -35,14 +34,13 @@ def run_jobs() -> None: cmd_worker_light = [ "celery", "-A", - "ee.danswer.background.celery.celery_app", + "danswer.background.celery.versioned_apps.light", "worker", "--pool=threads", "--concurrency=16", "--prefetch-multiplier=8", "--loglevel=INFO", - "-n", - "light@%n", + "--hostname=light@%n", "-Q", "vespa_metadata_sync,connector_deletion", ] @@ -50,14 +48,13 @@ def run_jobs() -> None: cmd_worker_heavy = [ "celery", "-A", - "ee.danswer.background.celery.celery_app", + "danswer.background.celery.versioned_apps.heavy", "worker", "--pool=threads", "--concurrency=6", "--prefetch-multiplier=1", "--loglevel=INFO", - "-n", - "heavy@%n", + "--hostname=heavy@%n", "-Q", "connector_pruning", ] @@ -65,21 +62,20 @@ def run_jobs() -> None: cmd_worker_indexing = [ "celery", "-A", - "ee.danswer.background.celery.celery_app", + "danswer.background.celery.versioned_apps.indexing", "worker", "--pool=threads", "--concurrency=1", "--prefetch-multiplier=1", "--loglevel=INFO", - "-n", - "indexing@%n", + "--hostname=indexing@%n", "--queues=connector_indexing", ] cmd_beat = [ "celery", "-A", - "ee.danswer.background.celery.celery_app", + "danswer.background.celery.versioned_apps.beat", "beat", "--loglevel=INFO", ] diff --git a/backend/supervisord.conf b/backend/supervisord.conf index 76026bc5667..93472161854 100644 --- a/backend/supervisord.conf +++ b/backend/supervisord.conf @@ -15,10 +15,7 @@ logfile=/var/log/supervisord.log # relatively compute-light (e.g. they tend to just make a bunch of requests to # Vespa / Postgres) [program:celery_worker_primary] -command=celery -A danswer.background.celery.celery_run:celery_app worker - --pool=threads - --concurrency=4 - --prefetch-multiplier=1 +command=celery -A danswer.background.celery.versioned_apps.primary worker --loglevel=INFO --hostname=primary@%%n -Q celery @@ -33,13 +30,10 @@ stopasgroup=true # since this is often the bottleneck for "sync" jobs (e.g. document set syncing, # user group syncing, deletion, etc.) [program:celery_worker_light] -command=bash -c "celery -A danswer.background.celery.celery_run:celery_app worker \ - --pool=threads \ - --concurrency=${CELERY_WORKER_LIGHT_CONCURRENCY:-24} \ - --prefetch-multiplier=${CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER:-8} \ - --loglevel=INFO \ - --hostname=light@%%n \ - -Q vespa_metadata_sync,connector_deletion" +command=celery -A danswer.background.celery.versioned_apps.light worker + --loglevel=INFO + --hostname=light@%%n + -Q vespa_metadata_sync,connector_deletion stdout_logfile=/var/log/celery_worker_light.log stdout_logfile_maxbytes=16MB redirect_stderr=true @@ -48,10 +42,7 @@ startsecs=10 stopasgroup=true [program:celery_worker_heavy] -command=celery -A danswer.background.celery.celery_run:celery_app worker - --pool=threads - --concurrency=4 - --prefetch-multiplier=1 +command=celery -A danswer.background.celery.versioned_apps.heavy worker --loglevel=INFO --hostname=heavy@%%n -Q connector_pruning @@ -63,13 +54,10 @@ startsecs=10 stopasgroup=true [program:celery_worker_indexing] -command=bash -c "celery -A danswer.background.celery.celery_run:celery_app worker \ - --pool=threads \ - --concurrency=${CELERY_WORKER_INDEXING_CONCURRENCY:-${NUM_INDEXING_WORKERS:-1}} \ - --prefetch-multiplier=1 \ - --loglevel=INFO \ - --hostname=indexing@%%n \ - -Q connector_indexing" +command=celery -A danswer.background.celery.versioned_apps.indexing worker + --loglevel=INFO + --hostname=indexing@%%n + -Q connector_indexing stdout_logfile=/var/log/celery_worker_indexing.log stdout_logfile_maxbytes=16MB redirect_stderr=true @@ -79,7 +67,7 @@ stopasgroup=true # Job scheduler for periodic tasks [program:celery_beat] -command=celery -A danswer.background.celery.celery_run:celery_app beat +command=celery -A danswer.background.celery.versioned_apps.beat beat stdout_logfile=/var/log/celery_beat.log stdout_logfile_maxbytes=16MB redirect_stderr=true