-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for passthrough auth for custom tool calls (#2824)
* Add support for passthrough auth for custom tool calls * Fix formatting
- Loading branch information
Showing
12 changed files
with
134 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import json | ||
import os | ||
|
||
|
||
# if specified, will pass through request headers to the call to API calls made by custom tools | ||
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None | ||
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get( | ||
"CUSTOM_TOOL_PASS_THROUGH_HEADERS" | ||
) | ||
if _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW: | ||
try: | ||
CUSTOM_TOOL_PASS_THROUGH_HEADERS = json.loads( | ||
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW | ||
) | ||
except Exception: | ||
# need to import here to avoid circular imports | ||
from danswer.utils.logger import setup_logger | ||
|
||
logger = setup_logger() | ||
logger.error( | ||
"Failed to parse CUSTOM_TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from typing import TypedDict | ||
|
||
from fastapi.datastructures import Headers | ||
|
||
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS | ||
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS | ||
from danswer.configs.tool_configs import CUSTOM_TOOL_PASS_THROUGH_HEADERS | ||
from danswer.utils.logger import setup_logger | ||
|
||
logger = setup_logger() | ||
|
||
|
||
class HeaderItemDict(TypedDict): | ||
key: str | ||
value: str | ||
|
||
|
||
def clean_header_list(headers_to_clean: list[HeaderItemDict]) -> dict[str, str]: | ||
cleaned_headers: dict[str, str] = {} | ||
for item in headers_to_clean: | ||
key = item["key"] | ||
value = item["value"] | ||
if key in cleaned_headers: | ||
logger.warning( | ||
f"Duplicate header {key} found in custom headers, ignoring..." | ||
) | ||
continue | ||
cleaned_headers[key] = value | ||
return cleaned_headers | ||
|
||
|
||
def header_dict_to_header_list(header_dict: dict[str, str]) -> list[HeaderItemDict]: | ||
return [{"key": key, "value": value} for key, value in header_dict.items()] | ||
|
||
|
||
def header_list_to_header_dict(header_list: list[HeaderItemDict]) -> dict[str, str]: | ||
return {header["key"]: header["value"] for header in header_list} | ||
|
||
|
||
def get_relevant_headers( | ||
headers: dict[str, str] | Headers, desired_headers: list[str] | None | ||
) -> dict[str, str]: | ||
if not desired_headers: | ||
return {} | ||
|
||
pass_through_headers: dict[str, str] = {} | ||
for key in desired_headers: | ||
if key in headers: | ||
pass_through_headers[key] = headers[key] | ||
else: | ||
# fastapi makes all header keys lowercase, handling that here | ||
lowercase_key = key.lower() | ||
if lowercase_key in headers: | ||
pass_through_headers[lowercase_key] = headers[lowercase_key] | ||
|
||
return pass_through_headers | ||
|
||
|
||
def get_litellm_additional_request_headers( | ||
headers: dict[str, str] | Headers | ||
) -> dict[str, str]: | ||
return get_relevant_headers(headers, LITELLM_PASS_THROUGH_HEADERS) | ||
|
||
|
||
def build_llm_extra_headers( | ||
additional_headers: dict[str, str] | None = None | ||
) -> dict[str, str]: | ||
extra_headers: dict[str, str] = {} | ||
if additional_headers: | ||
extra_headers.update(additional_headers) | ||
if LITELLM_EXTRA_HEADERS: | ||
extra_headers.update(LITELLM_EXTRA_HEADERS) | ||
return extra_headers | ||
|
||
|
||
def get_custom_tool_additional_request_headers( | ||
headers: dict[str, str] | Headers | ||
) -> dict[str, str]: | ||
return get_relevant_headers(headers, CUSTOM_TOOL_PASS_THROUGH_HEADERS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters