Skip to content

Commit

Permalink
add multiple formats to tools (#3041)
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx authored Nov 3, 2024
1 parent c2d04f5 commit c6e8bf2
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 26 deletions.
4 changes: 2 additions & 2 deletions backend/danswer/chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class QAResponse(SearchResponse, DanswerAnswer):
error_msg: str | None = None


class ImageGenerationDisplay(BaseModel):
class FileChatDisplay(BaseModel):
file_ids: list[str]


Expand All @@ -170,7 +170,7 @@ class CustomToolResponse(BaseModel):
| DanswerQuotes
| CitationInfo
| DanswerContexts
| ImageGenerationDisplay
| FileChatDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo
Expand Down
35 changes: 27 additions & 8 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import FileChatDisplay
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
Expand Down Expand Up @@ -275,7 +275,7 @@ def _get_force_search_settings(
| DanswerAnswerPiece
| AllCitations
| CitationInfo
| ImageGenerationDisplay
| FileChatDisplay
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
Expand Down Expand Up @@ -769,7 +769,6 @@ def stream_chat_message_objects(
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)

elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
Expand All @@ -787,7 +786,7 @@ def stream_chat_message_objects(
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield ImageGenerationDisplay(
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
Expand All @@ -801,10 +800,30 @@ def stream_chat_message_objects(
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)

if (
custom_tool_response.response_type == "image"
or custom_tool_response.response_type == "csv"
):
file_ids = custom_tool_response.tool_result.file_ids
ai_message_files = [
FileDescriptor(
id=str(file_id),
type=ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV,
)
for file_id in file_ids
]
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
else:
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)

elif isinstance(packet, StreamStopInfo):
pass
else:
Expand Down
18 changes: 13 additions & 5 deletions backend/danswer/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,10 @@ def translate_danswer_msg_to_langchain(
files: list[InMemoryChatFile] = []

# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now. Also, OpenAI doesn't allow files to
# be attached to AI messages, so we must remove them
if not isinstance(msg, ChatMessage) and msg.message_type != MessageType.ASSISTANT:
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(msg.message, files)
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)

if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
Expand Down Expand Up @@ -188,10 +187,19 @@ def build_content_with_imgs(
message: str,
files: list[InMemoryChatFile] | None = None,
img_urls: list[str] | None = None,
message_type: MessageType = MessageType.USER,
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
files = files or []
img_files = [file for file in files if file.file_type == ChatFileType.IMAGE]

# Only include image files for user messages
img_files = (
[file for file in files if file.file_type == ChatFileType.IMAGE]
if message_type == MessageType.USER
else []
)

img_urls = img_urls or []

message_main_content = _build_content(message, files)

if not img_files and not img_urls:
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/one_shot_answer/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def stream_answer_objects(
return_contexts=query_req.return_contexts,
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
)
# won't be any ImageGenerationDisplay responses since that tool is never passed in
# won't be any FileChatDisplay responses since that tool is never passed in
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse):
Expand Down
151 changes: 146 additions & 5 deletions backend/danswer/tools/tool_implementations/custom/custom_tool.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
import csv
import json
import uuid
from collections.abc import Generator
from io import BytesIO
from io import StringIO
from typing import Any
from typing import cast
from typing import Dict
from typing import List

import requests
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic import BaseModel

from danswer.configs.constants import FileOrigin
from danswer.db.engine import get_session_with_tenant
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
from danswer.key_value_store.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.interfaces import LLM
from danswer.tools.base_tool import BaseTool
from danswer.tools.message import ToolCallSummary
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import MESSAGE_ID_PLACEHOLDER
from danswer.tools.models import ToolResponse
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
from danswer.tools.tool_implementations.custom.custom_tool_prompts import (
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
)
Expand All @@ -39,6 +51,9 @@
from danswer.tools.tool_implementations.custom.openapi_parsing import (
validate_openapi_schema,
)
from danswer.tools.tool_implementations.custom.prompt import (
build_custom_image_generation_user_prompt,
)
from danswer.utils.headers import header_list_to_header_dict
from danswer.utils.headers import HeaderItemDict
from danswer.utils.logger import setup_logger
Expand All @@ -48,9 +63,14 @@
CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response"


class CustomToolFileResponse(BaseModel):
file_ids: List[str] # References to saved images or CSVs


class CustomToolCallSummary(BaseModel):
tool_name: str
tool_result: ToolResultType
response_type: str # e.g., 'json', 'image', 'csv', 'graph'
tool_result: Any # The response data


class CustomTool(BaseTool):
Expand Down Expand Up @@ -91,6 +111,12 @@ def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
response = cast(CustomToolCallSummary, args[0].response)

if response.response_type == "image" or response.response_type == "csv":
image_response = cast(CustomToolFileResponse, response.tool_result)
return json.dumps({"file_ids": image_response.file_ids})

# For JSON or other responses, return as-is
return json.dumps(response.tool_result)

"""For LLMs which do NOT support explicit tool calling"""
Expand Down Expand Up @@ -158,6 +184,38 @@ def get_args_for_non_tool_calling_llm(
)
return None

def _save_and_get_file_references(
self, file_content: bytes | str, content_type: str
) -> List[str]:
with get_session_with_tenant() as db_session:
file_store = get_default_file_store(db_session)

file_id = str(uuid.uuid4())

# Handle both binary and text content
if isinstance(file_content, str):
content = BytesIO(file_content.encode())
else:
content = BytesIO(file_content)

file_store.save_file(
file_name=file_id,
content=content,
display_name=file_id,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=content_type,
file_metadata={
"content_type": content_type,
},
)

return [file_id]

def _parse_csv(self, csv_text: str) -> List[Dict[str, Any]]:
csv_file = StringIO(csv_text)
reader = csv.DictReader(csv_file)
return [row for row in reader]

"""Actual execution of the tool"""

def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
Expand All @@ -177,20 +235,103 @@ def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:

url = self._method_spec.build_url(self._base_url, path_params, query_params)
method = self._method_spec.method
# Log request details

response = requests.request(
method, url, json=request_body, headers=self.headers
)
content_type = response.headers.get("Content-Type", "")

if "text/csv" in content_type:
file_ids = self._save_and_get_file_references(
response.content, content_type
)
tool_result = CustomToolFileResponse(file_ids=file_ids)
response_type = "csv"

elif "image/" in content_type:
file_ids = self._save_and_get_file_references(
response.content, content_type
)
tool_result = CustomToolFileResponse(file_ids=file_ids)
response_type = "image"

else:
tool_result = response.json()
response_type = "json"

logger.info(
f"Returning tool response for {self._name} with type {response_type}"
)

yield ToolResponse(
id=CUSTOM_TOOL_RESPONSE_ID,
response=CustomToolCallSummary(
tool_name=self._name, tool_result=response.json()
tool_name=self._name,
response_type=response_type,
tool_result=tool_result,
),
)

def build_next_prompt(
self,
prompt_builder: AnswerPromptBuilder,
tool_call_summary: ToolCallSummary,
tool_responses: list[ToolResponse],
using_tool_calling_llm: bool,
) -> AnswerPromptBuilder:
response = cast(CustomToolCallSummary, tool_responses[0].response)

# Handle non-file responses using parent class behavior
if response.response_type not in ["image", "csv"]:
return super().build_next_prompt(
prompt_builder,
tool_call_summary,
tool_responses,
using_tool_calling_llm,
)

# Handle image and CSV file responses
file_type = (
ChatFileType.IMAGE
if response.response_type == "image"
else ChatFileType.CSV
)

# Load files from storage
files = []
with get_session_with_tenant() as db_session:
file_store = get_default_file_store(db_session)

for file_id in response.tool_result.file_ids:
try:
file_io = file_store.read_file(file_id, mode="b")
files.append(
InMemoryChatFile(
file_id=file_id,
filename=file_id,
content=file_io.read(),
file_type=file_type,
)
)
except Exception:
logger.exception(f"Failed to read file {file_id}")

# Update prompt with file content
prompt_builder.update_user_prompt(
build_custom_image_generation_user_prompt(
query=prompt_builder.get_user_message_content(),
files=files,
file_type=file_type,
)
)

return prompt_builder

def final_result(self, *args: ToolResponse) -> JSON_ro:
return cast(CustomToolCallSummary, args[0].response).tool_result
response = cast(CustomToolCallSummary, args[0].response)
if isinstance(response.tool_result, CustomToolFileResponse):
return response.tool_result.model_dump()
return response.tool_result


def build_custom_tools_from_openapi_schema_and_headers(
Expand Down
25 changes: 25 additions & 0 deletions backend/danswer/tools/tool_implementations/custom/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from langchain_core.messages import HumanMessage

from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.utils import build_content_with_imgs


CUSTOM_IMG_GENERATION_SUMMARY_PROMPT = """
You have just created the attached {file_type} file in response to the following query: "{query}".
Can you please summarize it in a sentence or two? Do NOT include image urls or bulleted lists.
"""


def build_custom_image_generation_user_prompt(
query: str, file_type: ChatFileType, files: list[InMemoryChatFile] | None = None
) -> HumanMessage:
return HumanMessage(
content=build_content_with_imgs(
message=CUSTOM_IMG_GENERATION_SUMMARY_PROMPT.format(
query=query, file_type=file_type.value
).strip(),
files=files,
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def test_custom_tool_final_result(self) -> None:
mock_response = ToolResponse(
id=CUSTOM_TOOL_RESPONSE_ID,
response=CustomToolCallSummary(
response_type="json",
tool_name="getAssistant",
tool_result={"id": "789", "name": "Final Assistant"},
),
Expand Down
Loading

0 comments on commit c6e8bf2

Please sign in to comment.