Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluate None to default #3069

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions backend/danswer/db/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,26 +323,36 @@ async def get_async_session_with_tenant(
yield session


def get_session_with_default_tenant() -> Generator[Session, None, None]:
"""
Get a database session using the current tenant ID from the context variable.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
return get_session_with_tenant(tenant_id)


@contextmanager
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""
Generate a database session bound to a connection with the appropriate tenant schema set.
This preserves the tenant ID across the session and reverts to the previous tenant ID
after the session is closed.
If tenant ID is set, we save the previous tenant ID from the context var to set
after the session is closed. The value `None` evaluates to the default schema.
Generate a database session for a specific tenant.

This function:
1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session.
3. Reverts to the previous tenant ID after the session is closed.
4. Uses the default schema if no tenant ID is provided.
"""
engine = get_sqlalchemy_engine()

# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA

if tenant_id is None:
tenant_id = previous_tenant_id
else:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
tenant_id = POSTGRES_DEFAULT_SCHEMA

CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)

event.listen(engine, "checkout", set_search_path_on_checkout)

Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/db/search_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt
Expand Down Expand Up @@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:

def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with get_session_with_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pydantic import BaseModel

from danswer.configs.constants import FileOrigin
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_session_with_default_tenant
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
Expand Down Expand Up @@ -187,7 +187,7 @@ def get_args_for_non_tool_calling_llm(
def _save_and_get_file_references(
self, file_content: bytes | str, content_type: str
) -> List[str]:
with get_session_with_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
file_store = get_default_file_store(db_session)

file_id = str(uuid.uuid4())
Expand Down Expand Up @@ -299,7 +299,7 @@ def build_next_prompt(

# Load files from storage
files = []
with get_session_with_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
file_store = get_default_file_store(db_session)

for file_id in response.tool_result.file_ids:
Expand Down
Loading