diff --git a/README.md b/README.md index 1f9fbef5b2f..0b7f87ceaa4 100644 --- a/README.md +++ b/README.md @@ -1,48 +1,48 @@ - +
-
Open Source Gen-AI Chat + Unified Search.
+Open Source Gen-AI + Enterprise Search.
-[Danswer](https://www.danswer.ai/) is the AI Assistant connected to your company's docs, apps, and people. -Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any +[Onyx](https://www.onyx.app/) (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people. +Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your -own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready +own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for -configuring Personas (AI Assistants) and their Prompts. +configuring AI Assistants. -Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc. -By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if +Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc. +By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already supported?" or "Where's the pull request for feature Y?"diff --git a/backend/alembic/versions/bf7a81109301_delete_input_prompts.py b/backend/alembic/versions/bf7a81109301_delete_input_prompts.py new file mode 100644 index 00000000000..7aa3faf3277 --- /dev/null +++ b/backend/alembic/versions/bf7a81109301_delete_input_prompts.py @@ -0,0 +1,57 @@ +"""delete_input_prompts + +Revision ID: bf7a81109301 +Revises: f7a894b06d02 +Create Date: 2024-12-09 12:00:49.884228 + +""" +from alembic import op +import sqlalchemy as sa +import fastapi_users_db_sqlalchemy + + +# revision identifiers, used by Alembic. +revision = "bf7a81109301" +down_revision = "f7a894b06d02" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.drop_table("inputprompt__user") + op.drop_table("inputprompt") + + +def downgrade() -> None: + op.create_table( + "inputprompt", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("prompt", sa.String(), nullable=False), + sa.Column("content", sa.String(), nullable=False), + sa.Column("active", sa.Boolean(), nullable=False), + sa.Column("is_public", sa.Boolean(), nullable=False), + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "inputprompt__user", + sa.Column("input_prompt_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["input_prompt_id"], + ["inputprompt.id"], + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["inputprompt.id"], + ), + sa.PrimaryKeyConstraint("input_prompt_id", "user_id"), + ) diff --git a/backend/danswer/auth/invited_users.py b/backend/danswer/auth/invited_users.py index fb30332afd9..ff3a8cce95e 100644 --- a/backend/danswer/auth/invited_users.py +++ b/backend/danswer/auth/invited_users.py @@ -9,7 +9,6 @@ def get_invited_users() -> list[str]: try: store = get_kv_store() - return cast(list, store.load(KV_USER_STORE_KEY)) except KvKeyNotFoundError: return list() diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index c338674db4a..18a6b0b38a4 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -342,6 +342,12 @@ os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true" ) +# Egnyte specific configs +EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE") +EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN") +EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID") +EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET") + DASK_JOB_CLIENT_ENABLED = ( os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true" ) @@ -405,21 +411,28 @@ # We don't want the metadata to overwhelm the actual contents of the chunk SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true" # Timeout to wait for job's last update before killing it, in hours -CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3)) +CLEANUP_INDEXING_JOBS_TIMEOUT = int( + os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT") or 3 +) # The indexer will warn in the logs whenver a document exceeds this threshold (in bytes) INDEXING_SIZE_WARNING_THRESHOLD = int( - os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024) + os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD") or 100 * 1024 * 1024 ) # during indexing, will log verbose memory diff stats every x batches and at the end. # 0 disables this behavior and is the default. -INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0)) +INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0) # During an indexing attempt, specifies the number of batches which are allowed to # exception without aborting the attempt. -INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0)) +INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0) +# Maximum file size in a document to be indexed +MAX_DOCUMENT_CHARS = int(os.environ.get("MAX_DOCUMENT_CHARS") or 5_000_000) +MAX_FILE_SIZE_BYTES = int( + os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024 +) # 2GB in bytes ##### # Miscellaneous diff --git a/backend/danswer/configs/chat_configs.py b/backend/danswer/configs/chat_configs.py index 2d72bed0f5a..88ff301a99e 100644 --- a/backend/danswer/configs/chat_configs.py +++ b/backend/danswer/configs/chat_configs.py @@ -3,7 +3,6 @@ PROMPTS_YAML = "./danswer/seeding/prompts.yaml" PERSONAS_YAML = "./danswer/seeding/personas.yaml" -INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml" NUM_RETURNED_HITS = 50 # Used for LLM filtering and reranking diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index e84c229a696..b9b5f7deb26 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -132,6 +132,7 @@ class DocumentSource(str, Enum): NOT_APPLICABLE = "not_applicable" FRESHDESK = "freshdesk" FIREFLIES = "fireflies" + EGNYTE = "egnyte" DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE] diff --git a/backend/danswer/connectors/confluence/onyx_confluence.py b/backend/danswer/connectors/confluence/onyx_confluence.py index 267c0f9edeb..c6ffd5aa3fb 100644 --- a/backend/danswer/connectors/confluence/onyx_confluence.py +++ b/backend/danswer/connectors/confluence/onyx_confluence.py @@ -368,4 +368,5 @@ def build_confluence_client( backoff_and_retry=True, max_backoff_retries=10, max_backoff_seconds=60, + cloud=is_cloud, ) diff --git a/backend/danswer/connectors/egnyte/connector.py b/backend/danswer/connectors/egnyte/connector.py new file mode 100644 index 00000000000..264e837bf94 --- /dev/null +++ b/backend/danswer/connectors/egnyte/connector.py @@ -0,0 +1,373 @@ +import io +import os +from collections.abc import Generator +from datetime import datetime +from datetime import timezone +from logging import Logger +from typing import Any +from typing import cast +from typing import IO + +import requests +from retry import retry + +from danswer.configs.app_configs import EGNYTE_BASE_DOMAIN +from danswer.configs.app_configs import EGNYTE_CLIENT_ID +from danswer.configs.app_configs import EGNYTE_CLIENT_SECRET +from danswer.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE +from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.constants import DocumentSource +from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import LoadConnector +from danswer.connectors.interfaces import OAuthConnector +from danswer.connectors.interfaces import PollConnector +from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.connectors.models import BasicExpertInfo +from danswer.connectors.models import ConnectorMissingCredentialError +from danswer.connectors.models import Document +from danswer.connectors.models import Section +from danswer.file_processing.extract_file_text import detect_encoding +from danswer.file_processing.extract_file_text import extract_file_text +from danswer.file_processing.extract_file_text import get_file_ext +from danswer.file_processing.extract_file_text import is_text_file_extension +from danswer.file_processing.extract_file_text import is_valid_file_ext +from danswer.file_processing.extract_file_text import read_text_file +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + +_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1" +_EGNYTE_APP_BASE = "https://{domain}.egnyte.com" +_TIMEOUT = 60 + + +def _request_with_retries( + method: str, + url: str, + data: dict[str, Any] | None = None, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + timeout: int = _TIMEOUT, + stream: bool = False, + tries: int = 8, + delay: float = 1, + backoff: float = 2, +) -> requests.Response: + @retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger)) + def _make_request() -> requests.Response: + response = requests.request( + method, + url, + data=data, + headers=headers, + params=params, + timeout=timeout, + stream=stream, + ) + response.raise_for_status() + return response + + return _make_request() + + +def _parse_last_modified(last_modified: str) -> datetime: + return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace( + tzinfo=timezone.utc + ) + + +def _process_egnyte_file( + file_metadata: dict[str, Any], + file_content: IO, + base_url: str, + folder_path: str | None = None, +) -> Document | None: + """Process an Egnyte file into a Document object + + Args: + file_data: The file data from Egnyte API + file_content: The raw content of the file in bytes + base_url: The base URL for the Egnyte instance + folder_path: Optional folder path to filter results + """ + # Skip if file path doesn't match folder path filter + if folder_path and not file_metadata["path"].startswith(folder_path): + raise ValueError( + f"File path {file_metadata['path']} does not match folder path {folder_path}" + ) + + file_name = file_metadata["name"] + extension = get_file_ext(file_name) + if not is_valid_file_ext(extension): + logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") + return None + + # Extract text content based on file type + if is_text_file_extension(file_name): + encoding = detect_encoding(file_content) + file_content_raw, file_metadata = read_text_file( + file_content, encoding=encoding, ignore_danswer_metadata=False + ) + else: + file_content_raw = extract_file_text( + file=file_content, + file_name=file_name, + break_on_unprocessable=True, + ) + + # Build the web URL for the file + web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}" + + # Create document metadata + metadata: dict[str, str | list[str]] = { + "file_path": file_metadata["path"], + "last_modified": file_metadata.get("last_modified", ""), + } + + # Add lock info if present + if lock_info := file_metadata.get("lock_info"): + metadata[ + "lock_owner" + ] = f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}" + + # Create the document owners + primary_owner = None + if uploaded_by := file_metadata.get("uploaded_by"): + primary_owner = BasicExpertInfo( + email=uploaded_by, # Using username as email since that's what we have + ) + + # Create the document + return Document( + id=f"egnyte-{file_metadata['entry_id']}", + sections=[Section(text=file_content_raw.strip(), link=web_url)], + source=DocumentSource.EGNYTE, + semantic_identifier=file_name, + metadata=metadata, + doc_updated_at=( + _parse_last_modified(file_metadata["last_modified"]) + if "last_modified" in file_metadata + else None + ), + primary_owners=[primary_owner] if primary_owner else None, + ) + + +class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector): + def __init__( + self, + folder_path: str | None = None, + batch_size: int = INDEX_BATCH_SIZE, + ) -> None: + self.domain = "" # will always be set in `load_credentials` + self.folder_path = folder_path or "" # Root folder if not specified + self.batch_size = batch_size + self.access_token: str | None = None + + @classmethod + def oauth_id(cls) -> DocumentSource: + return DocumentSource.EGNYTE + + @classmethod + def oauth_authorization_url(cls, base_domain: str, state: str) -> str: + if not EGNYTE_CLIENT_ID: + raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") + if not EGNYTE_BASE_DOMAIN: + raise ValueError("EGNYTE_DOMAIN environment variable must be set") + + if EGNYTE_LOCALHOST_OVERRIDE: + base_domain = EGNYTE_LOCALHOST_OVERRIDE + + callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte" + return ( + f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token" + f"?client_id={EGNYTE_CLIENT_ID}" + f"&redirect_uri={callback_uri}" + f"&scope=Egnyte.filesystem" + f"&state={state}" + f"&response_type=code" + ) + + @classmethod + def oauth_code_to_token(cls, code: str) -> dict[str, Any]: + if not EGNYTE_CLIENT_ID: + raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") + if not EGNYTE_CLIENT_SECRET: + raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set") + if not EGNYTE_BASE_DOMAIN: + raise ValueError("EGNYTE_DOMAIN environment variable must be set") + + # Exchange code for token + url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token" + data = { + "client_id": EGNYTE_CLIENT_ID, + "client_secret": EGNYTE_CLIENT_SECRET, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": f"{EGNYTE_LOCALHOST_OVERRIDE or ''}/connector/oauth/callback/egnyte", + "scope": "Egnyte.filesystem", + } + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + response = _request_with_retries( + method="POST", + url=url, + data=data, + headers=headers, + # try a lot faster since this is a realtime flow + backoff=0, + delay=0.1, + ) + if not response.ok: + raise RuntimeError(f"Failed to exchange code for token: {response.text}") + + token_data = response.json() + return { + "domain": EGNYTE_BASE_DOMAIN, + "access_token": token_data["access_token"], + } + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self.domain = credentials["domain"] + self.access_token = credentials["access_token"] + return None + + def _get_files_list( + self, + path: str, + ) -> list[dict[str, Any]]: + if not self.access_token or not self.domain: + raise ConnectorMissingCredentialError("Egnyte") + + headers = { + "Authorization": f"Bearer {self.access_token}", + } + + params: dict[str, Any] = { + "list_content": True, + } + + url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}" + response = _request_with_retries( + method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT + ) + if not response.ok: + raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}") + + data = response.json() + all_files: list[dict[str, Any]] = [] + + # Add files from current directory + all_files.extend(data.get("files", [])) + + # Recursively traverse folders + for item in data.get("folders", []): + all_files.extend(self._get_files_list(item["path"])) + + return all_files + + def _filter_files( + self, + files: list[dict[str, Any]], + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> list[dict[str, Any]]: + filtered_files = [] + for file in files: + if file["is_folder"]: + continue + + file_modified = _parse_last_modified(file["last_modified"]) + if start_time and file_modified < start_time: + continue + if end_time and file_modified > end_time: + continue + + filtered_files.append(file) + + return filtered_files + + def _process_files( + self, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> Generator[list[Document], None, None]: + files = self._get_files_list(self.folder_path) + files = self._filter_files(files, start_time, end_time) + + current_batch: list[Document] = [] + for file in files: + try: + # Set up request with streaming enabled + headers = { + "Authorization": f"Bearer {self.access_token}", + } + url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}" + response = _request_with_retries( + method="GET", + url=url, + headers=headers, + timeout=_TIMEOUT, + stream=True, + ) + + if not response.ok: + logger.error( + f"Failed to fetch file content: {file['path']} (status code: {response.status_code})" + ) + continue + + # Stream the response content into a BytesIO buffer + buffer = io.BytesIO() + for chunk in response.iter_content(chunk_size=8192): + if chunk: + buffer.write(chunk) + + # Reset buffer's position to the start + buffer.seek(0) + + # Process the streamed file content + doc = _process_egnyte_file( + file_metadata=file, + file_content=buffer, + base_url=_EGNYTE_APP_BASE.format(domain=self.domain), + folder_path=self.folder_path, + ) + + if doc is not None: + current_batch.append(doc) + + if len(current_batch) >= self.batch_size: + yield current_batch + current_batch = [] + + except Exception: + logger.exception(f"Failed to process file {file['path']}") + continue + + if current_batch: + yield current_batch + + def load_from_state(self) -> GenerateDocumentsOutput: + yield from self._process_files() + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> GenerateDocumentsOutput: + start_time = datetime.fromtimestamp(start, tz=timezone.utc) + end_time = datetime.fromtimestamp(end, tz=timezone.utc) + + yield from self._process_files(start_time=start_time, end_time=end_time) + + +if __name__ == "__main__": + connector = EgnyteConnector() + connector.load_credentials( + { + "domain": os.environ["EGNYTE_DOMAIN"], + "access_token": os.environ["EGNYTE_ACCESS_TOKEN"], + } + ) + document_batches = connector.load_from_state() + print(next(document_batches)) diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index 40f926b31d1..87d1539d3d2 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -15,6 +15,7 @@ from danswer.connectors.discourse.connector import DiscourseConnector from danswer.connectors.document360.connector import Document360Connector from danswer.connectors.dropbox.connector import DropboxConnector +from danswer.connectors.egnyte.connector import EgnyteConnector from danswer.connectors.file.connector import LocalFileConnector from danswer.connectors.fireflies.connector import FirefliesConnector from danswer.connectors.freshdesk.connector import FreshdeskConnector @@ -103,6 +104,7 @@ def identify_connector_class( DocumentSource.XENFORO: XenforoConnector, DocumentSource.FRESHDESK: FreshdeskConnector, DocumentSource.FIREFLIES: FirefliesConnector, + DocumentSource.EGNYTE: EgnyteConnector, } connector_by_source = connector_map.get(source, {}) diff --git a/backend/danswer/connectors/file/connector.py b/backend/danswer/connectors/file/connector.py index b263354822f..70b7219f65a 100644 --- a/backend/danswer/connectors/file/connector.py +++ b/backend/danswer/connectors/file/connector.py @@ -17,11 +17,11 @@ from danswer.connectors.models import Document from danswer.connectors.models import Section from danswer.db.engine import get_session_with_tenant -from danswer.file_processing.extract_file_text import check_file_ext_is_valid from danswer.file_processing.extract_file_text import detect_encoding from danswer.file_processing.extract_file_text import extract_file_text from danswer.file_processing.extract_file_text import get_file_ext from danswer.file_processing.extract_file_text import is_text_file_extension +from danswer.file_processing.extract_file_text import is_valid_file_ext from danswer.file_processing.extract_file_text import load_files_from_zip from danswer.file_processing.extract_file_text import read_pdf_file from danswer.file_processing.extract_file_text import read_text_file @@ -50,7 +50,7 @@ def _read_files_and_metadata( file_content, ignore_dirs=True ): yield os.path.join(directory_path, file_info.filename), file, metadata - elif check_file_ext_is_valid(extension): + elif is_valid_file_ext(extension): yield file_name, file_content, metadata else: logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") @@ -63,7 +63,7 @@ def _process_file( pdf_pass: str | None = None, ) -> list[Document]: extension = get_file_ext(file_name) - if not check_file_ext_is_valid(extension): + if not is_valid_file_ext(extension): logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") return [] diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index d1f7ef7b274..771f9239e98 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -4,11 +4,13 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial from typing import Any +from typing import cast from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.app_configs import MAX_FILE_SIZE_BYTES from danswer.configs.constants import DocumentSource from danswer.connectors.google_drive.doc_conversion import build_slim_document from danswer.connectors.google_drive.doc_conversion import ( @@ -451,12 +453,14 @@ def _fetch_drive_items( if isinstance(self.creds, ServiceAccountCredentials) else self._manage_oauth_retrieval ) - return retrieval_method( + drive_files = retrieval_method( is_slim=is_slim, start=start, end=end, ) + return drive_files + def _extract_docs_from_google_drive( self, start: SecondsSinceUnixEpoch | None = None, @@ -472,6 +476,15 @@ def _extract_docs_from_google_drive( files_to_process = [] # Gather the files into batches to be processed in parallel for file in self._fetch_drive_items(is_slim=False, start=start, end=end): + if ( + file.get("size") + and int(cast(str, file.get("size"))) > MAX_FILE_SIZE_BYTES + ): + logger.warning( + f"Skipping file {file.get('name', 'Unknown')} as it is too large: {file.get('size')} bytes" + ) + continue + files_to_process.append(file) if len(files_to_process) >= LARGE_BATCH_SIZE: yield from _process_files_batch( diff --git a/backend/danswer/connectors/google_drive/file_retrieval.py b/backend/danswer/connectors/google_drive/file_retrieval.py index 962d531b076..9b9b17a8c27 100644 --- a/backend/danswer/connectors/google_drive/file_retrieval.py +++ b/backend/danswer/connectors/google_drive/file_retrieval.py @@ -16,7 +16,7 @@ FILE_FIELDS = ( "nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, " - "shortcutDetails, owners(emailAddress))" + "shortcutDetails, owners(emailAddress), size)" ) SLIM_FILE_FIELDS = ( "nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), " diff --git a/backend/danswer/connectors/interfaces.py b/backend/danswer/connectors/interfaces.py index c53b3de5f2f..3ab447a7a88 100644 --- a/backend/danswer/connectors/interfaces.py +++ b/backend/danswer/connectors/interfaces.py @@ -2,6 +2,7 @@ from collections.abc import Iterator from typing import Any +from danswer.configs.constants import DocumentSource from danswer.connectors.models import Document from danswer.connectors.models import SlimDocument @@ -64,6 +65,23 @@ def retrieve_all_slim_documents( raise NotImplementedError +class OAuthConnector(BaseConnector): + @classmethod + @abc.abstractmethod + def oauth_id(cls) -> DocumentSource: + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def oauth_authorization_url(cls, base_domain: str, state: str) -> str: + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def oauth_code_to_token(cls, code: str) -> dict[str, Any]: + raise NotImplementedError + + # Event driven class EventConnector(BaseConnector): @abc.abstractmethod diff --git a/backend/danswer/connectors/linear/connector.py b/backend/danswer/connectors/linear/connector.py index 22b769562d1..c6da61555bd 100644 --- a/backend/danswer/connectors/linear/connector.py +++ b/backend/danswer/connectors/linear/connector.py @@ -132,7 +132,6 @@ def _process_issues( branchName customerTicketCount description - descriptionData comments { nodes { url @@ -215,5 +214,6 @@ def poll_source( if __name__ == "__main__": connector = LinearConnector() connector.load_credentials({"linear_api_key": os.environ["LINEAR_API_KEY"]}) + document_batches = connector.load_from_state() print(next(document_batches)) diff --git a/backend/danswer/connectors/slack/connector.py b/backend/danswer/connectors/slack/connector.py index 22ace603bd4..b550e42d21f 100644 --- a/backend/danswer/connectors/slack/connector.py +++ b/backend/danswer/connectors/slack/connector.py @@ -171,7 +171,9 @@ def thread_to_doc( else first_message ) - doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}" + doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace( + "\n", " " + ) return Document( id=f"{channel_id}__{thread[0]['ts']}", diff --git a/backend/danswer/danswerbot/slack/blocks.py b/backend/danswer/danswerbot/slack/blocks.py index bb7ec29e8c0..34ec92e7daa 100644 --- a/backend/danswer/danswerbot/slack/blocks.py +++ b/backend/danswer/danswerbot/slack/blocks.py @@ -204,7 +204,8 @@ def _build_documents_blocks( continue seen_docs_identifiers.add(d.document_id) - doc_sem_id = d.semantic_identifier + # Strip newlines from the semantic identifier for Slackbot formatting + doc_sem_id = d.semantic_identifier.replace("\n", " ") if d.source_type == DocumentSource.SLACK.value: doc_sem_id = "#" + doc_sem_id diff --git a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py index 8e93620e298..e1e3673e8ec 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py @@ -373,7 +373,9 @@ def _get_slack_answer( respond_in_thread( client=client, channel=channel, - receiver_ids=receiver_ids, + receiver_ids=[message_info.sender] + if message_info.is_bot_msg and message_info.sender + else receiver_ids, text="Hello! Danswer has some results for you!", blocks=all_blocks, thread_ts=message_ts_to_respond_to, diff --git a/backend/danswer/danswerbot/slack/utils.py b/backend/danswer/danswerbot/slack/utils.py index 4221b4faf6f..147356b76d1 100644 --- a/backend/danswer/danswerbot/slack/utils.py +++ b/backend/danswer/danswerbot/slack/utils.py @@ -11,6 +11,7 @@ from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from slack_sdk.models.blocks import Block +from slack_sdk.models.blocks import SectionBlock from slack_sdk.models.metadata import Metadata from slack_sdk.socket_mode import SocketModeClient @@ -140,6 +141,40 @@ def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str: return re.sub(rf"<@{bot_tag_id}>\s", "", message_str) +def _check_for_url_in_block(block: Block) -> bool: + """ + Check if the block has a key that contains "url" in it + """ + block_dict = block.to_dict() + + def check_dict_for_url(d: dict) -> bool: + for key, value in d.items(): + if "url" in key.lower(): + return True + if isinstance(value, dict): + if check_dict_for_url(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_dict_for_url(item): + return True + return False + + return check_dict_for_url(block_dict) + + +def _build_error_block(error_message: str) -> Block: + """ + Build an error block to display in slack so that the user can see + the error without completely breaking + """ + display_text = ( + "There was an error displaying all of the Onyx answers." + f" Please let an admin or an onyx developer know. Error: {error_message}" + ) + return SectionBlock(text=display_text) + + @retry( tries=DANSWER_BOT_NUM_RETRIES, delay=0.25, @@ -162,24 +197,9 @@ def respond_in_thread( message_ids: list[str] = [] if not receiver_ids: slack_call = make_slack_api_rate_limited(client.chat_postMessage) - response = slack_call( - channel=channel, - text=text, - blocks=blocks, - thread_ts=thread_ts, - metadata=metadata, - unfurl_links=unfurl, - unfurl_media=unfurl, - ) - if not response.get("ok"): - raise RuntimeError(f"Failed to post message: {response}") - message_ids.append(response["message_ts"]) - else: - slack_call = make_slack_api_rate_limited(client.chat_postEphemeral) - for receiver in receiver_ids: + try: response = slack_call( channel=channel, - user=receiver, text=text, blocks=blocks, thread_ts=thread_ts, @@ -187,8 +207,68 @@ def respond_in_thread( unfurl_links=unfurl, unfurl_media=unfurl, ) - if not response.get("ok"): - raise RuntimeError(f"Failed to post message: {response}") + except Exception as e: + logger.warning(f"Failed to post message: {e} \n blocks: {blocks}") + logger.warning("Trying again without blocks that have urls") + + if not blocks: + raise e + + blocks_without_urls = [ + block for block in blocks if not _check_for_url_in_block(block) + ] + blocks_without_urls.append(_build_error_block(str(e))) + + # Try again wtihout blocks containing url + response = slack_call( + channel=channel, + text=text, + blocks=blocks_without_urls, + thread_ts=thread_ts, + metadata=metadata, + unfurl_links=unfurl, + unfurl_media=unfurl, + ) + + message_ids.append(response["message_ts"]) + else: + slack_call = make_slack_api_rate_limited(client.chat_postEphemeral) + for receiver in receiver_ids: + try: + response = slack_call( + channel=channel, + user=receiver, + text=text, + blocks=blocks, + thread_ts=thread_ts, + metadata=metadata, + unfurl_links=unfurl, + unfurl_media=unfurl, + ) + except Exception as e: + logger.warning(f"Failed to post message: {e} \n blocks: {blocks}") + logger.warning("Trying again without blocks that have urls") + + if not blocks: + raise e + + blocks_without_urls = [ + block for block in blocks if not _check_for_url_in_block(block) + ] + blocks_without_urls.append(_build_error_block(str(e))) + + # Try again wtihout blocks containing url + response = slack_call( + channel=channel, + user=receiver, + text=text, + blocks=blocks_without_urls, + thread_ts=thread_ts, + metadata=metadata, + unfurl_links=unfurl, + unfurl_media=unfurl, + ) + message_ids.append(response["message_ts"]) return message_ids diff --git a/backend/danswer/db/credentials.py b/backend/danswer/db/credentials.py index 5f4b83b4791..3ee165b34d0 100644 --- a/backend/danswer/db/credentials.py +++ b/backend/danswer/db/credentials.py @@ -20,7 +20,6 @@ from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.server.documents.models import CredentialBase -from danswer.server.documents.models import CredentialDataUpdateRequest from danswer.utils.logger import setup_logger @@ -262,7 +261,8 @@ def _cleanup_credential__user_group_relationships__no_commit( def alter_credential( credential_id: int, - credential_data: CredentialDataUpdateRequest, + name: str, + credential_json: dict[str, Any], user: User, db_session: Session, ) -> Credential | None: @@ -272,11 +272,13 @@ def alter_credential( if credential is None: return None - credential.name = credential_data.name + credential.name = name - # Update only the keys present in credential_data.credential_json - for key, value in credential_data.credential_json.items(): - credential.credential_json[key] = value + # Assign a new dictionary to credential.credential_json + credential.credential_json = { + **credential.credential_json, + **credential_json, + } credential.user_id = user.id if user is not None else None db_session.commit() @@ -309,8 +311,8 @@ def update_credential_json( credential = fetch_credential_by_id(credential_id, user, db_session) if credential is None: return None - credential.credential_json = credential_json + credential.credential_json = credential_json db_session.commit() return credential diff --git a/backend/danswer/db/input_prompt.py b/backend/danswer/db/input_prompt.py deleted file mode 100644 index efa54d986a1..00000000000 --- a/backend/danswer/db/input_prompt.py +++ /dev/null @@ -1,202 +0,0 @@ -from uuid import UUID - -from fastapi import HTTPException -from sqlalchemy import select -from sqlalchemy.orm import Session - -from danswer.db.models import InputPrompt -from danswer.db.models import User -from danswer.server.features.input_prompt.models import InputPromptSnapshot -from danswer.server.manage.models import UserInfo -from danswer.utils.logger import setup_logger - - -logger = setup_logger() - - -def insert_input_prompt_if_not_exists( - user: User | None, - input_prompt_id: int | None, - prompt: str, - content: str, - active: bool, - is_public: bool, - db_session: Session, - commit: bool = True, -) -> InputPrompt: - if input_prompt_id is not None: - input_prompt = ( - db_session.query(InputPrompt).filter_by(id=input_prompt_id).first() - ) - else: - query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt) - if user: - query = query.filter(InputPrompt.user_id == user.id) - else: - query = query.filter(InputPrompt.user_id.is_(None)) - input_prompt = query.first() - - if input_prompt is None: - input_prompt = InputPrompt( - id=input_prompt_id, - prompt=prompt, - content=content, - active=active, - is_public=is_public or user is None, - user_id=user.id if user else None, - ) - db_session.add(input_prompt) - - if commit: - db_session.commit() - - return input_prompt - - -def insert_input_prompt( - prompt: str, - content: str, - is_public: bool, - user: User | None, - db_session: Session, -) -> InputPrompt: - input_prompt = InputPrompt( - prompt=prompt, - content=content, - active=True, - is_public=is_public or user is None, - user_id=user.id if user is not None else None, - ) - db_session.add(input_prompt) - db_session.commit() - - return input_prompt - - -def update_input_prompt( - user: User | None, - input_prompt_id: int, - prompt: str, - content: str, - active: bool, - db_session: Session, -) -> InputPrompt: - input_prompt = db_session.scalar( - select(InputPrompt).where(InputPrompt.id == input_prompt_id) - ) - if input_prompt is None: - raise ValueError(f"No input prompt with id {input_prompt_id}") - - if not validate_user_prompt_authorization(user, input_prompt): - raise HTTPException(status_code=401, detail="You don't own this prompt") - - input_prompt.prompt = prompt - input_prompt.content = content - input_prompt.active = active - - db_session.commit() - return input_prompt - - -def validate_user_prompt_authorization( - user: User | None, input_prompt: InputPrompt -) -> bool: - prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt) - - if prompt.user_id is not None: - if user is None: - return False - - user_details = UserInfo.from_model(user) - if str(user_details.id) != str(prompt.user_id): - return False - return True - - -def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None: - input_prompt = db_session.scalar( - select(InputPrompt).where(InputPrompt.id == input_prompt_id) - ) - - if input_prompt is None: - raise ValueError(f"No input prompt with id {input_prompt_id}") - - if not input_prompt.is_public: - raise HTTPException(status_code=400, detail="This prompt is not public") - - db_session.delete(input_prompt) - db_session.commit() - - -def remove_input_prompt( - user: User | None, input_prompt_id: int, db_session: Session -) -> None: - input_prompt = db_session.scalar( - select(InputPrompt).where(InputPrompt.id == input_prompt_id) - ) - if input_prompt is None: - raise ValueError(f"No input prompt with id {input_prompt_id}") - - if input_prompt.is_public: - raise HTTPException( - status_code=400, detail="Cannot delete public prompts with this method" - ) - - if not validate_user_prompt_authorization(user, input_prompt): - raise HTTPException(status_code=401, detail="You do not own this prompt") - - db_session.delete(input_prompt) - db_session.commit() - - -def fetch_input_prompt_by_id( - id: int, user_id: UUID | None, db_session: Session -) -> InputPrompt: - query = select(InputPrompt).where(InputPrompt.id == id) - - if user_id: - query = query.where( - (InputPrompt.user_id == user_id) | (InputPrompt.user_id is None) - ) - else: - # If no user_id is provided, only fetch prompts without a user_id (aka public) - query = query.where(InputPrompt.user_id == None) # noqa - - result = db_session.scalar(query) - - if result is None: - raise HTTPException(422, "No input prompt found") - - return result - - -def fetch_public_input_prompts( - db_session: Session, -) -> list[InputPrompt]: - query = select(InputPrompt).where(InputPrompt.is_public) - return list(db_session.scalars(query).all()) - - -def fetch_input_prompts_by_user( - db_session: Session, - user_id: UUID | None, - active: bool | None = None, - include_public: bool = False, -) -> list[InputPrompt]: - query = select(InputPrompt) - - if user_id is not None: - if include_public: - query = query.where( - (InputPrompt.user_id == user_id) | InputPrompt.is_public - ) - else: - query = query.where(InputPrompt.user_id == user_id) - - elif include_public: - query = query.where(InputPrompt.is_public) - - if active is not None: - query = query.where(InputPrompt.active == active) - - return list(db_session.scalars(query).all()) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index b9e0f7bc416..031500360b0 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -159,9 +159,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base): ) prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user") - input_prompts: Mapped[list["InputPrompt"]] = relationship( - "InputPrompt", back_populates="user" - ) # Personas owned by this user personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user") @@ -178,31 +175,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base): ) -class InputPrompt(Base): - __tablename__ = "inputprompt" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - prompt: Mapped[str] = mapped_column(String) - content: Mapped[str] = mapped_column(String) - active: Mapped[bool] = mapped_column(Boolean) - user: Mapped[User | None] = relationship("User", back_populates="input_prompts") - is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) - user_id: Mapped[UUID | None] = mapped_column( - ForeignKey("user.id", ondelete="CASCADE"), nullable=True - ) - - -class InputPrompt__User(Base): - __tablename__ = "inputprompt__user" - - input_prompt_id: Mapped[int] = mapped_column( - ForeignKey("inputprompt.id"), primary_key=True - ) - user_id: Mapped[UUID | None] = mapped_column( - ForeignKey("inputprompt.id"), primary_key=True - ) - - class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): pass @@ -596,6 +568,25 @@ class Connector(Base): list["DocumentByConnectorCredentialPair"] ] = relationship("DocumentByConnectorCredentialPair", back_populates="connector") + # synchronize this validation logic with RefreshFrequencySchema etc on front end + # until we have a centralized validation schema + + # TODO(rkuo): experiment with SQLAlchemy validators rather than manual checks + # https://docs.sqlalchemy.org/en/20/orm/mapped_attributes.html + def validate_refresh_freq(self) -> None: + if self.refresh_freq is not None: + if self.refresh_freq < 60: + raise ValueError( + "refresh_freq must be greater than or equal to 60 seconds." + ) + + def validate_prune_freq(self) -> None: + if self.prune_freq is not None: + if self.prune_freq < 86400: + raise ValueError( + "prune_freq must be greater than or equal to 86400 seconds." + ) + class Credential(Base): __tablename__ = "credential" diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index f8602ed5af7..ee97885b376 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -453,9 +453,9 @@ def upsert_persona( """ if persona_id is not None: - persona = db_session.query(Persona).filter_by(id=persona_id).first() + existing_persona = db_session.query(Persona).filter_by(id=persona_id).first() else: - persona = _get_persona_by_name( + existing_persona = _get_persona_by_name( persona_name=name, user=user, db_session=db_session ) @@ -481,62 +481,78 @@ def upsert_persona( prompts = None if prompt_ids is not None: prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all() - if not prompts and prompt_ids: - raise ValueError("prompts not found") + + if prompts is not None and len(prompts) == 0: + raise ValueError( + f"Invalid Persona config, no valid prompts " + f"specified. Specified IDs were: '{prompt_ids}'" + ) # ensure all specified tools are valid if tools: validate_persona_tools(tools) - if persona: + if existing_persona: # Built-in personas can only be updated through YAML configuration. # This ensures that core system personas are not modified unintentionally. - if persona.builtin_persona and not builtin_persona: + if existing_persona.builtin_persona and not builtin_persona: raise ValueError("Cannot update builtin persona with non-builtin.") # this checks if the user has permission to edit the persona - persona = fetch_persona_by_id( - db_session=db_session, persona_id=persona.id, user=user, get_editable=True + # will raise an Exception if the user does not have permission + existing_persona = fetch_persona_by_id( + db_session=db_session, + persona_id=existing_persona.id, + user=user, + get_editable=True, ) # The following update excludes `default`, `built-in`, and display priority. # Display priority is handled separately in the `display-priority` endpoint. # `default` and `built-in` properties can only be set when creating a persona. - persona.name = name - persona.description = description - persona.num_chunks = num_chunks - persona.chunks_above = chunks_above - persona.chunks_below = chunks_below - persona.llm_relevance_filter = llm_relevance_filter - persona.llm_filter_extraction = llm_filter_extraction - persona.recency_bias = recency_bias - persona.llm_model_provider_override = llm_model_provider_override - persona.llm_model_version_override = llm_model_version_override - persona.starter_messages = starter_messages - persona.deleted = False # Un-delete if previously deleted - persona.is_public = is_public - persona.icon_color = icon_color - persona.icon_shape = icon_shape + existing_persona.name = name + existing_persona.description = description + existing_persona.num_chunks = num_chunks + existing_persona.chunks_above = chunks_above + existing_persona.chunks_below = chunks_below + existing_persona.llm_relevance_filter = llm_relevance_filter + existing_persona.llm_filter_extraction = llm_filter_extraction + existing_persona.recency_bias = recency_bias + existing_persona.llm_model_provider_override = llm_model_provider_override + existing_persona.llm_model_version_override = llm_model_version_override + existing_persona.starter_messages = starter_messages + existing_persona.deleted = False # Un-delete if previously deleted + existing_persona.is_public = is_public + existing_persona.icon_color = icon_color + existing_persona.icon_shape = icon_shape if remove_image or uploaded_image_id: - persona.uploaded_image_id = uploaded_image_id - persona.is_visible = is_visible - persona.search_start_date = search_start_date - persona.category_id = category_id + existing_persona.uploaded_image_id = uploaded_image_id + existing_persona.is_visible = is_visible + existing_persona.search_start_date = search_start_date + existing_persona.category_id = category_id # Do not delete any associations manually added unless # a new updated list is provided if document_sets is not None: - persona.document_sets.clear() - persona.document_sets = document_sets or [] + existing_persona.document_sets.clear() + existing_persona.document_sets = document_sets or [] if prompts is not None: - persona.prompts.clear() - persona.prompts = prompts or [] + existing_persona.prompts.clear() + existing_persona.prompts = prompts if tools is not None: - persona.tools = tools or [] + existing_persona.tools = tools or [] + + persona = existing_persona else: - persona = Persona( + if not prompts: + raise ValueError( + "Invalid Persona config. " + "Must specify at least one prompt for a new persona." + ) + + new_persona = Persona( id=persona_id, user_id=user.id if user else None, is_public=is_public, @@ -549,7 +565,7 @@ def upsert_persona( llm_filter_extraction=llm_filter_extraction, recency_bias=recency_bias, builtin_persona=builtin_persona, - prompts=prompts or [], + prompts=prompts, document_sets=document_sets or [], llm_model_provider_override=llm_model_provider_override, llm_model_version_override=llm_model_version_override, @@ -564,8 +580,8 @@ def upsert_persona( is_default_persona=is_default_persona, category_id=category_id, ) - db_session.add(persona) - + db_session.add(new_persona) + persona = new_persona if commit: db_session.commit() else: diff --git a/backend/danswer/file_processing/extract_file_text.py b/backend/danswer/file_processing/extract_file_text.py index 428afd9ae52..58016e80d63 100644 --- a/backend/danswer/file_processing/extract_file_text.py +++ b/backend/danswer/file_processing/extract_file_text.py @@ -70,7 +70,7 @@ def get_file_ext(file_path_or_name: str | Path) -> str: return extension -def check_file_ext_is_valid(ext: str) -> bool: +def is_valid_file_ext(ext: str) -> bool: return ext in VALID_FILE_EXTENSIONS @@ -364,7 +364,7 @@ def extract_file_text( elif file_name is not None: final_extension = get_file_ext(file_name) - if check_file_ext_is_valid(final_extension): + if is_valid_file_ext(final_extension): return extension_to_function.get(final_extension, file_io_to_text)(file) # Either the file somehow has no name or the extension is not one that we recognize diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index b1ee8f4d944..bace61cec80 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -1,4 +1,5 @@ import traceback +from collections.abc import Callable from functools import partial from http import HTTPStatus from typing import Protocol @@ -12,6 +13,7 @@ from danswer.access.models import DocumentAccess from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING from danswer.configs.app_configs import INDEXING_EXCEPTION_LIMIT +from danswer.configs.app_configs import MAX_DOCUMENT_CHARS from danswer.configs.constants import DEFAULT_BOOST from danswer.connectors.cross_connector_utils.miscellaneous_utils import ( get_experts_stores_representations, @@ -202,40 +204,13 @@ def index_doc_batch_with_handler( def index_doc_batch_prepare( - document_batch: list[Document], + documents: list[Document], index_attempt_metadata: IndexAttemptMetadata, db_session: Session, ignore_time_skip: bool = False, ) -> DocumentBatchPrepareContext | None: """Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc. This preceeds indexing it into the actual document index.""" - documents: list[Document] = [] - for document in document_batch: - empty_contents = not any(section.text.strip() for section in document.sections) - if ( - (not document.title or not document.title.strip()) - and not document.semantic_identifier.strip() - and empty_contents - ): - # Skip documents that have neither title nor content - # If the document doesn't have either, then there is no useful information in it - # This is again verified later in the pipeline after chunking but at that point there should - # already be no documents that are empty. - logger.warning( - f"Skipping document with ID {document.id} as it has neither title nor content." - ) - continue - - if document.title is not None and not document.title.strip() and empty_contents: - # The title is explicitly empty ("" and not None) and the document is empty - # so when building the chunk text representation, it will be empty and unuseable - logger.warning( - f"Skipping document with ID {document.id} as the chunks will be empty." - ) - continue - - documents.append(document) - # Create a trimmed list of docs that don't have a newer updated at # Shortcuts the time-consuming flow on connector index retries document_ids: list[str] = [document.id for document in documents] @@ -282,17 +257,64 @@ def index_doc_batch_prepare( ) +def filter_documents(document_batch: list[Document]) -> list[Document]: + documents: list[Document] = [] + for document in document_batch: + empty_contents = not any(section.text.strip() for section in document.sections) + if ( + (not document.title or not document.title.strip()) + and not document.semantic_identifier.strip() + and empty_contents + ): + # Skip documents that have neither title nor content + # If the document doesn't have either, then there is no useful information in it + # This is again verified later in the pipeline after chunking but at that point there should + # already be no documents that are empty. + logger.warning( + f"Skipping document with ID {document.id} as it has neither title nor content." + ) + continue + + if document.title is not None and not document.title.strip() and empty_contents: + # The title is explicitly empty ("" and not None) and the document is empty + # so when building the chunk text representation, it will be empty and unuseable + logger.warning( + f"Skipping document with ID {document.id} as the chunks will be empty." + ) + continue + + section_chars = sum(len(section.text) for section in document.sections) + if ( + MAX_DOCUMENT_CHARS + and len(document.title or document.semantic_identifier) + section_chars + > MAX_DOCUMENT_CHARS + ): + # Skip documents that are too long, later on there are more memory intensive steps done on the text + # and the container will run out of memory and crash. Several other checks are included upstream but + # those are at the connector level so a catchall is still needed. + # Assumption here is that files that are that long, are generated files and not the type users + # generally care for. + logger.warning( + f"Skipping document with ID {document.id} as it is too long." + ) + continue + + documents.append(document) + return documents + + @log_function_time(debug_only=True) def index_doc_batch( *, + document_batch: list[Document], chunker: Chunker, embedder: IndexingEmbedder, document_index: DocumentIndex, - document_batch: list[Document], index_attempt_metadata: IndexAttemptMetadata, db_session: Session, ignore_time_skip: bool = False, tenant_id: str | None = None, + filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents, ) -> tuple[int, int]: """Takes different pieces of the indexing pipeline and applies it to a batch of documents Note that the documents should already be batched at this point so that it does not inflate the @@ -309,8 +331,11 @@ def index_doc_batch( is_public=False, ) + logger.debug("Filtering Documents") + filtered_documents = filter_fnc(document_batch) + ctx = index_doc_batch_prepare( - document_batch=document_batch, + documents=filtered_documents, index_attempt_metadata=index_attempt_metadata, ignore_time_skip=ignore_time_skip, db_session=db_session, diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index f4b09d261fd..88b8f0396d5 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -268,12 +268,16 @@ def __init__( # NOTE: have to set these as environment variables for Litellm since # not all are able to passed in but they always support them set as env - # variables + # variables. We'll also try passing them in, since litellm just ignores + # addtional kwargs (and some kwargs MUST be passed in rather than set as + # env variables) if custom_config: for k, v in custom_config.items(): os.environ[k] = v model_kwargs = model_kwargs or {} + if custom_config: + model_kwargs.update(custom_config) if extra_headers: model_kwargs.update({"extra_headers": extra_headers}) if extra_body: diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 41c6592c6e8..47175c8d9fd 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -1,5 +1,4 @@ import copy -import io import json from collections.abc import Callable from collections.abc import Iterator @@ -7,7 +6,6 @@ from typing import cast import litellm # type: ignore -import pandas as pd import tiktoken from langchain.prompts.base import StringPromptValue from langchain.prompts.chat import ChatPromptValue @@ -100,53 +98,32 @@ def litellm_exception_to_error_msg( return error_msg -# Processes CSV files to show the first 5 rows and max_columns (default 40) columns -def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str: - df = pd.read_csv(io.StringIO(file.content.decode("utf-8"))) - - csv_preview = df.head().to_string(max_cols=max_columns) - - file_name_section = ( - f"CSV FILE NAME: {file.filename}\n" - if file.filename - else "CSV FILE (NO NAME PROVIDED):\n" - ) - return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n" - - def _build_content( message: str, files: list[InMemoryChatFile] | None = None, ) -> str: """Applies all non-image files.""" - text_files = ( - [file for file in files if file.file_type == ChatFileType.PLAIN_TEXT] - if files - else None - ) + if not files: + return message - csv_files = ( - [file for file in files if file.file_type == ChatFileType.CSV] - if files - else None - ) + text_files = [ + file + for file in files + if file.file_type in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV) + ] - if not text_files and not csv_files: + if not text_files: return message final_message_with_files = "FILES:\n\n" - for file in text_files or []: + for file in text_files: file_content = file.content.decode("utf-8") file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else "" final_message_with_files += ( f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n" ) - for file in csv_files or []: - final_message_with_files += _process_csv_file(file) - - final_message_with_files += message - return final_message_with_files + return final_message_with_files + message def build_content_with_imgs( diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 60e06f99e70..49fdfadfbf3 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -52,12 +52,9 @@ from danswer.server.documents.credential import router as credential_router from danswer.server.documents.document import router as document_router from danswer.server.documents.indexing import router as indexing_router +from danswer.server.documents.standard_oauth import router as standard_oauth_router from danswer.server.features.document_set.api import router as document_set_router from danswer.server.features.folder.api import router as folder_router -from danswer.server.features.input_prompt.api import ( - admin_router as admin_input_prompt_router, -) -from danswer.server.features.input_prompt.api import basic_router as input_prompt_router from danswer.server.features.notifications.api import router as notification_router from danswer.server.features.persona.api import admin_router as admin_persona_router from danswer.server.features.persona.api import basic_router as persona_router @@ -259,8 +256,6 @@ def get_application() -> FastAPI: ) include_router_with_global_prefix_prepended(application, persona_router) include_router_with_global_prefix_prepended(application, admin_persona_router) - include_router_with_global_prefix_prepended(application, input_prompt_router) - include_router_with_global_prefix_prepended(application, admin_input_prompt_router) include_router_with_global_prefix_prepended(application, notification_router) include_router_with_global_prefix_prepended(application, prompt_router) include_router_with_global_prefix_prepended(application, tool_router) @@ -282,8 +277,9 @@ def get_application() -> FastAPI: application, get_full_openai_assistants_api_router() ) include_router_with_global_prefix_prepended(application, long_term_logs_router) - include_router_with_global_prefix_prepended(application, oauth_router) + include_router_with_global_prefix_prepended(application, standard_oauth_router) include_router_with_global_prefix_prepended(application, api_key_router) + include_router_with_global_prefix_prepended(application, oauth_router) if AUTH_TYPE == AuthType.DISABLED: # Server logs this during auth setup verification step diff --git a/backend/danswer/seeding/input_prompts.yaml b/backend/danswer/seeding/input_prompts.yaml deleted file mode 100644 index cc7dbe78ea1..00000000000 --- a/backend/danswer/seeding/input_prompts.yaml +++ /dev/null @@ -1,24 +0,0 @@ -input_prompts: - - id: -5 - prompt: "Elaborate" - content: "Elaborate on the above, give me a more in depth explanation." - active: true - is_public: true - - - id: -4 - prompt: "Reword" - content: "Help me rewrite the following politely and concisely for professional communication:\n" - active: true - is_public: true - - - id: -3 - prompt: "Email" - content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n" - active: true - is_public: true - - - id: -2 - prompt: "Debug" - content: "Provide step-by-step troubleshooting instructions for the following issue:\n" - active: true - is_public: true diff --git a/backend/danswer/seeding/load_docs.py b/backend/danswer/seeding/load_docs.py index 1567f7f6bbb..5fe591423f0 100644 --- a/backend/danswer/seeding/load_docs.py +++ b/backend/danswer/seeding/load_docs.py @@ -196,7 +196,7 @@ def seed_initial_documents( docs, chunks = _create_indexable_chunks(processed_docs, tenant_id) index_doc_batch_prepare( - document_batch=docs, + documents=docs, index_attempt_metadata=IndexAttemptMetadata( connector_id=connector_id, credential_id=PUBLIC_CREDENTIAL_ID, diff --git a/backend/danswer/seeding/load_yamls.py b/backend/danswer/seeding/load_yamls.py index c93851dd5c0..6efa1efd368 100644 --- a/backend/danswer/seeding/load_yamls.py +++ b/backend/danswer/seeding/load_yamls.py @@ -1,13 +1,11 @@ import yaml from sqlalchemy.orm import Session -from danswer.configs.chat_configs import INPUT_PROMPT_YAML from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.chat_configs import PERSONAS_YAML from danswer.configs.chat_configs import PROMPTS_YAML from danswer.context.search.enums import RecencyBiasSetting from danswer.db.document_set import get_or_create_document_set_by_name -from danswer.db.input_prompt import insert_input_prompt_if_not_exists from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.db.models import Persona from danswer.db.models import Prompt as PromptDBModel @@ -79,6 +77,9 @@ def load_personas_from_yaml( if prompts: prompt_ids = [prompt.id for prompt in prompts if prompt is not None] + if not prompt_ids: + raise ValueError("Invalid Persona config, no prompts exist") + p_id = persona.get("id") tool_ids = [] @@ -123,45 +124,24 @@ def load_personas_from_yaml( tool_ids=tool_ids, builtin_persona=True, is_public=True, - display_priority=existing_persona.display_priority - if existing_persona is not None - else persona.get("display_priority"), - is_visible=existing_persona.is_visible - if existing_persona is not None - else persona.get("is_visible"), + display_priority=( + existing_persona.display_priority + if existing_persona is not None + else persona.get("display_priority") + ), + is_visible=( + existing_persona.is_visible + if existing_persona is not None + else persona.get("is_visible") + ), db_session=db_session, ) -def load_input_prompts_from_yaml( - db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML -) -> None: - with open(input_prompts_yaml, "r") as file: - data = yaml.safe_load(file) - - all_input_prompts = data.get("input_prompts", []) - for input_prompt in all_input_prompts: - # If these prompts are deleted (which is a hard delete in the DB), on server startup - # they will be recreated, but the user can always just deactivate them, just a light inconvenience - - insert_input_prompt_if_not_exists( - user=None, - input_prompt_id=input_prompt.get("id"), - prompt=input_prompt["prompt"], - content=input_prompt["content"], - is_public=input_prompt["is_public"], - active=input_prompt.get("active", True), - db_session=db_session, - commit=True, - ) - - def load_chat_yamls( db_session: Session, prompt_yaml: str = PROMPTS_YAML, personas_yaml: str = PERSONAS_YAML, - input_prompts_yaml: str = INPUT_PROMPT_YAML, ) -> None: load_prompts_from_yaml(db_session, prompt_yaml) load_personas_from_yaml(db_session, personas_yaml) - load_input_prompts_from_yaml(db_session, input_prompts_yaml) diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 46bdb2078c4..424ae256462 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -45,6 +45,7 @@ from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_pool import get_redis_client from danswer.server.documents.models import CCPairFullInfo +from danswer.server.documents.models import CCPropertyUpdateRequest from danswer.server.documents.models import CCStatusUpdateRequest from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorCredentialPairMetadata @@ -308,6 +309,46 @@ def update_cc_pair_name( raise HTTPException(status_code=400, detail="Name must be unique") +@router.put("/admin/cc-pair/{cc_pair_id}/property") +def update_cc_pair_property( + cc_pair_id: int, + update_request: CCPropertyUpdateRequest, # in seconds + user: User | None = Depends(current_curator_or_admin_user), + db_session: Session = Depends(get_session), +) -> StatusResponse[int]: + cc_pair = get_connector_credential_pair_from_id( + cc_pair_id=cc_pair_id, + db_session=db_session, + user=user, + get_editable=True, + ) + if not cc_pair: + raise HTTPException( + status_code=400, detail="CC Pair not found for current user's permissions" + ) + + # Can we centralize logic for updating connector properties + # so that we don't need to manually validate everywhere? + if update_request.name == "refresh_frequency": + cc_pair.connector.refresh_freq = int(update_request.value) + cc_pair.connector.validate_refresh_freq() + db_session.commit() + + msg = "Refresh frequency updated successfully" + elif update_request.name == "pruning_frequency": + cc_pair.connector.prune_freq = int(update_request.value) + cc_pair.connector.validate_prune_freq() + db_session.commit() + + msg = "Pruning frequency updated successfully" + else: + raise HTTPException( + status_code=400, detail=f"Property name {update_request.name} is not valid." + ) + + return StatusResponse(success=True, message=msg, data=cc_pair_id) + + @router.get("/admin/cc-pair/{cc_pair_id}/last_pruned") def get_cc_pair_last_pruned( cc_pair_id: int, diff --git a/backend/danswer/server/documents/credential.py b/backend/danswer/server/documents/credential.py index 3ae60a5d69a..1cd118cd938 100644 --- a/backend/danswer/server/documents/credential.py +++ b/backend/danswer/server/documents/credential.py @@ -178,7 +178,13 @@ def update_credential_data( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> CredentialBase: - credential = alter_credential(credential_id, credential_update, user, db_session) + credential = alter_credential( + credential_id, + credential_update.name, + credential_update.credential_json, + user, + db_session, + ) if credential is None: raise HTTPException( diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index 7b523d929ec..2c4f509444f 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -364,6 +364,11 @@ class RunConnectorRequest(BaseModel): from_beginning: bool = False +class CCPropertyUpdateRequest(BaseModel): + name: str + value: str + + """Connectors Models""" diff --git a/backend/danswer/server/documents/standard_oauth.py b/backend/danswer/server/documents/standard_oauth.py new file mode 100644 index 00000000000..ddc85761914 --- /dev/null +++ b/backend/danswer/server/documents/standard_oauth.py @@ -0,0 +1,142 @@ +import uuid +from typing import Annotated +from typing import cast + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Query +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.configs.app_configs import WEB_DOMAIN +from danswer.configs.constants import DocumentSource +from danswer.connectors.interfaces import OAuthConnector +from danswer.db.credentials import create_credential +from danswer.db.engine import get_current_tenant_id +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.redis.redis_pool import get_redis_client +from danswer.server.documents.models import CredentialBase +from danswer.utils.logger import setup_logger +from danswer.utils.subclasses import find_all_subclasses_in_dir + +logger = setup_logger() + +router = APIRouter(prefix="/connector/oauth") + +_OAUTH_STATE_KEY_FMT = "oauth_state:{state}" +_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes + +# Cache for OAuth connectors, populated at module load time +_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {} + + +def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]: + """Walk through the connectors package to find all OAuthConnector implementations""" + global _OAUTH_CONNECTORS + if _OAUTH_CONNECTORS: # Return cached connectors if already discovered + return _OAUTH_CONNECTORS + + oauth_connectors = find_all_subclasses_in_dir( + cast(type[OAuthConnector], OAuthConnector), "danswer.connectors" + ) + + _OAUTH_CONNECTORS = {cls.oauth_id(): cls for cls in oauth_connectors} + return _OAUTH_CONNECTORS + + +# Discover OAuth connectors at module load time +_discover_oauth_connectors() + + +class AuthorizeResponse(BaseModel): + redirect_url: str + + +@router.get("/authorize/{source}") +def oauth_authorize( + source: DocumentSource, + desired_return_url: Annotated[str | None, Query()] = None, + _: User = Depends(current_user), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> AuthorizeResponse: + """Initiates the OAuth flow by redirecting to the provider's auth page""" + oauth_connectors = _discover_oauth_connectors() + + if source not in oauth_connectors: + raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}") + + connector_cls = oauth_connectors[source] + base_url = WEB_DOMAIN + + # store state in redis + if not desired_return_url: + desired_return_url = f"{base_url}/admin/connectors/{source}?step=0" + redis_client = get_redis_client(tenant_id=tenant_id) + state = str(uuid.uuid4()) + redis_client.set( + _OAUTH_STATE_KEY_FMT.format(state=state), + desired_return_url, + ex=_OAUTH_STATE_EXPIRATION_SECONDS, + ) + + return AuthorizeResponse( + redirect_url=connector_cls.oauth_authorization_url(base_url, state) + ) + + +class CallbackResponse(BaseModel): + redirect_url: str + + +@router.get("/callback/{source}") +def oauth_callback( + source: DocumentSource, + code: Annotated[str, Query()], + state: Annotated[str, Query()], + db_session: Session = Depends(get_session), + user: User = Depends(current_user), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> CallbackResponse: + """Handles the OAuth callback and exchanges the code for tokens""" + oauth_connectors = _discover_oauth_connectors() + + if source not in oauth_connectors: + raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}") + + connector_cls = oauth_connectors[source] + + # get state from redis + redis_client = get_redis_client(tenant_id=tenant_id) + original_url_bytes = cast( + bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state)) + ) + if not original_url_bytes: + raise HTTPException(status_code=400, detail="Invalid OAuth state") + original_url = original_url_bytes.decode("utf-8") + + token_info = connector_cls.oauth_code_to_token(code) + + # Create a new credential with the token info + credential_data = CredentialBase( + credential_json=token_info, + admin_public=True, # Or based on some logic/parameter + source=source, + name=f"{source.title()} OAuth Credential", + ) + + credential = create_credential( + credential_data=credential_data, + user=user, + db_session=db_session, + ) + + return CallbackResponse( + redirect_url=( + f"{original_url}?credentialId={credential.id}" + if "?" not in original_url + else f"{original_url}&credentialId={credential.id}" + ) + ) diff --git a/backend/danswer/server/features/input_prompt/__init__.py b/backend/danswer/server/features/input_prompt/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/backend/danswer/server/features/input_prompt/api.py b/backend/danswer/server/features/input_prompt/api.py deleted file mode 100644 index 58eecd0093d..00000000000 --- a/backend/danswer/server/features/input_prompt/api.py +++ /dev/null @@ -1,134 +0,0 @@ -from fastapi import APIRouter -from fastapi import Depends -from fastapi import HTTPException -from sqlalchemy.orm import Session - -from danswer.auth.users import current_admin_user -from danswer.auth.users import current_user -from danswer.db.engine import get_session -from danswer.db.input_prompt import fetch_input_prompt_by_id -from danswer.db.input_prompt import fetch_input_prompts_by_user -from danswer.db.input_prompt import fetch_public_input_prompts -from danswer.db.input_prompt import insert_input_prompt -from danswer.db.input_prompt import remove_input_prompt -from danswer.db.input_prompt import remove_public_input_prompt -from danswer.db.input_prompt import update_input_prompt -from danswer.db.models import User -from danswer.server.features.input_prompt.models import CreateInputPromptRequest -from danswer.server.features.input_prompt.models import InputPromptSnapshot -from danswer.server.features.input_prompt.models import UpdateInputPromptRequest -from danswer.utils.logger import setup_logger - -logger = setup_logger() - -basic_router = APIRouter(prefix="/input_prompt") -admin_router = APIRouter(prefix="/admin/input_prompt") - - -@basic_router.get("") -def list_input_prompts( - user: User | None = Depends(current_user), - include_public: bool = False, - db_session: Session = Depends(get_session), -) -> list[InputPromptSnapshot]: - user_prompts = fetch_input_prompts_by_user( - user_id=user.id if user is not None else None, - db_session=db_session, - include_public=include_public, - ) - return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts] - - -@basic_router.get("/{input_prompt_id}") -def get_input_prompt( - input_prompt_id: int, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> InputPromptSnapshot: - input_prompt = fetch_input_prompt_by_id( - id=input_prompt_id, - user_id=user.id if user is not None else None, - db_session=db_session, - ) - return InputPromptSnapshot.from_model(input_prompt=input_prompt) - - -@basic_router.post("") -def create_input_prompt( - create_input_prompt_request: CreateInputPromptRequest, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> InputPromptSnapshot: - input_prompt = insert_input_prompt( - prompt=create_input_prompt_request.prompt, - content=create_input_prompt_request.content, - is_public=create_input_prompt_request.is_public, - user=user, - db_session=db_session, - ) - return InputPromptSnapshot.from_model(input_prompt) - - -@basic_router.patch("/{input_prompt_id}") -def patch_input_prompt( - input_prompt_id: int, - update_input_prompt_request: UpdateInputPromptRequest, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> InputPromptSnapshot: - try: - updated_input_prompt = update_input_prompt( - user=user, - input_prompt_id=input_prompt_id, - prompt=update_input_prompt_request.prompt, - content=update_input_prompt_request.content, - active=update_input_prompt_request.active, - db_session=db_session, - ) - except ValueError as e: - error_msg = "Error occurred while updated input prompt" - logger.warn(f"{error_msg}. Stack trace: {e}") - raise HTTPException(status_code=404, detail=error_msg) - - return InputPromptSnapshot.from_model(updated_input_prompt) - - -@basic_router.delete("/{input_prompt_id}") -def delete_input_prompt( - input_prompt_id: int, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> None: - try: - remove_input_prompt(user, input_prompt_id, db_session) - - except ValueError as e: - error_msg = "Error occurred while deleting input prompt" - logger.warn(f"{error_msg}. Stack trace: {e}") - raise HTTPException(status_code=404, detail=error_msg) - - -@admin_router.delete("/{input_prompt_id}") -def delete_public_input_prompt( - input_prompt_id: int, - _: User | None = Depends(current_admin_user), - db_session: Session = Depends(get_session), -) -> None: - try: - remove_public_input_prompt(input_prompt_id, db_session) - - except ValueError as e: - error_msg = "Error occurred while deleting input prompt" - logger.warn(f"{error_msg}. Stack trace: {e}") - raise HTTPException(status_code=404, detail=error_msg) - - -@admin_router.get("") -def list_public_input_prompts( - _: User | None = Depends(current_admin_user), - db_session: Session = Depends(get_session), -) -> list[InputPromptSnapshot]: - user_prompts = fetch_public_input_prompts( - db_session=db_session, - ) - return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts] diff --git a/backend/danswer/server/features/input_prompt/models.py b/backend/danswer/server/features/input_prompt/models.py deleted file mode 100644 index 21ce2ba4e5b..00000000000 --- a/backend/danswer/server/features/input_prompt/models.py +++ /dev/null @@ -1,47 +0,0 @@ -from uuid import UUID - -from pydantic import BaseModel - -from danswer.db.models import InputPrompt -from danswer.utils.logger import setup_logger - -logger = setup_logger() - - -class CreateInputPromptRequest(BaseModel): - prompt: str - content: str - is_public: bool - - -class UpdateInputPromptRequest(BaseModel): - prompt: str - content: str - active: bool - - -class InputPromptResponse(BaseModel): - id: int - prompt: str - content: str - active: bool - - -class InputPromptSnapshot(BaseModel): - id: int - prompt: str - content: str - active: bool - user_id: UUID | None - is_public: bool - - @classmethod - def from_model(cls, input_prompt: InputPrompt) -> "InputPromptSnapshot": - return InputPromptSnapshot( - id=input_prompt.id, - prompt=input_prompt.prompt, - content=input_prompt.content, - active=input_prompt.active, - user_id=input_prompt.user_id, - is_public=input_prompt.is_public, - ) diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index 8d3955e0ff6..0e37fc89191 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -266,5 +266,7 @@ class FullModelVersionResponse(BaseModel): class AllUsersResponse(BaseModel): accepted: list[FullUserSnapshot] invited: list[InvitedUserSnapshot] + slack_users: list[FullUserSnapshot] accepted_pages: int invited_pages: int + slack_users_pages: int diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 33162b93430..75fd9dfe3a8 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -119,6 +119,7 @@ def set_user_role( def list_all_users( q: str | None = None, accepted_page: int | None = None, + slack_users_page: int | None = None, invited_page: int | None = None, user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), @@ -131,7 +132,12 @@ def list_all_users( for user in list_users(db_session, email_filter_string=q) if not is_api_key_email_address(user.email) ] - accepted_emails = {user.email for user in users} + + slack_users = [user for user in users if user.role == UserRole.SLACK_USER] + accepted_users = [user for user in users if user.role != UserRole.SLACK_USER] + + accepted_emails = {user.email for user in accepted_users} + slack_users_emails = {user.email for user in slack_users} invited_emails = get_invited_users() if q: invited_emails = [ @@ -139,10 +145,11 @@ def list_all_users( ] accepted_count = len(accepted_emails) + slack_users_count = len(slack_users_emails) invited_count = len(invited_emails) # If any of q, accepted_page, or invited_page is None, return all users - if accepted_page is None or invited_page is None: + if accepted_page is None or invited_page is None or slack_users_page is None: return AllUsersResponse( accepted=[ FullUserSnapshot( @@ -153,11 +160,23 @@ def list_all_users( UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED ), ) - for user in users + for user in accepted_users + ], + slack_users=[ + FullUserSnapshot( + id=user.id, + email=user.email, + role=user.role, + status=( + UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED + ), + ) + for user in slack_users ], invited=[InvitedUserSnapshot(email=email) for email in invited_emails], accepted_pages=1, invited_pages=1, + slack_users_pages=1, ) # Otherwise, return paginated results @@ -169,13 +188,27 @@ def list_all_users( role=user.role, status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED, ) - for user in users + for user in accepted_users ][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE], + slack_users=[ + FullUserSnapshot( + id=user.id, + email=user.email, + role=user.role, + status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED, + ) + for user in slack_users + ][ + slack_users_page + * USERS_PAGE_SIZE : (slack_users_page + 1) + * USERS_PAGE_SIZE + ], invited=[InvitedUserSnapshot(email=email) for email in invited_emails][ invited_page * USERS_PAGE_SIZE : (invited_page + 1) * USERS_PAGE_SIZE ], accepted_pages=accepted_count // USERS_PAGE_SIZE + 1, invited_pages=invited_count // USERS_PAGE_SIZE + 1, + slack_users_pages=slack_users_count // USERS_PAGE_SIZE + 1, ) diff --git a/backend/danswer/utils/subclasses.py b/backend/danswer/utils/subclasses.py new file mode 100644 index 00000000000..72408f98b08 --- /dev/null +++ b/backend/danswer/utils/subclasses.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import importlib +import os +import pkgutil +import sys +from types import ModuleType +from typing import List +from typing import Type +from typing import TypeVar + +T = TypeVar("T") + + +def import_all_modules_from_dir(dir_path: str) -> List[ModuleType]: + """ + Imports all modules found in the given directory and its subdirectories, + returning a list of imported module objects. + """ + dir_path = os.path.abspath(dir_path) + + if dir_path not in sys.path: + sys.path.insert(0, dir_path) + + imported_modules: List[ModuleType] = [] + + for _, package_name, _ in pkgutil.walk_packages([dir_path]): + try: + module = importlib.import_module(package_name) + imported_modules.append(module) + except Exception as e: + # Handle or log exceptions as needed + print(f"Could not import {package_name}: {e}") + + return imported_modules + + +def all_subclasses(cls: Type[T]) -> List[Type[T]]: + """ + Recursively find all subclasses of the given class. + """ + direct_subs = cls.__subclasses__() + result: List[Type[T]] = [] + for subclass in direct_subs: + result.append(subclass) + # Extend the result by recursively calling all_subclasses + result.extend(all_subclasses(subclass)) + return result + + +def find_all_subclasses_in_dir(parent_class: Type[T], directory: str) -> List[Type[T]]: + """ + Imports all modules from the given directory (and subdirectories), + then returns all classes that are subclasses of parent_class. + + :param parent_class: The class to find subclasses of. + :param directory: The directory to search for subclasses. + :return: A list of all subclasses of parent_class found in the directory. + """ + # First import all modules to ensure classes are loaded into memory + import_all_modules_from_dir(directory) + + # Gather all subclasses of the given parent class + subclasses = all_subclasses(parent_class) + return subclasses + + +# Example usage: +if __name__ == "__main__": + + class Animal: + pass + + # Suppose "mymodules" contains files that define classes inheriting from Animal + found_subclasses = find_all_subclasses_in_dir(Animal, "mymodules") + for sc in found_subclasses: + print("Found subclass:", sc.__name__) diff --git a/backend/ee/danswer/db/external_perm.py b/backend/ee/danswer/db/external_perm.py index 4df635788fc..7121130e3eb 100644 --- a/backend/ee/danswer/db/external_perm.py +++ b/backend/ee/danswer/db/external_perm.py @@ -76,7 +76,7 @@ def replace_user__ext_group_for_cc_pair( new_external_permissions = [] for external_group in group_defs: for user_email in external_group.user_emails: - user_id = email_id_map.get(user_email) + user_id = email_id_map.get(user_email.lower()) if user_id is None: logger.warning( f"User in group {external_group.id}" diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index b4920815a83..f1081fe5f37 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -132,13 +132,18 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> if personas: logger.notice("Seeding Personas") for persona in personas: + if not persona.prompt_ids: + raise ValueError( + f"Invalid Persona with name {persona.name}; no prompts exist" + ) + upsert_persona( user=None, # Seeding is done as admin name=persona.name, description=persona.description, - num_chunks=persona.num_chunks - if persona.num_chunks is not None - else 0.0, + num_chunks=( + persona.num_chunks if persona.num_chunks is not None else 0.0 + ), llm_relevance_filter=persona.llm_relevance_filter, llm_filter_extraction=persona.llm_filter_extraction, recency_bias=RecencyBiasSetting.AUTO, diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index c72be9e4ac3..ef04c0a7f05 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -1,4 +1,6 @@ +import asyncio import json +from types import TracebackType from typing import cast from typing import Optional @@ -6,11 +8,11 @@ import openai import vertexai # type: ignore import voyageai # type: ignore -from cohere import Client as CohereClient +from cohere import AsyncClient as CohereAsyncClient from fastapi import APIRouter from fastapi import HTTPException from google.oauth2 import service_account # type: ignore -from litellm import embedding +from litellm import aembedding from litellm.exceptions import RateLimitError from retry import retry from sentence_transformers import CrossEncoder # type: ignore @@ -63,22 +65,31 @@ def __init__( provider: EmbeddingProvider, api_url: str | None = None, api_version: str | None = None, + timeout: int = API_BASED_EMBEDDING_TIMEOUT, ) -> None: self.provider = provider self.api_key = api_key self.api_url = api_url self.api_version = api_version + self.timeout = timeout + self.http_client = httpx.AsyncClient(timeout=timeout) + self._closed = False - def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: + async def _embed_openai( + self, texts: list[str], model: str | None + ) -> list[Embedding]: if not model: model = DEFAULT_OPENAI_MODEL - client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT) + # Use the OpenAI specific timeout for this one + client = openai.AsyncOpenAI( + api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT + ) final_embeddings: list[Embedding] = [] try: for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN): - response = client.embeddings.create(input=text_batch, model=model) + response = await client.embeddings.create(input=text_batch, model=model) final_embeddings.extend( [embedding.embedding for embedding in response.data] ) @@ -93,19 +104,19 @@ def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: logger.error(error_string) raise RuntimeError(error_string) - def _embed_cohere( + async def _embed_cohere( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: model = DEFAULT_COHERE_MODEL - client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT) + client = CohereAsyncClient(api_key=self.api_key) final_embeddings: list[Embedding] = [] for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN): # Does not use the same tokenizer as the Danswer API server but it's approximately the same # empirically it's only off by a very few tokens so it's not a big deal - response = client.embed( + response = await client.embed( texts=text_batch, model=model, input_type=embedding_type, @@ -114,26 +125,29 @@ def _embed_cohere( final_embeddings.extend(cast(list[Embedding], response.embeddings)) return final_embeddings - def _embed_voyage( + async def _embed_voyage( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: model = DEFAULT_VOYAGE_MODEL - client = voyageai.Client( + client = voyageai.AsyncClient( api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT ) - response = client.embed( - texts, + response = await client.embed( + texts=texts, model=model, input_type=embedding_type, truncation=True, ) + return response.embeddings - def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]: - response = embedding( + async def _embed_azure( + self, texts: list[str], model: str | None + ) -> list[Embedding]: + response = await aembedding( model=model, input=texts, timeout=API_BASED_EMBEDDING_TIMEOUT, @@ -142,10 +156,9 @@ def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]: api_version=self.api_version, ) embeddings = [embedding["embedding"] for embedding in response.data] - return embeddings - def _embed_vertex( + async def _embed_vertex( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: @@ -158,7 +171,7 @@ def _embed_vertex( vertexai.init(project=project_id, credentials=credentials) client = TextEmbeddingModel.from_pretrained(model) - embeddings = client.get_embeddings( + embeddings = await client.get_embeddings_async( [ TextEmbeddingInput( text, @@ -166,11 +179,11 @@ def _embed_vertex( ) for text in texts ], - auto_truncate=True, # Also this is default + auto_truncate=True, # This is the default ) return [embedding.values for embedding in embeddings] - def _embed_litellm_proxy( + async def _embed_litellm_proxy( self, texts: list[str], model_name: str | None ) -> list[Embedding]: if not model_name: @@ -183,22 +196,20 @@ def _embed_litellm_proxy( {} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"} ) - with httpx.Client() as client: - response = client.post( - self.api_url, - json={ - "model": model_name, - "input": texts, - }, - headers=headers, - timeout=API_BASED_EMBEDDING_TIMEOUT, - ) - response.raise_for_status() - result = response.json() - return [embedding["embedding"] for embedding in result["data"]] + response = await self.http_client.post( + self.api_url, + json={ + "model": model_name, + "input": texts, + }, + headers=headers, + ) + response.raise_for_status() + result = response.json() + return [embedding["embedding"] for embedding in result["data"]] @retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY) - def embed( + async def embed( self, *, texts: list[str], @@ -207,19 +218,19 @@ def embed( deployment_name: str | None = None, ) -> list[Embedding]: if self.provider == EmbeddingProvider.OPENAI: - return self._embed_openai(texts, model_name) + return await self._embed_openai(texts, model_name) elif self.provider == EmbeddingProvider.AZURE: - return self._embed_azure(texts, f"azure/{deployment_name}") + return await self._embed_azure(texts, f"azure/{deployment_name}") elif self.provider == EmbeddingProvider.LITELLM: - return self._embed_litellm_proxy(texts, model_name) + return await self._embed_litellm_proxy(texts, model_name) embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) if self.provider == EmbeddingProvider.COHERE: - return self._embed_cohere(texts, model_name, embedding_type) + return await self._embed_cohere(texts, model_name, embedding_type) elif self.provider == EmbeddingProvider.VOYAGE: - return self._embed_voyage(texts, model_name, embedding_type) + return await self._embed_voyage(texts, model_name, embedding_type) elif self.provider == EmbeddingProvider.GOOGLE: - return self._embed_vertex(texts, model_name, embedding_type) + return await self._embed_vertex(texts, model_name, embedding_type) else: raise ValueError(f"Unsupported provider: {self.provider}") @@ -233,6 +244,30 @@ def create( logger.debug(f"Creating Embedding instance for provider: {provider}") return CloudEmbedding(api_key, provider, api_url, api_version) + async def aclose(self) -> None: + """Explicitly close the client.""" + if not self._closed: + await self.http_client.aclose() + self._closed = True + + async def __aenter__(self) -> "CloudEmbedding": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.aclose() + + def __del__(self) -> None: + """Finalizer to warn about unclosed clients.""" + if not self._closed: + logger.warning( + "CloudEmbedding was not properly closed. Use 'async with' or call aclose()" + ) + def get_embedding_model( model_name: str, @@ -242,9 +277,6 @@ def get_embedding_model( global _GLOBAL_MODELS_DICT # A dictionary to store models - if _GLOBAL_MODELS_DICT is None: - _GLOBAL_MODELS_DICT = {} - if model_name not in _GLOBAL_MODELS_DICT: logger.notice(f"Loading {model_name}") # Some model architectures that aren't built into the Transformers or Sentence @@ -275,7 +307,7 @@ def get_local_reranking_model( @simple_log_function_time() -def embed_text( +async def embed_text( texts: list[str], text_type: EmbedTextType, model_name: str | None, @@ -311,18 +343,18 @@ def embed_text( "Cloud models take an explicit text type instead." ) - cloud_model = CloudEmbedding( + async with CloudEmbedding( api_key=api_key, provider=provider_type, api_url=api_url, api_version=api_version, - ) - embeddings = cloud_model.embed( - texts=texts, - model_name=model_name, - deployment_name=deployment_name, - text_type=text_type, - ) + ) as cloud_model: + embeddings = await cloud_model.embed( + texts=texts, + model_name=model_name, + deployment_name=deployment_name, + text_type=text_type, + ) if any(embedding is None for embedding in embeddings): error_message = "Embeddings contain None values\n" @@ -338,8 +370,12 @@ def embed_text( local_model = get_embedding_model( model_name=model_name, max_context_length=max_context_length ) - embeddings_vectors = local_model.encode( - prefixed_texts, normalize_embeddings=normalize_embeddings + # Run CPU-bound embedding in a thread pool + embeddings_vectors = await asyncio.get_event_loop().run_in_executor( + None, + lambda: local_model.encode( + prefixed_texts, normalize_embeddings=normalize_embeddings + ), ) embeddings = [ embedding if isinstance(embedding, list) else embedding.tolist() @@ -357,27 +393,31 @@ def embed_text( @simple_log_function_time() -def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]: +async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]: cross_encoder = get_local_reranking_model(model_name) - return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore + # Run CPU-bound reranking in a thread pool + return await asyncio.get_event_loop().run_in_executor( + None, + lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore + ) -def cohere_rerank( +async def cohere_rerank( query: str, docs: list[str], model_name: str, api_key: str ) -> list[float]: - cohere_client = CohereClient(api_key=api_key) - response = cohere_client.rerank(query=query, documents=docs, model=model_name) + cohere_client = CohereAsyncClient(api_key=api_key) + response = await cohere_client.rerank(query=query, documents=docs, model=model_name) results = response.results sorted_results = sorted(results, key=lambda item: item.index) return [result.relevance_score for result in sorted_results] -def litellm_rerank( +async def litellm_rerank( query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None ) -> list[float]: headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} - with httpx.Client() as client: - response = client.post( + async with httpx.AsyncClient() as client: + response = await client.post( api_url, json={ "model": model_name, @@ -411,7 +451,7 @@ async def process_embed_request( else: prefix = None - embeddings = embed_text( + embeddings = await embed_text( texts=embed_request.texts, model_name=embed_request.model_name, deployment_name=embed_request.deployment_name, @@ -451,7 +491,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons try: if rerank_request.provider_type is None: - sim_scores = local_rerank( + sim_scores = await local_rerank( query=rerank_request.query, docs=rerank_request.documents, model_name=rerank_request.model_name, @@ -461,7 +501,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons if rerank_request.api_url is None: raise ValueError("API URL is required for LiteLLM reranking.") - sim_scores = litellm_rerank( + sim_scores = await litellm_rerank( query=rerank_request.query, docs=rerank_request.documents, api_url=rerank_request.api_url, @@ -474,7 +514,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons elif rerank_request.provider_type == RerankerProvider.COHERE: if rerank_request.api_key is None: raise RuntimeError("Cohere Rerank Requires an API Key") - sim_scores = cohere_rerank( + sim_scores = await cohere_rerank( query=rerank_request.query, docs=rerank_request.documents, model_name=rerank_request.model_name, diff --git a/backend/model_server/management_endpoints.py b/backend/model_server/management_endpoints.py index 56640a2fa73..4c6387e0708 100644 --- a/backend/model_server/management_endpoints.py +++ b/backend/model_server/management_endpoints.py @@ -6,12 +6,12 @@ @router.get("/health") -def healthcheck() -> Response: +async def healthcheck() -> Response: return Response(status_code=200) @router.get("/gpu-status") -def gpu_status() -> dict[str, bool | str]: +async def gpu_status() -> dict[str, bool | str]: if torch.cuda.is_available(): return {"gpu_available": True, "type": "cuda"} elif torch.backends.mps.is_available(): diff --git a/backend/model_server/utils.py b/backend/model_server/utils.py index 0c2d6bac5dc..86192b031f6 100644 --- a/backend/model_server/utils.py +++ b/backend/model_server/utils.py @@ -1,3 +1,4 @@ +import asyncio import time from collections.abc import Callable from collections.abc import Generator @@ -21,21 +22,39 @@ def simple_log_function_time( include_args: bool = False, ) -> Callable[[F], F]: def decorator(func: F) -> F: - @wraps(func) - def wrapped_func(*args: Any, **kwargs: Any) -> Any: - start_time = time.time() - result = func(*args, **kwargs) - elapsed_time_str = str(time.time() - start_time) - log_name = func_name or func.__name__ - args_str = f" args={args} kwargs={kwargs}" if include_args else "" - final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" - if debug_only: - logger.debug(final_log) - else: - logger.notice(final_log) - - return result - - return cast(F, wrapped_func) + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def wrapped_async_func(*args: Any, **kwargs: Any) -> Any: + start_time = time.time() + result = await func(*args, **kwargs) + elapsed_time_str = str(time.time() - start_time) + log_name = func_name or func.__name__ + args_str = f" args={args} kwargs={kwargs}" if include_args else "" + final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" + if debug_only: + logger.debug(final_log) + else: + logger.notice(final_log) + return result + + return cast(F, wrapped_async_func) + else: + + @wraps(func) + def wrapped_sync_func(*args: Any, **kwargs: Any) -> Any: + start_time = time.time() + result = func(*args, **kwargs) + elapsed_time_str = str(time.time() - start_time) + log_name = func_name or func.__name__ + args_str = f" args={args} kwargs={kwargs}" if include_args else "" + final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" + if debug_only: + logger.debug(final_log) + else: + logger.notice(final_log) + return result + + return cast(F, wrapped_sync_func) return decorator diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 8a13bb8a74f..3a4996d9014 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -29,7 +29,7 @@ trafilatura==1.12.2 langchain==0.1.17 langchain-core==0.1.50 langchain-text-splitters==0.0.1 -litellm==1.53.1 +litellm==1.54.1 lxml==5.3.0 lxml_html_clean==0.2.2 llama-index==0.9.45 diff --git a/backend/requirements/dev.txt b/backend/requirements/dev.txt index 27304dbef37..a89b8db674d 100644 --- a/backend/requirements/dev.txt +++ b/backend/requirements/dev.txt @@ -1,30 +1,34 @@ black==23.3.0 +boto3-stubs[s3]==1.34.133 celery-types==0.19.0 +cohere==5.6.1 +google-cloud-aiplatform==1.58.0 +lxml==5.3.0 +lxml_html_clean==0.2.2 mypy-extensions==1.0.0 mypy==1.8.0 +pandas-stubs==2.2.3.241009 +pandas==2.2.3 pre-commit==3.2.2 +pytest-asyncio==0.22.0 pytest==7.4.4 reorder-python-imports==3.9.0 ruff==0.0.286 -types-PyYAML==6.0.12.11 +sentence-transformers==2.6.1 +trafilatura==1.12.2 types-beautifulsoup4==4.12.0.3 types-html5lib==1.1.11.13 types-oauthlib==3.2.0.9 -types-setuptools==68.0.0.3 -types-Pillow==10.2.0.20240822 types-passlib==1.7.7.20240106 +types-Pillow==10.2.0.20240822 types-psutil==5.9.5.17 types-psycopg2==2.9.21.10 types-python-dateutil==2.8.19.13 types-pytz==2023.3.1.1 +types-PyYAML==6.0.12.11 types-regex==2023.3.23.1 types-requests==2.28.11.17 types-retry==0.9.9.3 +types-setuptools==68.0.0.3 types-urllib3==1.26.25.11 -trafilatura==1.12.2 -lxml==5.3.0 -lxml_html_clean==0.2.2 -boto3-stubs[s3]==1.34.133 -pandas==2.2.3 -pandas-stubs==2.2.3.241009 -cohere==5.6.1 \ No newline at end of file +voyageai==0.2.3 diff --git a/backend/requirements/model_server.txt b/backend/requirements/model_server.txt index 4803dc64eb6..531382cb4b1 100644 --- a/backend/requirements/model_server.txt +++ b/backend/requirements/model_server.txt @@ -12,5 +12,5 @@ torch==2.2.0 transformers==4.39.2 uvicorn==0.21.1 voyageai==0.2.3 -litellm==1.50.2 +litellm==1.54.1 sentry-sdk[fastapi,celery,starlette]==2.14.0 \ No newline at end of file diff --git a/backend/tests/integration/common_utils/managers/persona.py b/backend/tests/integration/common_utils/managers/persona.py index de2d9db25c1..e5392dfb68b 100644 --- a/backend/tests/integration/common_utils/managers/persona.py +++ b/backend/tests/integration/common_utils/managers/persona.py @@ -42,7 +42,7 @@ def create( "is_public": is_public, "llm_filter_extraction": llm_filter_extraction, "recency_bias": recency_bias, - "prompt_ids": prompt_ids or [], + "prompt_ids": prompt_ids or [0], "document_set_ids": document_set_ids or [], "tool_ids": tool_ids or [], "llm_model_provider_override": llm_model_provider_override, diff --git a/backend/tests/integration/common_utils/managers/tenant.py b/backend/tests/integration/common_utils/managers/tenant.py index fc411018df7..c25a1b2ec6e 100644 --- a/backend/tests/integration/common_utils/managers/tenant.py +++ b/backend/tests/integration/common_utils/managers/tenant.py @@ -69,8 +69,10 @@ def get_all_users( return AllUsersResponse( accepted=[FullUserSnapshot(**user) for user in data["accepted"]], invited=[InvitedUserSnapshot(**user) for user in data["invited"]], + slack_users=[FullUserSnapshot(**user) for user in data["slack_users"]], accepted_pages=data["accepted_pages"], invited_pages=data["invited_pages"], + slack_users_pages=data["slack_users_pages"], ) @staticmethod diff --git a/backend/tests/integration/common_utils/managers/user.py b/backend/tests/integration/common_utils/managers/user.py index 43286c6a716..26cb29cdffb 100644 --- a/backend/tests/integration/common_utils/managers/user.py +++ b/backend/tests/integration/common_utils/managers/user.py @@ -130,8 +130,10 @@ def verify( all_users = AllUsersResponse( accepted=[FullUserSnapshot(**user) for user in data["accepted"]], invited=[InvitedUserSnapshot(**user) for user in data["invited"]], + slack_users=[FullUserSnapshot(**user) for user in data["slack_users"]], accepted_pages=data["accepted_pages"], invited_pages=data["invited_pages"], + slack_users_pages=data["slack_users_pages"], ) for accepted_user in all_users.accepted: if accepted_user.email == user.email and accepted_user.id == user.id: diff --git a/backend/tests/integration/tests/api_key/test_api_key.py b/backend/tests/integration/tests/api_key/test_api_key.py index bd0618b962d..34023d897a5 100644 --- a/backend/tests/integration/tests/api_key/test_api_key.py +++ b/backend/tests/integration/tests/api_key/test_api_key.py @@ -27,13 +27,6 @@ def test_limited(reset: None) -> None: ) assert response.status_code == 200 - # test basic endpoints - response = requests.get( - f"{API_SERVER_URL}/input_prompt", - headers=api_key.headers, - ) - assert response.status_code == 403 - # test admin endpoints response = requests.get( f"{API_SERVER_URL}/admin/api-key", diff --git a/backend/tests/unit/danswer/indexing/test_indexing_pipeline.py b/backend/tests/unit/danswer/indexing/test_indexing_pipeline.py new file mode 100644 index 00000000000..612535f67ed --- /dev/null +++ b/backend/tests/unit/danswer/indexing/test_indexing_pipeline.py @@ -0,0 +1,120 @@ +from typing import List + +from danswer.configs.app_configs import MAX_DOCUMENT_CHARS +from danswer.connectors.models import Document +from danswer.connectors.models import DocumentSource +from danswer.connectors.models import Section +from danswer.indexing.indexing_pipeline import filter_documents + + +def create_test_document( + doc_id: str = "test_id", + title: str | None = "Test Title", + semantic_id: str = "test_semantic_id", + sections: List[Section] | None = None, +) -> Document: + if sections is None: + sections = [Section(text="Test content", link="test_link")] + return Document( + id=doc_id, + title=title, + semantic_identifier=semantic_id, + sections=sections, + source=DocumentSource.FILE, + metadata={}, + ) + + +def test_filter_documents_empty_title_and_content() -> None: + doc = create_test_document( + title="", semantic_id="", sections=[Section(text="", link="test_link")] + ) + result = filter_documents([doc]) + assert len(result) == 0 + + +def test_filter_documents_empty_title_with_content() -> None: + doc = create_test_document( + title="", sections=[Section(text="Valid content", link="test_link")] + ) + result = filter_documents([doc]) + assert len(result) == 1 + assert result[0].id == "test_id" + + +def test_filter_documents_empty_content_with_title() -> None: + doc = create_test_document( + title="Valid Title", sections=[Section(text="", link="test_link")] + ) + result = filter_documents([doc]) + assert len(result) == 1 + assert result[0].id == "test_id" + + +def test_filter_documents_exceeding_max_chars() -> None: + if not MAX_DOCUMENT_CHARS: # Skip if no max chars configured + return + long_text = "a" * (MAX_DOCUMENT_CHARS + 1) + doc = create_test_document(sections=[Section(text=long_text, link="test_link")]) + result = filter_documents([doc]) + assert len(result) == 0 + + +def test_filter_documents_valid_document() -> None: + doc = create_test_document( + title="Valid Title", sections=[Section(text="Valid content", link="test_link")] + ) + result = filter_documents([doc]) + assert len(result) == 1 + assert result[0].id == "test_id" + assert result[0].title == "Valid Title" + + +def test_filter_documents_whitespace_only() -> None: + doc = create_test_document( + title=" ", semantic_id=" ", sections=[Section(text=" ", link="test_link")] + ) + result = filter_documents([doc]) + assert len(result) == 0 + + +def test_filter_documents_semantic_id_no_title() -> None: + doc = create_test_document( + title=None, + semantic_id="Valid Semantic ID", + sections=[Section(text="Valid content", link="test_link")], + ) + result = filter_documents([doc]) + assert len(result) == 1 + assert result[0].semantic_identifier == "Valid Semantic ID" + + +def test_filter_documents_multiple_sections() -> None: + doc = create_test_document( + sections=[ + Section(text="Content 1", link="test_link"), + Section(text="Content 2", link="test_link"), + Section(text="Content 3", link="test_link"), + ] + ) + result = filter_documents([doc]) + assert len(result) == 1 + assert len(result[0].sections) == 3 + + +def test_filter_documents_multiple_documents() -> None: + docs = [ + create_test_document(doc_id="1", title="Title 1"), + create_test_document( + doc_id="2", title="", sections=[Section(text="", link="test_link")] + ), # Should be filtered + create_test_document(doc_id="3", title="Title 3"), + ] + result = filter_documents(docs) + assert len(result) == 2 + assert {doc.id for doc in result} == {"1", "3"} + + +def test_filter_documents_empty_batch() -> None: + result = filter_documents([]) + assert len(result) == 0 diff --git a/backend/tests/unit/model_server/test_embedding.py b/backend/tests/unit/model_server/test_embedding.py new file mode 100644 index 00000000000..6781ab27aa0 --- /dev/null +++ b/backend/tests/unit/model_server/test_embedding.py @@ -0,0 +1,198 @@ +import asyncio +import time +from collections.abc import AsyncGenerator +from typing import Any +from typing import List +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from httpx import AsyncClient +from litellm.exceptions import RateLimitError + +from model_server.encoders import CloudEmbedding +from model_server.encoders import embed_text +from model_server.encoders import local_rerank +from model_server.encoders import process_embed_request +from shared_configs.enums import EmbeddingProvider +from shared_configs.enums import EmbedTextType +from shared_configs.model_server_models import EmbedRequest + + +@pytest.fixture +async def mock_http_client() -> AsyncGenerator[AsyncMock, None]: + with patch("httpx.AsyncClient") as mock: + client = AsyncMock(spec=AsyncClient) + mock.return_value = client + client.post = AsyncMock() + async with client as c: + yield c + + +@pytest.fixture +def sample_embeddings() -> List[List[float]]: + return [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + + +@pytest.mark.asyncio +async def test_cloud_embedding_context_manager() -> None: + async with CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) as embedding: + assert not embedding._closed + assert embedding._closed + + +@pytest.mark.asyncio +async def test_cloud_embedding_explicit_close() -> None: + embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) + assert not embedding._closed + await embedding.aclose() + assert embedding._closed + + +@pytest.mark.asyncio +async def test_openai_embedding( + mock_http_client: AsyncMock, sample_embeddings: List[List[float]] +) -> None: + with patch("openai.AsyncOpenAI") as mock_openai: + mock_client = AsyncMock() + mock_openai.return_value = mock_client + + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=emb) for emb in sample_embeddings] + mock_client.embeddings.create = AsyncMock(return_value=mock_response) + + embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) + result = await embedding._embed_openai( + ["test1", "test2"], "text-embedding-ada-002" + ) + + assert result == sample_embeddings + mock_client.embeddings.create.assert_called_once() + + +@pytest.mark.asyncio +async def test_embed_text_cloud_provider() -> None: + with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed: + mock_embed.return_value = [[0.1, 0.2], [0.3, 0.4]] + mock_embed.side_effect = AsyncMock(return_value=[[0.1, 0.2], [0.3, 0.4]]) + + result = await embed_text( + texts=["test1", "test2"], + text_type=EmbedTextType.QUERY, + model_name="fake-model", + deployment_name=None, + max_context_length=512, + normalize_embeddings=True, + api_key="fake-key", + provider_type=EmbeddingProvider.OPENAI, + prefix=None, + api_url=None, + api_version=None, + ) + + assert result == [[0.1, 0.2], [0.3, 0.4]] + mock_embed.assert_called_once() + + +@pytest.mark.asyncio +async def test_embed_text_local_model() -> None: + with patch("model_server.encoders.get_embedding_model") as mock_get_model: + mock_model = MagicMock() + mock_model.encode.return_value = [[0.1, 0.2], [0.3, 0.4]] + mock_get_model.return_value = mock_model + + result = await embed_text( + texts=["test1", "test2"], + text_type=EmbedTextType.QUERY, + model_name="fake-local-model", + deployment_name=None, + max_context_length=512, + normalize_embeddings=True, + api_key=None, + provider_type=None, + prefix=None, + api_url=None, + api_version=None, + ) + + assert result == [[0.1, 0.2], [0.3, 0.4]] + mock_model.encode.assert_called_once() + + +@pytest.mark.asyncio +async def test_local_rerank() -> None: + with patch("model_server.encoders.get_local_reranking_model") as mock_get_model: + mock_model = MagicMock() + mock_array = MagicMock() + mock_array.tolist.return_value = [0.8, 0.6] + mock_model.predict.return_value = mock_array + mock_get_model.return_value = mock_model + + result = await local_rerank( + query="test query", docs=["doc1", "doc2"], model_name="fake-rerank-model" + ) + + assert result == [0.8, 0.6] + mock_model.predict.assert_called_once() + + +@pytest.mark.asyncio +async def test_rate_limit_handling() -> None: + with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed: + mock_embed.side_effect = RateLimitError( + "Rate limit exceeded", llm_provider="openai", model="fake-model" + ) + + with pytest.raises(RateLimitError): + await embed_text( + texts=["test"], + text_type=EmbedTextType.QUERY, + model_name="fake-model", + deployment_name=None, + max_context_length=512, + normalize_embeddings=True, + api_key="fake-key", + provider_type=EmbeddingProvider.OPENAI, + prefix=None, + api_url=None, + api_version=None, + ) + + +@pytest.mark.asyncio +async def test_concurrent_embeddings() -> None: + def mock_encode(*args: Any, **kwargs: Any) -> List[List[float]]: + time.sleep(5) + return [[0.1, 0.2, 0.3]] + + test_req = EmbedRequest( + texts=["test"], + model_name="'nomic-ai/nomic-embed-text-v1'", + deployment_name=None, + max_context_length=512, + normalize_embeddings=True, + api_key=None, + provider_type=None, + text_type=EmbedTextType.QUERY, + manual_query_prefix=None, + manual_passage_prefix=None, + api_url=None, + api_version=None, + ) + + with patch("model_server.encoders.get_embedding_model") as mock_get_model: + mock_model = MagicMock() + mock_model.encode = mock_encode + mock_get_model.return_value = mock_model + start_time = time.time() + + tasks = [process_embed_request(test_req) for _ in range(5)] + await asyncio.gather(*tasks) + + end_time = time.time() + + # 5 * 5 seconds = 25 seconds, this test ensures that the embeddings are at least yielding the thread + # However, the developer may still introduce unnecessary blocking above the mock and this test will + # still pass as long as it's less than (7 - 5) / 5 seconds + assert end_time - start_time < 7 diff --git a/ct.yaml b/ct.yaml index f568ef5d52b..cec4478c850 100644 --- a/ct.yaml +++ b/ct.yaml @@ -6,7 +6,7 @@ chart-dirs: # must be kept in sync with Chart.yaml chart-repos: - - vespa=https://danswer-ai.github.io/vespa-helm-charts + - vespa=https://onyx-dot-app.github.io/vespa-helm-charts - postgresql=https://charts.bitnami.com/bitnami helm-extra-args: --debug --timeout 600s diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index bcd73b729c1..19991de2d37 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -183,6 +183,13 @@ services: - GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} - GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-} + - MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-} + - MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-} + # Egnyte OAuth Configs + - EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-} + - EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-} + - EGNYTE_BASE_DOMAIN=${EGNYTE_BASE_DOMAIN:-} + - EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-} # Celery Configs (defaults are set in the supervisord.conf file. # prefer doing that to have one source of defaults) - CELERY_WORKER_INDEXING_CONCURRENCY=${CELERY_WORKER_INDEXING_CONCURRENCY:-} diff --git a/deployment/helm/charts/danswer/Chart.lock b/deployment/helm/charts/danswer/Chart.lock index 26cc24e4494..af26f510eb1 100644 --- a/deployment/helm/charts/danswer/Chart.lock +++ b/deployment/helm/charts/danswer/Chart.lock @@ -3,13 +3,13 @@ dependencies: repository: https://charts.bitnami.com/bitnami version: 14.3.1 - name: vespa - repository: https://danswer-ai.github.io/vespa-helm-charts - version: 0.2.16 + repository: https://onyx-dot-app.github.io/vespa-helm-charts + version: 0.2.18 - name: nginx repository: oci://registry-1.docker.io/bitnamicharts version: 15.14.0 - name: redis repository: https://charts.bitnami.com/bitnami version: 20.1.0 -digest: sha256:711bbb76ba6ab604a36c9bf1839ab6faa5610afb21e535afd933c78f2d102232 -generated: "2024-11-07T09:39:30.17171-08:00" +digest: sha256:5c9eb3d55d5f8e3beb64f26d26f686c8d62755daa10e2e6d87530bdf2fbbf957 +generated: "2024-12-10T10:47:35.812483-08:00" diff --git a/deployment/helm/charts/danswer/Chart.yaml b/deployment/helm/charts/danswer/Chart.yaml index 8cda8e8ba2e..b033122c0fc 100644 --- a/deployment/helm/charts/danswer/Chart.yaml +++ b/deployment/helm/charts/danswer/Chart.yaml @@ -23,8 +23,8 @@ dependencies: repository: https://charts.bitnami.com/bitnami condition: postgresql.enabled - name: vespa - version: 0.2.16 - repository: https://danswer-ai.github.io/vespa-helm-charts + version: 0.2.18 + repository: https://onyx-dot-app.github.io/vespa-helm-charts condition: vespa.enabled - name: nginx version: 15.14.0 diff --git a/deployment/kubernetes/env-configmap.yaml b/deployment/kubernetes/env-configmap.yaml index 176e468c110..84bd6747973 100644 --- a/deployment/kubernetes/env-configmap.yaml +++ b/deployment/kubernetes/env-configmap.yaml @@ -61,6 +61,8 @@ data: WEB_CONNECTOR_VALIDATE_URLS: "" GONG_CONNECTOR_START_TIME: "" NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP: "" + MAX_DOCUMENT_CHARS: "" + MAX_FILE_SIZE_BYTES: "" # DanswerBot SlackBot Configs DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER: "" DANSWER_BOT_DISPLAY_ERROR_MSGS: "" diff --git a/web/public/Egnyte.png b/web/public/Egnyte.png new file mode 100644 index 00000000000..54eef07dc28 Binary files /dev/null and b/web/public/Egnyte.png differ diff --git a/web/public/Wikipedia.png b/web/public/Wikipedia.png new file mode 100644 index 00000000000..30d9a3bbae0 Binary files /dev/null and b/web/public/Wikipedia.png differ diff --git a/web/public/Wikipedia.svg b/web/public/Wikipedia.svg deleted file mode 100644 index ee4a3caa55f..00000000000 --- a/web/public/Wikipedia.svg +++ /dev/null @@ -1,535 +0,0 @@ - - diff --git a/web/src/app/admin/api-key/DanswerApiKeyForm.tsx b/web/src/app/admin/api-key/DanswerApiKeyForm.tsx index 80bb84d626f..27d6457d141 100644 --- a/web/src/app/admin/api-key/DanswerApiKeyForm.tsx +++ b/web/src/app/admin/api-key/DanswerApiKeyForm.tsx @@ -82,7 +82,7 @@ export const DanswerApiKeyForm = ({ }} > {({ isSubmitting, values, setFieldValue }) => ( -