diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 244aeb2b70f..ea4e7be93d4 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -105,6 +105,7 @@ from danswer.tools.tool_runner import ToolCallFinalResult from danswer.tools.utils import compute_all_tool_tokens from danswer.tools.utils import explicit_tool_calling_supported +from danswer.utils.headers import header_dict_to_header_list from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -276,7 +277,7 @@ def stream_chat_message_objects( # on the `new_msg_req.message`. Currently, requires a state where the last message is a use_existing_user_message: bool = False, litellm_additional_headers: dict[str, str] | None = None, - tool_additional_headers: dict[str, str] | None = None, + custom_tool_additional_headers: dict[str, str] | None = None, is_connected: Callable[[], bool] | None = None, enforce_chat_session_id_for_search_docs: bool = True, ) -> ChatPacketStream: @@ -640,7 +641,12 @@ def stream_chat_message_objects( chat_session_id=chat_session_id, message_id=user_message.id if user_message else None, ), - custom_headers=db_tool_model.custom_headers, + custom_headers=(db_tool_model.custom_headers or []) + + ( + header_dict_to_header_list( + custom_tool_additional_headers or {} + ) + ), ), ) @@ -863,7 +869,7 @@ def stream_chat_message( user: User | None, use_existing_user_message: bool = False, litellm_additional_headers: dict[str, str] | None = None, - tool_additional_headers: dict[str, str] | None = None, + custom_tool_additional_headers: dict[str, str] | None = None, is_connected: Callable[[], bool] | None = None, ) -> Iterator[str]: with get_session_context_manager() as db_session: @@ -873,7 +879,7 @@ def stream_chat_message( db_session=db_session, use_existing_user_message=use_existing_user_message, litellm_additional_headers=litellm_additional_headers, - tool_additional_headers=tool_additional_headers, + custom_tool_additional_headers=custom_tool_additional_headers, is_connected=is_connected, ) for obj in objects: diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 4454eee159b..c9668cd8136 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -119,19 +119,3 @@ logger.error( "Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object" ) - - -# List of headers to pass through to tool calls (e.g., API requests made by tools) -# This allows for dynamic configuration of tool behavior based on incoming request headers -TOOL_PASS_THROUGH_HEADERS: list[str] | None = None -_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get("TOOL_PASS_THROUGH_HEADERS") -if _TOOL_PASS_THROUGH_HEADERS_RAW: - try: - TOOL_PASS_THROUGH_HEADERS = json.loads(_TOOL_PASS_THROUGH_HEADERS_RAW) - except Exception: - from danswer.utils.logger import setup_logger - - logger = setup_logger() - logger.error( - "Failed to parse TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object" - ) diff --git a/backend/danswer/configs/tool_configs.py b/backend/danswer/configs/tool_configs.py new file mode 100644 index 00000000000..3170cb31ff9 --- /dev/null +++ b/backend/danswer/configs/tool_configs.py @@ -0,0 +1,22 @@ +import json +import os + + +# if specified, will pass through request headers to the call to API calls made by custom tools +CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None +_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get( + "CUSTOM_TOOL_PASS_THROUGH_HEADERS" +) +if _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW: + try: + CUSTOM_TOOL_PASS_THROUGH_HEADERS = json.loads( + _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW + ) + except Exception: + # need to import here to avoid circular imports + from danswer.utils.logger import setup_logger + + logger = setup_logger() + logger.error( + "Failed to parse CUSTOM_TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object" + ) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 9f8f1e2371a..2101fc74e90 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -60,6 +60,7 @@ from danswer.search.enums import RecencyBiasSetting from danswer.utils.encryption import decrypt_bytes_to_string from danswer.utils.encryption import encrypt_string_to_bytes +from danswer.utils.headers import HeaderItemDict from shared_configs.enums import EmbeddingProvider from shared_configs.enums import RerankerProvider @@ -1288,7 +1289,7 @@ class Tool(Base): openapi_schema: Mapped[dict[str, Any] | None] = mapped_column( postgresql.JSONB(), nullable=True ) - custom_headers: Mapped[list[dict[str, str]] | None] = mapped_column( + custom_headers: Mapped[list[HeaderItemDict] | None] = mapped_column( postgresql.JSONB(), nullable=True ) # user who created / owns the tool. Will be None for built-in tools. diff --git a/backend/danswer/db/tools.py b/backend/danswer/db/tools.py index 248744b5639..0fd126d0065 100644 --- a/backend/danswer/db/tools.py +++ b/backend/danswer/db/tools.py @@ -1,4 +1,5 @@ from typing import Any +from typing import cast from uuid import UUID from sqlalchemy import select @@ -6,6 +7,7 @@ from danswer.db.models import Tool from danswer.server.features.tool.models import Header +from danswer.utils.headers import HeaderItemDict from danswer.utils.logger import setup_logger logger = setup_logger() @@ -67,7 +69,9 @@ def update_tool( if user_id is not None: tool.user_id = user_id if custom_headers is not None: - tool.custom_headers = [header.dict() for header in custom_headers] + tool.custom_headers = [ + cast(HeaderItemDict, header.model_dump()) for header in custom_headers + ] db_session.commit() return tool diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index 904735d5ffe..f930c3d3358 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -7,9 +7,9 @@ from danswer.db.models import Persona from danswer.llm.chat_llm import DefaultMultiLLM from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.headers import build_llm_extra_headers from danswer.llm.interfaces import LLM from danswer.llm.override_models import LLMOverride +from danswer.utils.headers import build_llm_extra_headers def get_main_llm_from_tuple( diff --git a/backend/danswer/llm/headers.py b/backend/danswer/llm/headers.py deleted file mode 100644 index 13622167d99..00000000000 --- a/backend/danswer/llm/headers.py +++ /dev/null @@ -1,12 +0,0 @@ -from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS - - -def build_llm_extra_headers( - additional_headers: dict[str, str] | None = None -) -> dict[str, str]: - extra_headers: dict[str, str] = {} - if additional_headers: - extra_headers.update(additional_headers) - if LITELLM_EXTRA_HEADERS: - extra_headers.update(LITELLM_EXTRA_HEADERS) - return extra_headers diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 5b0cc30a1b4..6e2d3c40988 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -25,7 +25,6 @@ from danswer.configs.constants import FileOrigin from danswer.configs.constants import MessageType from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS -from danswer.configs.model_configs import TOOL_PASS_THROUGH_HEADERS from danswer.db.chat import create_chat_session from danswer.db.chat import create_new_chat_message from danswer.db.chat import delete_chat_session @@ -74,6 +73,7 @@ from danswer.server.query_and_chat.models import SearchFeedbackRequest from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest from danswer.server.query_and_chat.token_limit import check_token_rate_limits +from danswer.utils.headers import get_custom_tool_additional_request_headers from danswer.utils.logger import setup_logger @@ -338,8 +338,8 @@ def stream_generator() -> Generator[str, None, None]: litellm_additional_headers=extract_headers( request.headers, LITELLM_PASS_THROUGH_HEADERS ), - tool_additional_headers=extract_headers( - request.headers, TOOL_PASS_THROUGH_HEADERS + custom_tool_additional_headers=get_custom_tool_additional_request_headers( + request.headers ), is_connected=is_disconnected_func, ): diff --git a/backend/danswer/tools/custom/custom_tool.py b/backend/danswer/tools/custom/custom_tool.py index 8f4a4b23fa8..ee431af70e1 100644 --- a/backend/danswer/tools/custom/custom_tool.py +++ b/backend/danswer/tools/custom/custom_tool.py @@ -29,6 +29,8 @@ from danswer.tools.models import MESSAGE_ID_PLACEHOLDER from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse +from danswer.utils.headers import header_list_to_header_dict +from danswer.utils.headers import HeaderItemDict from danswer.utils.logger import setup_logger logger = setup_logger() @@ -46,8 +48,7 @@ def __init__( self, method_spec: MethodSpec, base_url: str, - custom_headers: list[dict[str, str]] | None = [], - tool_additional_headers: dict[str, str] | None = None, + custom_headers: list[HeaderItemDict] | None = None, ) -> None: self._base_url = base_url self._method_spec = method_spec @@ -55,9 +56,9 @@ def __init__( self._name = self._method_spec.name self._description = self._method_spec.summary - self.headers = { - header["key"]: header["value"] for header in (custom_headers or []) - } | (tool_additional_headers or {}) + self.headers = ( + header_list_to_header_dict(custom_headers) if custom_headers else {} + ) @property def name(self) -> str: @@ -184,8 +185,7 @@ def final_result(self, *args: ToolResponse) -> JSON_ro: def build_custom_tools_from_openapi_schema_and_headers( openapi_schema: dict[str, Any], - tool_additional_headers: dict[str, str] | None = None, - custom_headers: list[dict[str, str]] | None = [], + custom_headers: list[HeaderItemDict] | None = None, dynamic_schema_info: DynamicSchemaInfo | None = None, ) -> list[CustomTool]: if dynamic_schema_info: @@ -205,8 +205,7 @@ def build_custom_tools_from_openapi_schema_and_headers( url = openapi_to_url(openapi_schema) method_specs = openapi_to_method_specs(openapi_schema) return [ - CustomTool(method_spec, url, custom_headers, tool_additional_headers) - for method_spec in method_specs + CustomTool(method_spec, url, custom_headers) for method_spec in method_specs ] diff --git a/backend/danswer/tools/images/image_generation_tool.py b/backend/danswer/tools/images/image_generation_tool.py index 3c1fa75c742..3584d50f77e 100644 --- a/backend/danswer/tools/images/image_generation_tool.py +++ b/backend/danswer/tools/images/image_generation_tool.py @@ -11,13 +11,13 @@ from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.key_value_store.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage -from danswer.llm.headers import build_llm_extra_headers from danswer.llm.interfaces import LLM from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import message_to_string from danswer.prompts.constants import GENERAL_SEP_PAT from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse +from danswer.utils.headers import build_llm_extra_headers from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel diff --git a/backend/danswer/utils/headers.py b/backend/danswer/utils/headers.py new file mode 100644 index 00000000000..5ccf61a51e1 --- /dev/null +++ b/backend/danswer/utils/headers.py @@ -0,0 +1,79 @@ +from typing import TypedDict + +from fastapi.datastructures import Headers + +from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS +from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS +from danswer.configs.tool_configs import CUSTOM_TOOL_PASS_THROUGH_HEADERS +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +class HeaderItemDict(TypedDict): + key: str + value: str + + +def clean_header_list(headers_to_clean: list[HeaderItemDict]) -> dict[str, str]: + cleaned_headers: dict[str, str] = {} + for item in headers_to_clean: + key = item["key"] + value = item["value"] + if key in cleaned_headers: + logger.warning( + f"Duplicate header {key} found in custom headers, ignoring..." + ) + continue + cleaned_headers[key] = value + return cleaned_headers + + +def header_dict_to_header_list(header_dict: dict[str, str]) -> list[HeaderItemDict]: + return [{"key": key, "value": value} for key, value in header_dict.items()] + + +def header_list_to_header_dict(header_list: list[HeaderItemDict]) -> dict[str, str]: + return {header["key"]: header["value"] for header in header_list} + + +def get_relevant_headers( + headers: dict[str, str] | Headers, desired_headers: list[str] | None +) -> dict[str, str]: + if not desired_headers: + return {} + + pass_through_headers: dict[str, str] = {} + for key in desired_headers: + if key in headers: + pass_through_headers[key] = headers[key] + else: + # fastapi makes all header keys lowercase, handling that here + lowercase_key = key.lower() + if lowercase_key in headers: + pass_through_headers[lowercase_key] = headers[lowercase_key] + + return pass_through_headers + + +def get_litellm_additional_request_headers( + headers: dict[str, str] | Headers +) -> dict[str, str]: + return get_relevant_headers(headers, LITELLM_PASS_THROUGH_HEADERS) + + +def build_llm_extra_headers( + additional_headers: dict[str, str] | None = None +) -> dict[str, str]: + extra_headers: dict[str, str] = {} + if additional_headers: + extra_headers.update(additional_headers) + if LITELLM_EXTRA_HEADERS: + extra_headers.update(LITELLM_EXTRA_HEADERS) + return extra_headers + + +def get_custom_tool_additional_request_headers( + headers: dict[str, str] | Headers +) -> dict[str, str]: + return get_relevant_headers(headers, CUSTOM_TOOL_PASS_THROUGH_HEADERS) diff --git a/backend/tests/unit/danswer/tools/custom/test_custom_tools.py b/backend/tests/unit/danswer/tools/custom/test_custom_tools.py index 5b07e21bb83..6139f41e62a 100644 --- a/backend/tests/unit/danswer/tools/custom/test_custom_tools.py +++ b/backend/tests/unit/danswer/tools/custom/test_custom_tools.py @@ -13,6 +13,7 @@ from danswer.tools.custom.custom_tool import validate_openapi_schema from danswer.tools.models import DynamicSchemaInfo from danswer.tools.tool import ToolResponse +from danswer.utils.headers import HeaderItemDict class TestCustomTool(unittest.TestCase): @@ -143,7 +144,7 @@ def test_custom_tool_with_headers( Test the custom tool with custom headers. Verifies that the tool correctly includes the custom headers in the request. """ - custom_headers: list[dict[str, str]] = [ + custom_headers: list[HeaderItemDict] = [ {"key": "Authorization", "value": "Bearer token123"}, {"key": "Custom-Header", "value": "CustomValue"}, ] @@ -171,7 +172,7 @@ def test_custom_tool_with_empty_headers( Test the custom tool with an empty list of custom headers. Verifies that the tool correctly handles an empty list of headers. """ - custom_headers: list[dict[str, str]] = [] + custom_headers: list[HeaderItemDict] = [] tools = build_custom_tools_from_openapi_schema_and_headers( self.openapi_schema, custom_headers=custom_headers,