Skip to content

Commit

Permalink
Add support for passthrough auth for custom tool calls (#2824)
Browse files Browse the repository at this point in the history
* Add support for passthrough auth for custom tool calls

* Fix formatting
  • Loading branch information
Weves authored Oct 16, 2024
1 parent db0779d commit 33974fc
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 50 deletions.
14 changes: 10 additions & 4 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time

Expand Down Expand Up @@ -276,7 +277,7 @@ def stream_chat_message_objects(
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
tool_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
) -> ChatPacketStream:
Expand Down Expand Up @@ -640,7 +641,12 @@ def stream_chat_message_objects(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=db_tool_model.custom_headers,
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
),
)

Expand Down Expand Up @@ -863,7 +869,7 @@ def stream_chat_message(
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
tool_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
Expand All @@ -873,7 +879,7 @@ def stream_chat_message(
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
tool_additional_headers=tool_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,
)
for obj in objects:
Expand Down
16 changes: 0 additions & 16 deletions backend/danswer/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,3 @@
logger.error(
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
)


# List of headers to pass through to tool calls (e.g., API requests made by tools)
# This allows for dynamic configuration of tool behavior based on incoming request headers
TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get("TOOL_PASS_THROUGH_HEADERS")
if _TOOL_PASS_THROUGH_HEADERS_RAW:
try:
TOOL_PASS_THROUGH_HEADERS = json.loads(_TOOL_PASS_THROUGH_HEADERS_RAW)
except Exception:
from danswer.utils.logger import setup_logger

logger = setup_logger()
logger.error(
"Failed to parse TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object"
)
22 changes: 22 additions & 0 deletions backend/danswer/configs/tool_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import json
import os


# if specified, will pass through request headers to the call to API calls made by custom tools
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(
"CUSTOM_TOOL_PASS_THROUGH_HEADERS"
)
if _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW:
try:
CUSTOM_TOOL_PASS_THROUGH_HEADERS = json.loads(
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW
)
except Exception:
# need to import here to avoid circular imports
from danswer.utils.logger import setup_logger

logger = setup_logger()
logger.error(
"Failed to parse CUSTOM_TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object"
)
3 changes: 2 additions & 1 deletion backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.encryption import decrypt_bytes_to_string
from danswer.utils.encryption import encrypt_string_to_bytes
from danswer.utils.headers import HeaderItemDict
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import RerankerProvider

Expand Down Expand Up @@ -1288,7 +1289,7 @@ class Tool(Base):
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
custom_headers: Mapped[list[dict[str, str]] | None] = mapped_column(
custom_headers: Mapped[list[HeaderItemDict] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# user who created / owns the tool. Will be None for built-in tools.
Expand Down
6 changes: 5 additions & 1 deletion backend/danswer/db/tools.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Any
from typing import cast
from uuid import UUID

from sqlalchemy import select
from sqlalchemy.orm import Session

from danswer.db.models import Tool
from danswer.server.features.tool.models import Header
from danswer.utils.headers import HeaderItemDict
from danswer.utils.logger import setup_logger

logger = setup_logger()
Expand Down Expand Up @@ -67,7 +69,9 @@ def update_tool(
if user_id is not None:
tool.user_id = user_id
if custom_headers is not None:
tool.custom_headers = [header.dict() for header in custom_headers]
tool.custom_headers = [
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
]
db_session.commit()

return tool
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/llm/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from danswer.db.models import Persona
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.headers import build_llm_extra_headers
from danswer.llm.interfaces import LLM
from danswer.llm.override_models import LLMOverride
from danswer.utils.headers import build_llm_extra_headers


def get_main_llm_from_tuple(
Expand Down
12 changes: 0 additions & 12 deletions backend/danswer/llm/headers.py

This file was deleted.

6 changes: 3 additions & 3 deletions backend/danswer/server/query_and_chat/chat_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
from danswer.configs.model_configs import TOOL_PASS_THROUGH_HEADERS
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import delete_chat_session
Expand Down Expand Up @@ -74,6 +73,7 @@
from danswer.server.query_and_chat.models import SearchFeedbackRequest
from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
from danswer.utils.headers import get_custom_tool_additional_request_headers
from danswer.utils.logger import setup_logger


Expand Down Expand Up @@ -338,8 +338,8 @@ def stream_generator() -> Generator[str, None, None]:
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
tool_additional_headers=extract_headers(
request.headers, TOOL_PASS_THROUGH_HEADERS
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
is_connected=is_disconnected_func,
):
Expand Down
17 changes: 8 additions & 9 deletions backend/danswer/tools/custom/custom_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from danswer.tools.models import MESSAGE_ID_PLACEHOLDER
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.headers import header_list_to_header_dict
from danswer.utils.headers import HeaderItemDict
from danswer.utils.logger import setup_logger

logger = setup_logger()
Expand All @@ -46,18 +48,17 @@ def __init__(
self,
method_spec: MethodSpec,
base_url: str,
custom_headers: list[dict[str, str]] | None = [],
tool_additional_headers: dict[str, str] | None = None,
custom_headers: list[HeaderItemDict] | None = None,
) -> None:
self._base_url = base_url
self._method_spec = method_spec
self._tool_definition = self._method_spec.to_tool_definition()

self._name = self._method_spec.name
self._description = self._method_spec.summary
self.headers = {
header["key"]: header["value"] for header in (custom_headers or [])
} | (tool_additional_headers or {})
self.headers = (
header_list_to_header_dict(custom_headers) if custom_headers else {}
)

@property
def name(self) -> str:
Expand Down Expand Up @@ -184,8 +185,7 @@ def final_result(self, *args: ToolResponse) -> JSON_ro:

def build_custom_tools_from_openapi_schema_and_headers(
openapi_schema: dict[str, Any],
tool_additional_headers: dict[str, str] | None = None,
custom_headers: list[dict[str, str]] | None = [],
custom_headers: list[HeaderItemDict] | None = None,
dynamic_schema_info: DynamicSchemaInfo | None = None,
) -> list[CustomTool]:
if dynamic_schema_info:
Expand All @@ -205,8 +205,7 @@ def build_custom_tools_from_openapi_schema_and_headers(
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [
CustomTool(method_spec, url, custom_headers, tool_additional_headers)
for method_spec in method_specs
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
]


Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/tools/images/image_generation_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.headers import build_llm_extra_headers
from danswer.llm.interfaces import LLM
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import message_to_string
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.headers import build_llm_extra_headers
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel

Expand Down
79 changes: 79 additions & 0 deletions backend/danswer/utils/headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import TypedDict

from fastapi.datastructures import Headers

from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
from danswer.configs.tool_configs import CUSTOM_TOOL_PASS_THROUGH_HEADERS
from danswer.utils.logger import setup_logger

logger = setup_logger()


class HeaderItemDict(TypedDict):
key: str
value: str


def clean_header_list(headers_to_clean: list[HeaderItemDict]) -> dict[str, str]:
cleaned_headers: dict[str, str] = {}
for item in headers_to_clean:
key = item["key"]
value = item["value"]
if key in cleaned_headers:
logger.warning(
f"Duplicate header {key} found in custom headers, ignoring..."
)
continue
cleaned_headers[key] = value
return cleaned_headers


def header_dict_to_header_list(header_dict: dict[str, str]) -> list[HeaderItemDict]:
return [{"key": key, "value": value} for key, value in header_dict.items()]


def header_list_to_header_dict(header_list: list[HeaderItemDict]) -> dict[str, str]:
return {header["key"]: header["value"] for header in header_list}


def get_relevant_headers(
headers: dict[str, str] | Headers, desired_headers: list[str] | None
) -> dict[str, str]:
if not desired_headers:
return {}

pass_through_headers: dict[str, str] = {}
for key in desired_headers:
if key in headers:
pass_through_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
pass_through_headers[lowercase_key] = headers[lowercase_key]

return pass_through_headers


def get_litellm_additional_request_headers(
headers: dict[str, str] | Headers
) -> dict[str, str]:
return get_relevant_headers(headers, LITELLM_PASS_THROUGH_HEADERS)


def build_llm_extra_headers(
additional_headers: dict[str, str] | None = None
) -> dict[str, str]:
extra_headers: dict[str, str] = {}
if additional_headers:
extra_headers.update(additional_headers)
if LITELLM_EXTRA_HEADERS:
extra_headers.update(LITELLM_EXTRA_HEADERS)
return extra_headers


def get_custom_tool_additional_request_headers(
headers: dict[str, str] | Headers
) -> dict[str, str]:
return get_relevant_headers(headers, CUSTOM_TOOL_PASS_THROUGH_HEADERS)
5 changes: 3 additions & 2 deletions backend/tests/unit/danswer/tools/custom/test_custom_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from danswer.tools.custom.custom_tool import validate_openapi_schema
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.tool import ToolResponse
from danswer.utils.headers import HeaderItemDict


class TestCustomTool(unittest.TestCase):
Expand Down Expand Up @@ -143,7 +144,7 @@ def test_custom_tool_with_headers(
Test the custom tool with custom headers.
Verifies that the tool correctly includes the custom headers in the request.
"""
custom_headers: list[dict[str, str]] = [
custom_headers: list[HeaderItemDict] = [
{"key": "Authorization", "value": "Bearer token123"},
{"key": "Custom-Header", "value": "CustomValue"},
]
Expand Down Expand Up @@ -171,7 +172,7 @@ def test_custom_tool_with_empty_headers(
Test the custom tool with an empty list of custom headers.
Verifies that the tool correctly handles an empty list of headers.
"""
custom_headers: list[dict[str, str]] = []
custom_headers: list[HeaderItemDict] = []
tools = build_custom_tools_from_openapi_schema_and_headers(
self.openapi_schema,
custom_headers=custom_headers,
Expand Down

0 comments on commit 33974fc

Please sign in to comment.