Skip to content

Commit

Permalink
Tenant provisioning in the dataplane (#2694)
Browse files Browse the repository at this point in the history
* add tenant provisioning to data plane

* minor typing update

* ensure tenant router included

* proper auth check

* update disabling logic

* validated basic provisioning

* use new kv store
  • Loading branch information
pablonyx authored Oct 6, 2024
1 parent e00f467 commit 0da736b
Show file tree
Hide file tree
Showing 16 changed files with 620 additions and 428 deletions.
28 changes: 28 additions & 0 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional
from typing import Tuple

import jwt
from email_validator import EmailNotValidError
from email_validator import validate_email
from fastapi import APIRouter
Expand Down Expand Up @@ -37,8 +38,10 @@
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DATA_PLANE_SECRET
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import EXPECTED_API_KEY
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SMTP_PASS
Expand Down Expand Up @@ -504,3 +507,28 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
def get_default_admin_user_emails_() -> list[str]:
# No default seeding available for Danswer MIT
return []


async def control_plane_dep(request: Request) -> None:
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key")
raise HTTPException(status_code=401, detail="Invalid API key")

auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
logger.warning("Invalid authorization header")
raise HTTPException(status_code=401, detail="Invalid authorization header")

token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
if payload.get("scope") != "tenant:create":
logger.warning("Insufficient permissions")
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
logger.warning("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")
237 changes: 119 additions & 118 deletions backend/danswer/chat/load_yamls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
Expand All @@ -18,148 +17,150 @@
from danswer.search.enums import RecencyBiasSetting


def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
def load_prompts_from_yaml(
db_session: Session, prompts_yaml: str = PROMPTS_YAML
) -> None:
with open(prompts_yaml, "r") as file:
data = yaml.safe_load(file)

all_prompts = data.get("prompts", [])
with Session(get_sqlalchemy_engine()) as db_session:
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)


def load_personas_from_yaml(
db_session: Session,
personas_yaml: str = PERSONAS_YAML,
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> None:
with open(personas_yaml, "r") as file:
data = yaml.safe_load(file)

all_personas = data.get("personas", [])
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]

# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None

prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]

# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None

prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")

if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]

p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)

llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")

# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"

existing_persona = (
db_session.query(Persona)
.filter(Persona.name == persona["name"])
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")

if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]

p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)

upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
builtin_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)


def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
if image_gen_tool:
tool_ids.append(image_gen_tool.id)

llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")

# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"

existing_persona = (
db_session.query(Persona).filter(Persona.name == persona["name"]).first()
)

upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
builtin_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)


def load_input_prompts_from_yaml(
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
) -> None:
with open(input_prompts_yaml, "r") as file:
data = yaml.safe_load(file)

all_input_prompts = data.get("input_prompts", [])
with Session(get_sqlalchemy_engine()) as db_session:
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience

insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)


def load_chat_yamls(
db_session: Session,
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
input_prompts_yaml: str = INPUT_PROMPT_YAML,
) -> None:
load_prompts_from_yaml(prompt_yaml)
load_personas_from_yaml(personas_yaml)
load_input_prompts_from_yaml(input_prompts_yaml)
load_prompts_from_yaml(db_session, prompt_yaml)
load_personas_from_yaml(db_session, personas_yaml)
load_input_prompts_from_yaml(db_session, input_prompts_yaml)
4 changes: 4 additions & 0 deletions backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,7 @@

MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")


DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")
3 changes: 2 additions & 1 deletion backend/danswer/db/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_db_current_time(db_session: Session) -> datetime:


# Regular expression to validate schema names to prevent SQL injection
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")


def is_valid_schema_name(name: str) -> bool:
Expand Down Expand Up @@ -281,6 +281,7 @@ def get_session_with_tenant(tenant_id: str | None = None) -> Session:
tenant_id = current_tenant_id.get()

if not is_valid_schema_name(tenant_id):
logger.error(f"Invalid tenant ID: {tenant_id}")
raise Exception("Invalid tenant ID")

engine = SqlEngine.get_engine()
Expand Down
9 changes: 3 additions & 6 deletions backend/danswer/key_value_store/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sqlalchemy.orm import Session

from danswer.db.engine import get_session_factory
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import KVStore
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KeyValueStore
Expand All @@ -26,12 +26,9 @@ def __init__(self) -> None:

@contextmanager
def get_session(self) -> Iterator[Session]:
factory = get_session_factory()
session: Session = factory()
try:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
yield session
finally:
session.close()

def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
# Not encrypted in Redis, but encrypted in Postgres
Expand Down
Loading

0 comments on commit 0da736b

Please sign in to comment.