Skip to content

Commit

Permalink
added backend endpoint for duplicating a chat session from Slack
Browse files Browse the repository at this point in the history
  • Loading branch information
hagen-danswer committed Nov 24, 2024
1 parent bc798db commit 144d6b2
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 0 deletions.
103 changes: 103 additions & 0 deletions backend/danswer/db/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import timedelta
from uuid import UUID

from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
Expand Down Expand Up @@ -30,6 +31,7 @@
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_best_persona_id_for_user
from danswer.db.pg_file_store import delete_lobj_by_name
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
Expand Down Expand Up @@ -250,6 +252,43 @@ def create_chat_session(
return chat_session


def duplicate_chat_session_for_user_from_slack(
db_session: Session,
user: User | None,
chat_session_id: UUID,
) -> ChatSession:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=None, # Ignore user permissions for this
db_session=db_session,
)
if not chat_session:
raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided")

# This enforces permissions and sets a default
new_persona_id = get_best_persona_id_for_user(
db_session=db_session,
user=user,
persona_id=chat_session.persona_id,
)

return create_chat_session(
db_session=db_session,
user_id=user.id if user else None,
persona_id=new_persona_id,
# This will likely be empty but the frontend will force a rename
description=chat_session.description,
llm_override=chat_session.llm_override,
prompt_override=chat_session.prompt_override,
# Chat sessions from Slack should put people in the chat UI, not the search
one_shot=False,
# Chat is in UI now so this is false
danswerbot_flow=False,
# Maybe we want this in the future to track if it was created from Slack
slack_thread_id=None,
)


def update_chat_session(
db_session: Session,
user_id: UUID | None,
Expand Down Expand Up @@ -377,6 +416,70 @@ def get_chat_messages_by_sessions(
return db_session.execute(stmt).scalars().all()


def add_chats_to_session_from_slack_thread(
db_session: Session,
slack_chat_session_id: UUID,
new_chat_session_id: UUID,
) -> None:
new_root_message = ChatMessage(
chat_session_id=new_chat_session_id,
prompt_id=None,
parent_message=None,
latest_child_message=None,
message="",
token_count=0,
message_type=MessageType.SYSTEM,
)
db_session.add(new_root_message)
db_session.commit()

user_message = None
assistant_message = None
for chat_message in get_chat_messages_by_sessions(
chat_session_ids=[slack_chat_session_id],
user_id=None, # Ignore user permissions for this
db_session=db_session,
skip_permission_check=True,
):
# Should only be 3 messages in a Slack chat session
if chat_message.message_type == MessageType.SYSTEM:
continue
elif chat_message.message_type == MessageType.USER:
user_message = chat_message
elif chat_message.message_type == MessageType.ASSISTANT:
assistant_message = chat_message

if user_message is None or assistant_message is None:
raise HTTPException(
status_code=500,
detail="Couldnt find all messages in Slack chat session",
)

new_user_message = create_new_chat_message(
db_session=db_session,
chat_session_id=new_chat_session_id,
parent_message=new_root_message,
message=user_message.message,
prompt_id=user_message.prompt_id,
token_count=user_message.token_count,
message_type=MessageType.USER,
)
db_session.add(new_user_message)
db_session.commit()

new_assistant_message = create_new_chat_message(
db_session=db_session,
chat_session_id=new_chat_session_id,
parent_message=new_user_message,
message=assistant_message.message,
prompt_id=assistant_message.prompt_id,
token_count=assistant_message.token_count,
message_type=MessageType.ASSISTANT,
)
db_session.add(new_assistant_message)
db_session.commit()


def get_search_docs_for_chat_message(
chat_message_id: int, db_session: Session
) -> list[SearchDoc]:
Expand Down
24 changes: 24 additions & 0 deletions backend/danswer/db/persona.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,30 @@ def fetch_persona_by_id(
return persona


def get_best_persona_id_for_user(
db_session: Session, user: User | None, persona_id: int | None = None
) -> int | None:
if persona_id:
stmt = select(Persona).where(Persona.id == persona_id).distinct()
stmt = _add_user_filters(
stmt=stmt,
user=user,
# We don't want to filter by editable here, we just want to see if the
# persona is usable by the user
get_editable=False,
)
persona = db_session.scalars(stmt).one_or_none()
if persona:
return persona.id

# If the persona is not found, we need to find the best persona for the user
# This is the persona with the highest display priority that the user has access to
stmt = select(Persona).order_by(Persona.display_priority.desc()).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=True)
persona = db_session.scalars(stmt).one_or_none()
return persona.id if persona else None


def _get_persona_by_name(
persona_name: str, user: User | None, db_session: Session
) -> Persona | None:
Expand Down
34 changes: 34 additions & 0 deletions backend/danswer/server/query_and_chat/chat_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
from danswer.db.chat import add_chats_to_session_from_slack_thread
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import delete_chat_session
from danswer.db.chat import duplicate_chat_session_for_user_from_slack
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
Expand Down Expand Up @@ -532,6 +534,38 @@ def seed_chat(
)


class SeedChatFromSlackRequest(BaseModel):
chat_session_id: UUID


class SeedChatFromSlackResponse(BaseModel):
redirect_url: str


@router.post("/seed-chat-session-from-slack")
def seed_chat_from_slack(
chat_seed_request: SeedChatFromSlackRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SeedChatFromSlackResponse:
slack_chat_session_id = chat_seed_request.chat_session_id
new_chat_session = duplicate_chat_session_for_user_from_slack(
db_session=db_session,
user=user,
chat_session_id=slack_chat_session_id,
)

add_chats_to_session_from_slack_thread(
db_session=db_session,
slack_chat_session_id=slack_chat_session_id,
new_chat_session_id=new_chat_session.id,
)

return SeedChatFromSlackResponse(
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}"
)


"""File upload"""


Expand Down

0 comments on commit 144d6b2

Please sign in to comment.