From 7906d9edc837767834fc51d15d2a7313a8d888d2 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 17 Oct 2024 19:55:05 -0700 Subject: [PATCH] Add all-tenants migration for K8 job (#2846) * add migration * update migration logic for tenants * k * k * k * k --- backend/alembic/env.py | 171 +++++++++++++++++++++++++++++------------ 1 file changed, 120 insertions(+), 51 deletions(-) diff --git a/backend/alembic/env.py b/backend/alembic/env.py index afa5a9669c1..c89d3455227 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -1,10 +1,11 @@ +from sqlalchemy.engine.base import Connection from typing import Any import asyncio from logging.config import fileConfig +import logging from alembic import context from sqlalchemy import pool -from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.sql import text @@ -12,6 +13,7 @@ from danswer.db.engine import build_connection_string from danswer.db.models import Base from celery.backends.database.session import ResultModelBase # type: ignore +from danswer.background.celery.celery_app import get_all_tenant_ids # Alembic Config object config = context.config @@ -22,66 +24,42 @@ ): fileConfig(config.config_file_name) -# Add your model's MetaData object here -# for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata +# Add your model's MetaData object here for 'autogenerate' support target_metadata = [Base.metadata, ResultModelBase.metadata] - -def get_schema_options() -> tuple[str, bool]: - x_args_raw = context.get_x_argument() - x_args = {} - for arg in x_args_raw: - for pair in arg.split(","): - if "=" in pair: - key, value = pair.split("=", 1) - x_args[key.strip()] = value.strip() - schema_name = x_args.get("schema", "public") - create_schema = x_args.get("create_schema", "true").lower() == "true" - return schema_name, create_schema - - EXCLUDE_TABLES = {"kombu_queue", "kombu_message"} +# Set up logging +logger = logging.getLogger(__name__) + def include_object( object: Any, name: str, type_: str, reflected: bool, compare_to: Any ) -> bool: + """ + Determines whether a database object should be included in migrations. + Excludes specified tables from migrations. + """ if type_ == "table" and name in EXCLUDE_TABLES: return False return True -def run_migrations_offline() -> None: - """Run migrations in 'offline' mode. - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - Calls to context.execute() here emit the given string to the - script output. +def get_schema_options() -> tuple[str, bool, bool]: """ - schema_name, _ = get_schema_options() - url = build_connection_string() - - context.configure( - url=url, - target_metadata=target_metadata, # type: ignore - literal_binds=True, - include_object=include_object, - version_table_schema=schema_name, - include_schemas=True, - script_location=config.get_main_option("script_location"), - dialect_opts={"paramstyle": "named"}, - ) - - with context.begin_transaction(): - context.run_migrations() - - -def do_run_migrations(connection: Connection) -> None: - schema_name, create_schema = get_schema_options() + Parses command-line options passed via '-x' in Alembic commands. + Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options. + """ + x_args_raw = context.get_x_argument() + x_args = {} + for arg in x_args_raw: + for pair in arg.split(","): + if "=" in pair: + key, value = pair.split("=", 1) + x_args[key.strip()] = value.strip() + schema_name = x_args.get("schema", "public") + create_schema = x_args.get("create_schema", "true").lower() == "true" + upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true" if MULTI_TENANT and schema_name == "public": raise ValueError( @@ -89,6 +67,17 @@ def do_run_migrations(connection: Connection) -> None: "Please specify a tenant-specific schema." ) + return schema_name, create_schema, upgrade_all_tenants + + +def do_run_migrations( + connection: Connection, schema_name: str, create_schema: bool +) -> None: + """ + Executes migrations in the specified schema. + """ + logger.info(f"About to migrate schema: {schema_name}") + if create_schema: connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"')) connection.execute(text("COMMIT")) @@ -112,18 +101,98 @@ def do_run_migrations(connection: Connection) -> None: async def run_async_migrations() -> None: - connectable = create_async_engine( + """ + Determines whether to run migrations for a single schema or all schemas, + and executes migrations accordingly. + """ + schema_name, create_schema, upgrade_all_tenants = get_schema_options() + + engine = create_async_engine( build_connection_string(), poolclass=pool.NullPool, ) - async with connectable.connect() as connection: - await connection.run_sync(do_run_migrations) + if upgrade_all_tenants: + # Run migrations for all tenant schemas sequentially + tenant_schemas = get_all_tenant_ids() + + for schema in tenant_schemas: + try: + logger.info(f"Migrating schema: {schema}") + async with engine.connect() as connection: + await connection.run_sync( + do_run_migrations, + schema_name=schema, + create_schema=create_schema, + ) + except Exception as e: + logger.error(f"Error migrating schema {schema}: {e}") + raise + else: + try: + logger.info(f"Migrating schema: {schema_name}") + async with engine.connect() as connection: + await connection.run_sync( + do_run_migrations, + schema_name=schema_name, + create_schema=create_schema, + ) + except Exception as e: + logger.error(f"Error migrating schema {schema_name}: {e}") + raise + + await engine.dispose() + + +def run_migrations_offline() -> None: + """ + Run migrations in 'offline' mode. + """ + schema_name, _, upgrade_all_tenants = get_schema_options() + url = build_connection_string() + + if upgrade_all_tenants: + # Run offline migrations for all tenant schemas + engine = create_async_engine(url) + tenant_schemas = get_all_tenant_ids() + engine.sync_engine.dispose() + + for schema in tenant_schemas: + logger.info(f"Migrating schema: {schema}") + context.configure( + url=url, + target_metadata=target_metadata, # type: ignore + literal_binds=True, + include_object=include_object, + version_table_schema=schema, + include_schemas=True, + script_location=config.get_main_option("script_location"), + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + else: + logger.info(f"Migrating schema: {schema_name}") + context.configure( + url=url, + target_metadata=target_metadata, # type: ignore + literal_binds=True, + include_object=include_object, + version_table_schema=schema_name, + include_schemas=True, + script_location=config.get_main_option("script_location"), + dialect_opts={"paramstyle": "named"}, + ) - await connectable.dispose() + with context.begin_transaction(): + context.run_migrations() def run_migrations_online() -> None: + """ + Runs migrations in 'online' mode using an asynchronous engine. + """ asyncio.run(run_async_migrations())