Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Prompt for Non Function Calling LLMs #3241

Merged
merged 1 commit into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/danswer/llm/answering/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:

# DEBUG: good breakpoint
stream = self.llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
Expand Down
8 changes: 2 additions & 6 deletions backend/danswer/llm/answering/prompts/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __init__(
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
raw_user_text: str,
single_message_history: str | None = None,
raw_user_text: str | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)

Expand Down Expand Up @@ -89,11 +89,7 @@ def __init__(

self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []

self.raw_user_message = (
HumanMessage(content=raw_user_text)
if raw_user_text is not None
else user_message
)
self.raw_user_message = raw_user_text

def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
Expand Down
6 changes: 3 additions & 3 deletions backend/danswer/llm/answering/tool/tool_response_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_tool_call_for_non_tool_calling_llm(
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
Expand All @@ -76,7 +76,7 @@ def get_tool_call_for_non_tool_calling_llm(
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
Expand All @@ -95,7 +95,7 @@ def get_tool_call_for_non_tool_calling_llm(
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
llm=llm,
)
if available_tools_and_args
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import cast

from langchain_core.messages import HumanMessage

from danswer.chat.models import LlmDoc
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PromptConfig
Expand Down Expand Up @@ -58,9 +60,11 @@ def build_next_prompt_for_search_like_tool(
# For Quotes, the system prompt is included in the user prompt
prompt_builder.update_system_prompt(None)

human_message = HumanMessage(content=prompt_builder.raw_user_message)

prompt_builder.update_user_prompt(
build_quotes_user_message(
message=prompt_builder.raw_user_message,
message=human_message,
context_docs=final_context_documents,
history_str=prompt_builder.single_message_history or "",
prompt=prompt_config,
Expand Down
Loading