Skip to content

Commit

Permalink
Fix Prompt for Non Function Calling LLMs (#3241)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 authored Nov 24, 2024
1 parent 413891f commit 3466451
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
2 changes: 2 additions & 0 deletions backend/danswer/llm/answering/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:

# DEBUG: good breakpoint
stream = self.llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
Expand Down
8 changes: 2 additions & 6 deletions backend/danswer/llm/answering/prompts/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __init__(
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
raw_user_text: str,
single_message_history: str | None = None,
raw_user_text: str | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)

Expand Down Expand Up @@ -89,11 +89,7 @@ def __init__(

self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []

self.raw_user_message = (
HumanMessage(content=raw_user_text)
if raw_user_text is not None
else user_message
)
self.raw_user_message = raw_user_text

def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
Expand Down
6 changes: 3 additions & 3 deletions backend/danswer/llm/answering/tool/tool_response_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_tool_call_for_non_tool_calling_llm(
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
Expand All @@ -76,7 +76,7 @@ def get_tool_call_for_non_tool_calling_llm(
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
Expand All @@ -95,7 +95,7 @@ def get_tool_call_for_non_tool_calling_llm(
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
llm=llm,
)
if available_tools_and_args
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import cast

from langchain_core.messages import HumanMessage

from danswer.chat.models import LlmDoc
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PromptConfig
Expand Down Expand Up @@ -58,9 +60,11 @@ def build_next_prompt_for_search_like_tool(
# For Quotes, the system prompt is included in the user prompt
prompt_builder.update_system_prompt(None)

human_message = HumanMessage(content=prompt_builder.raw_user_message)

prompt_builder.update_user_prompt(
build_quotes_user_message(
message=prompt_builder.raw_user_message,
message=human_message,
context_docs=final_context_documents,
history_str=prompt_builder.single_message_history or "",
prompt=prompt_config,
Expand Down

0 comments on commit 3466451

Please sign in to comment.