Skip to content
This repository has been archived by the owner on Dec 11, 2024. It is now read-only.

Commit

Permalink
Refactor imports and enhance employee context handling in prompts
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Co <[email protected]>
  • Loading branch information
onimsha committed Nov 15, 2024
1 parent 1b705f4 commit f09afb3
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 55 deletions.
7 changes: 3 additions & 4 deletions backend/danswer/danswerbot/slack/handlers/handle_message.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import datetime

from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session

from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.danswerbot.slack.blocks import get_feedback_reminder_blocks
Expand All @@ -24,6 +20,9 @@
from danswer.db.users import add_non_web_user_if_not_exists
from danswer.utils.logger import setup_logger
from shared_configs.configs import SLACK_CHANNEL_ID
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session

logger_base = setup_logger()

Expand Down
43 changes: 24 additions & 19 deletions backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,6 @@
from typing import Optional
from typing import TypeVar

from fastapi import HTTPException
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from sqlalchemy.orm import Session

from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
Expand Down Expand Up @@ -54,6 +47,15 @@
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
from danswer.utils.logger import DanswerLoggingAdapter
from danswer.utils.logger import setup_logger
from fastapi import HTTPException
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from sqlalchemy.orm import Session

logger = setup_logger()


srl = SlackRateLimiter()
Expand Down Expand Up @@ -101,12 +103,11 @@ def handle_regular_answer(
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
is_bot_msg = message_info.is_bot_msg
user = None
if message_info.is_bot_dm:
if message_info.email:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
user = get_user_by_email(message_info.email, db_session)

if message_info.email:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
user = get_user_by_email(message_info.email, db_session)

document_set_names: list[str] | None = None
persona = slack_bot_config.persona if slack_bot_config else None
Expand Down Expand Up @@ -253,16 +254,20 @@ def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | Non
answer = _get_answer(
DirectQARequest(
messages=messages,
multilingual_query_expansion=saved_search_settings.multilingual_expansion
if saved_search_settings
else None,
multilingual_query_expansion=(
saved_search_settings.multilingual_expansion
if saved_search_settings
else None
),
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
rerank_settings=RerankingDetails.from_db_model(saved_search_settings)
if saved_search_settings
else None,
rerank_settings=(
RerankingDetails.from_db_model(saved_search_settings)
if saved_search_settings
else None
),
)
)
except Exception as e:
Expand Down
9 changes: 7 additions & 2 deletions backend/danswer/llm/answering/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.build import default_build_system_message
from danswer.llm.answering.prompts.build import default_build_user_message
from danswer.llm.answering.prompts.citations_prompt import build_citations_system_message
from danswer.llm.answering.prompts.citations_prompt import (
build_citations_system_message,
)
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
from danswer.llm.answering.stream_processing.citation_processing import (
Expand Down Expand Up @@ -53,7 +55,9 @@
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import check_which_tools_should_run_for_non_tool_calling_llm
from danswer.tools.tool_runner import (
check_which_tools_should_run_for_non_tool_calling_llm,
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.tools.tool_runner import ToolRunner
Expand Down Expand Up @@ -183,6 +187,7 @@ def _update_prompt_builder_for_search_tool(
context_docs=final_context_documents,
history_str=self.single_message_history or "",
prompt=self.prompt_config,
user_email=self.user_email,
)
)

Expand Down
40 changes: 19 additions & 21 deletions backend/danswer/llm/answering/prompts/citations_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,31 @@
from danswer.db.search_settings import get_multilingual_expansion
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import PromptConfig
from danswer.llm.factory import get_llms_for_persona, get_main_llm_from_tuple
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import (
build_content_with_imgs,
check_number_of_tokens,
get_max_input_tokens,
)
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.direct_qa_prompts import (
CITATIONS_PROMPT,
CITATIONS_PROMPT_FOR_TOOL_CALLING,
)
from danswer.prompts.prompt_utils import (
add_date_time_to_prompt,
add_employee_context_to_prompt,
build_complete_context_str,
build_task_prompt_reminders,
)
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import add_employee_context_to_prompt
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.prompts.prompt_utils import build_task_prompt_reminders
from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
from danswer.prompts.token_counts import (
ADDITIONAL_INFO_TOKEN_CNT,
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
CITATION_REMINDER_TOKEN_CNT,
CITATION_STATEMENT_TOKEN_CNT,
LANGUAGE_HINT_TOKEN_CNT,
)
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from langchain.schema.messages import HumanMessage, SystemMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage

logger = setup_logger()

Expand Down Expand Up @@ -136,6 +132,8 @@ def build_citations_system_message(
prompt_str=system_prompt, user_email=user_email
)

logger.debug(f"Built system message: {system_prompt}")

return SystemMessage(content=system_prompt)


Expand Down
33 changes: 31 additions & 2 deletions backend/danswer/llm/answering/prompts/quotes_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from langchain.schema.messages import HumanMessage

from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
Expand All @@ -10,15 +8,21 @@
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import add_employee_context_to_prompt
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from langchain.schema.messages import HumanMessage

logger = setup_logger()


def _build_weak_llm_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
user_email: str | None = None,
) -> HumanMessage:
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
as an option to use with weaker LLMs such as small version, low float precision, quantized,
Expand All @@ -39,6 +43,10 @@ def _build_weak_llm_quotes_prompt(
if prompt.datetime_aware:
prompt_str = add_date_time_to_prompt(prompt_str=prompt_str)

if user_email:
prompt_str = add_employee_context_to_prompt(
prompt_str=prompt_str, user_email=user_email
)
return HumanMessage(content=prompt_str)


Expand All @@ -47,7 +55,21 @@ def _build_strong_llm_quotes_prompt(
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
user_email: str | None = None,
) -> HumanMessage:
"""
Constructs a prompt for the language model based on the provided inputs.
Args:
question (str): The user's query.
context_docs (list[LlmDoc] | list[InferenceChunk]): List of context documents or inference chunks.
history_str (str): The conversation history.
prompt (PromptConfig): The prompt configuration.
user_email (str, optional): The user's email. Defaults to None.
Returns:
HumanMessage: The constructed prompt.
"""
use_language_hint = bool(get_multilingual_expansion())

context_block = ""
Expand All @@ -71,6 +93,11 @@ def _build_strong_llm_quotes_prompt(
if prompt.datetime_aware:
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)

if user_email:
full_prompt = add_employee_context_to_prompt(
prompt_str=full_prompt, user_email=user_email
)

return HumanMessage(content=full_prompt)


Expand All @@ -79,6 +106,7 @@ def build_quotes_user_message(
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
user_email: str,
) -> HumanMessage:
prompt_builder = (
_build_weak_llm_quotes_prompt
Expand All @@ -91,6 +119,7 @@ def build_quotes_user_message(
context_docs=context_docs,
history_str=history_str,
prompt=prompt,
user_email=user_email,
)


Expand Down
22 changes: 15 additions & 7 deletions backend/danswer/prompts/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,32 +61,40 @@ def add_date_time_to_prompt(prompt_str: str) -> str:
+ " "
+ BASIC_TIME_STR.format(datetime_info=get_current_llm_day_time())
)


# Initialize Redis client
redis_client = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)


def add_employee_context_to_prompt(prompt_str: str, user_email: str) -> str:
# Check Redis for cached employee context
cached_context = redis_client.get(user_email)
if cached_context:
logger.info("Employee context retrieved from Redis.")
return prompt_str.replace(DANSWER_EMPLOYEE_REPLACEMENT, cached_context.decode('utf-8'))
return prompt_str.replace(
DANSWER_EMPLOYEE_REPLACEMENT, cached_context.decode("utf-8")
)

airtable_client = AirtableApi(AIRTABLE_API_TOKEN)
all_employees = airtable_client.table(AIRTABLE_EMPLOYEE_BASE_ID, AIRTABLE_EMPLOYEE_TABLE_NAME_OR_ID).all()
all_employees = airtable_client.table(
AIRTABLE_EMPLOYEE_BASE_ID, AIRTABLE_EMPLOYEE_TABLE_NAME_OR_ID
).all()

for employee in all_employees:
if "fields" in employee and "MV Email" in employee["fields"]:
if employee["fields"]["MV Email"] == user_email:
logger.info(f"Employee found: {employee['fields']['Preferred Name']}")
employee_context = f"My Name: {employee['fields']['Preferred Name']}\nMy Title: {employee['fields']['Job Role']}\nMy City Office: {employee['fields']['City Office']}\nMy Division: {employee['fields']['Import: Division']}\nMy Manager: {employee['fields']['Reports To']}\nMy Department: {employee['fields']['Import: Department']}"
employee_context = f"My Name: {employee['fields']['Preferred Name']}\nMy Title: {employee['fields']['Job Role']}\nMy City Office: {employee['fields']['City Office']}\nMy Division: {employee['fields']['Import: Division']}\nMy Manager: {employee['fields']['Reports To']}\nMy Department: {employee['fields']['Import: Department']}\nMy Employment Status: {employee['fields']['Employment Status']}"

# Store the employee context in Redis with a TTL of 30 days
redis_client.setex(user_email, 30 * 24 * 60 * 60, employee_context)
redis_client.setex(user_email, 7 * 24 * 60 * 60, employee_context)
break

if DANSWER_EMPLOYEE_REPLACEMENT in prompt_str:
return prompt_str.replace(DANSWER_EMPLOYEE_REPLACEMENT, employee_context)



def build_task_prompt_reminders(
prompt: Prompt | PromptConfig,
use_language_hint: bool,
Expand Down

0 comments on commit f09afb3

Please sign in to comment.