Skip to content

Commit

Permalink
Multitenant redis update (#2889)
Browse files Browse the repository at this point in the history
* add multi tenancy to redis

* rename context var

* k

* args -> kwargs

* minor update to kv interface

* robustify
  • Loading branch information
pablonyx authored Oct 24, 2024
1 parent b9fb657 commit 0545fb4
Show file tree
Hide file tree
Showing 26 changed files with 408 additions and 204 deletions.
10 changes: 5 additions & 5 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import current_tenant_id
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA

logger = setup_logger()
Expand Down Expand Up @@ -249,7 +249,7 @@ async def create(
)

async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)

verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
Expand Down Expand Up @@ -288,7 +288,7 @@ async def create(
else:
raise exceptions.UserAlreadyExists()

current_tenant_id.reset(token)
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user

async def on_after_login(
Expand Down Expand Up @@ -342,7 +342,7 @@ async def oauth_callback(

token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)

verify_email_in_whitelist(account_email, tenant_id)
verify_email_domain(account_email)
Expand Down Expand Up @@ -432,7 +432,7 @@ async def oauth_callback(
user.oidc_expiry = None # type: ignore

if token:
current_tenant_id.reset(token)
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

return user

Expand Down
77 changes: 60 additions & 17 deletions backend/danswer/background/celery/apps/app_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.engine import get_all_tenant_ids
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
Expand Down Expand Up @@ -56,7 +57,7 @@ def on_task_postrun(
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
kwargs: dict[str, Any] | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
Expand All @@ -83,7 +84,19 @@ def on_task_postrun(
if not task_id:
return

r = get_redis_client()
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = None
else:
tenant_id = kwargs.get("tenant_id")

task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
)

r = get_redis_client(tenant_id=tenant_id)

if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
Expand Down Expand Up @@ -124,7 +137,7 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None


def wait_for_redis(sender: Any, **kwargs: Any) -> None:
r = get_redis_client()
r = get_redis_client(tenant_id=None)

WAIT_INTERVAL = 5
WAIT_LIMIT = 60
Expand Down Expand Up @@ -157,34 +170,52 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:


def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
r = get_redis_client()

WAIT_INTERVAL = 5
WAIT_LIMIT = 60

logger.info("Running as a secondary celery worker.")
logger.info("Waiting for primary worker to be ready...")
logger.info("Waiting for all tenant primary workers to be ready...")
time_start = time.monotonic()

while True:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
tenant_ids = get_all_tenant_ids()
# Check if we have a primary worker lock for each tenant
all_tenants_ready = all(
get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
for tenant_id in tenant_ids
)

if all_tenants_ready:
break

time.monotonic()
time_elapsed = time.monotonic() - time_start
ready_tenants = sum(
1
for tenant_id in tenant_ids
if get_redis_client(tenant_id=tenant_id).exists(
DanswerRedisLocks.PRIMARY_WORKER
)
)

logger.info(
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
f"Not all tenant primary workers are ready yet. "
f"Ready tenants: {ready_tenants}/{len(tenant_ids)} "
f"elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)

if time_elapsed > WAIT_LIMIT:
msg = (
f"Primary worker was not ready within the timeout. "
f"Not all tenant primary workers were ready within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)

time.sleep(WAIT_INTERVAL)

logger.info("Wait for primary worker completed successfully. Continuing...")
logger.info("All tenant primary workers are ready. Continuing...")
return


Expand All @@ -196,14 +227,26 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return

if not sender.primary_worker_lock:
if not hasattr(sender, "primary_worker_locks"):
return

logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
if lock.owned():
lock.release()
sender.primary_worker_lock = None
for tenant_id, lock in sender.primary_worker_locks.items():
try:
if lock and lock.owned():
logger.debug(f"Attempting to release lock for tenant {tenant_id}")
try:
lock.release()
logger.debug(f"Successfully released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Failed to release lock for tenant {tenant_id}. Error: {str(e)}"
)
finally:
sender.primary_worker_locks[tenant_id] = None
except Exception as e:
logger.error(
f"Error checking lock status for tenant {tenant_id}. Error: {str(e)}"
)


def on_setup_logging(
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/background/celery/apps/beat.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def on_setup_logging(
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"args": (tenant_id,), # Must pass tenant_id as an argument
"kwargs": {"tenant_id": tenant_id}, # Must pass tenant_id as an argument
}

# Include any existing beat schedules
Expand Down
Loading

0 comments on commit 0545fb4

Please sign in to comment.