Skip to content

Commit

Permalink
Session id: int -> UUID (#2814)
Browse files Browse the repository at this point in the history
* session id: int -> UUID

* nit

* validated

* validated downgrade + upgrade + all functionality

* nit

* minor nit

* fix test case
  • Loading branch information
pablonyx authored Oct 16, 2024
1 parent f3fb7c5 commit db0779d
Show file tree
Hide file tree
Showing 29 changed files with 276 additions and 106 deletions.
153 changes: 153 additions & 0 deletions backend/alembic/versions/6756efa39ada_id_uuid_for_chat_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
Revision ID: 6756efa39ada
Revises: 5d12a446f5c0
Create Date: 2024-10-15 17:47:44.108537
"""
from alembic import op
import sqlalchemy as sa

revision = "6756efa39ada"
down_revision = "5d12a446f5c0"
branch_labels = None
depends_on = None

"""
Migrate chat_session and chat_message tables to use UUID primary keys.
This script:
1. Adds UUID columns to chat_session and chat_message
2. Populates new columns with UUIDs
3. Updates foreign key relationships
4. Removes old integer ID columns
Note: Downgrade will assign new integer IDs, not restore original ones.
"""


def upgrade() -> None:
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;")

op.add_column(
"chat_session",
sa.Column(
"new_id",
sa.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
)

op.execute("UPDATE chat_session SET new_id = gen_random_uuid();")

op.add_column(
"chat_message",
sa.Column("new_chat_session_id", sa.UUID(as_uuid=True), nullable=True),
)

op.execute(
"""
UPDATE chat_message
SET new_chat_session_id = cs.new_id
FROM chat_session cs
WHERE chat_message.chat_session_id = cs.id;
"""
)

op.drop_constraint(
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
)

op.drop_column("chat_message", "chat_session_id")
op.alter_column(
"chat_message", "new_chat_session_id", new_column_name="chat_session_id"
)

op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
op.drop_column("chat_session", "id")
op.alter_column("chat_session", "new_id", new_column_name="id")

op.create_primary_key("chat_session_pkey", "chat_session", ["id"])

op.create_foreign_key(
"chat_message_chat_session_id_fkey",
"chat_message",
"chat_session",
["chat_session_id"],
["id"],
ondelete="CASCADE",
)


def downgrade() -> None:
op.drop_constraint(
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
)

op.add_column(
"chat_session",
sa.Column("old_id", sa.Integer, autoincrement=True, nullable=True),
)

op.execute("CREATE SEQUENCE chat_session_old_id_seq OWNED BY chat_session.old_id;")
op.execute(
"ALTER TABLE chat_session ALTER COLUMN old_id SET DEFAULT nextval('chat_session_old_id_seq');"
)

op.execute(
"UPDATE chat_session SET old_id = nextval('chat_session_old_id_seq') WHERE old_id IS NULL;"
)

op.alter_column("chat_session", "old_id", nullable=False)

op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
op.create_primary_key("chat_session_pkey", "chat_session", ["old_id"])

op.add_column(
"chat_message",
sa.Column("old_chat_session_id", sa.Integer, nullable=True),
)

op.execute(
"""
UPDATE chat_message
SET old_chat_session_id = cs.old_id
FROM chat_session cs
WHERE chat_message.chat_session_id = cs.id;
"""
)

op.drop_column("chat_message", "chat_session_id")
op.alter_column(
"chat_message", "old_chat_session_id", new_column_name="chat_session_id"
)

op.create_foreign_key(
"chat_message_chat_session_id_fkey",
"chat_message",
"chat_session",
["chat_session_id"],
["old_id"],
ondelete="CASCADE",
)

op.drop_column("chat_session", "id")
op.alter_column("chat_session", "old_id", new_column_name="id")

op.alter_column(
"chat_session",
"id",
type_=sa.Integer(),
existing_type=sa.Integer(),
existing_nullable=False,
existing_server_default=False,
)

# Rename the sequence
op.execute("ALTER SEQUENCE chat_session_old_id_seq RENAME TO chat_session_id_seq;")

# Update the default value to use the renamed sequence
op.alter_column(
"chat_session",
"id",
server_default=sa.text("nextval('chat_session_id_seq'::regclass)"),
)
3 changes: 2 additions & 1 deletion backend/danswer/chat/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from typing import cast
from uuid import UUID

from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -34,7 +35,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo


def create_chat_chain(
chat_session_id: int,
chat_session_id: UUID,
db_session: Session,
prefetch_tool_calls: bool = True,
# Optional id at which we finish processing
Expand Down
22 changes: 11 additions & 11 deletions backend/danswer/db/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@


def get_chat_session_by_id(
chat_session_id: int,
chat_session_id: UUID,
user_id: UUID | None,
db_session: Session,
include_deleted: bool = False,
Expand Down Expand Up @@ -87,9 +87,9 @@ def get_chat_sessions_by_slack_thread_id(


def get_valid_messages_from_query_sessions(
chat_session_ids: list[int],
chat_session_ids: list[UUID],
db_session: Session,
) -> dict[int, str]:
) -> dict[UUID, str]:
user_message_subquery = (
select(
ChatMessage.chat_session_id, func.min(ChatMessage.id).label("user_msg_id")
Expand Down Expand Up @@ -196,7 +196,7 @@ def delete_orphaned_search_docs(db_session: Session) -> None:


def delete_messages_and_files_from_chat_session(
chat_session_id: int, db_session: Session
chat_session_id: UUID, db_session: Session
) -> None:
# Select messages older than cutoff_time with files
messages_with_files = db_session.execute(
Expand Down Expand Up @@ -253,7 +253,7 @@ def create_chat_session(
def update_chat_session(
db_session: Session,
user_id: UUID | None,
chat_session_id: int,
chat_session_id: UUID,
description: str | None = None,
sharing_status: ChatSessionSharedStatus | None = None,
) -> ChatSession:
Expand All @@ -276,7 +276,7 @@ def update_chat_session(

def delete_chat_session(
user_id: UUID | None,
chat_session_id: int,
chat_session_id: UUID,
db_session: Session,
hard_delete: bool = HARD_DELETE_CHATS,
) -> None:
Expand Down Expand Up @@ -337,7 +337,7 @@ def get_chat_message(


def get_chat_messages_by_sessions(
chat_session_ids: list[int],
chat_session_ids: list[UUID],
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
Expand Down Expand Up @@ -370,7 +370,7 @@ def get_search_docs_for_chat_message(


def get_chat_messages_by_session(
chat_session_id: int,
chat_session_id: UUID,
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
Expand All @@ -397,7 +397,7 @@ def get_chat_messages_by_session(


def get_or_create_root_message(
chat_session_id: int,
chat_session_id: UUID,
db_session: Session,
) -> ChatMessage:
try:
Expand Down Expand Up @@ -433,7 +433,7 @@ def get_or_create_root_message(

def reserve_message_id(
db_session: Session,
chat_session_id: int,
chat_session_id: UUID,
parent_message: int,
message_type: MessageType,
) -> int:
Expand All @@ -460,7 +460,7 @@ def reserve_message_id(


def create_new_chat_message(
chat_session_id: int,
chat_session_id: UUID,
parent_message: ChatMessage,
message: str,
prompt_id: int | None,
Expand Down
11 changes: 9 additions & 2 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from typing import Literal
from typing import NotRequired
from typing import Optional
from uuid import uuid4
from typing_extensions import TypedDict # noreorder
from uuid import UUID

from sqlalchemy.dialects.postgresql import UUID as PGUUID

from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
Expand Down Expand Up @@ -920,7 +923,9 @@ class ToolCall(Base):
class ChatSession(Base):
__tablename__ = "chat_session"

id: Mapped[int] = mapped_column(primary_key=True)
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), primary_key=True, default=uuid4
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
Expand Down Expand Up @@ -990,7 +995,9 @@ class ChatMessage(Base):
__tablename__ = "chat_message"

id: Mapped[int] = mapped_column(primary_key=True)
chat_session_id: Mapped[int] = mapped_column(ForeignKey("chat_session.id"))
chat_session_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), ForeignKey("chat_session.id")
)

alternate_assistant_id = mapped_column(
Integer, ForeignKey("persona.id"), nullable=True
Expand Down
4 changes: 3 additions & 1 deletion backend/danswer/server/features/folder/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from uuid import UUID

from pydantic import BaseModel

from danswer.server.query_and_chat.models import ChatSessionDetails
Expand All @@ -23,7 +25,7 @@ class FolderUpdateRequest(BaseModel):


class FolderChatSessionRequest(BaseModel):
chat_session_id: int
chat_session_id: UUID


class DeleteFolderOptions(BaseModel):
Expand Down
7 changes: 4 additions & 3 deletions backend/danswer/server/query_and_chat/chat_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Tuple
from uuid import UUID

from fastapi import APIRouter
from fastapi import Depends
Expand Down Expand Up @@ -131,7 +132,7 @@ def update_chat_session_model(

@router.get("/get-chat-session/{session_id}")
def get_chat_session(
session_id: int,
session_id: UUID,
is_shared: bool = False,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
Expand Down Expand Up @@ -254,7 +255,7 @@ def rename_chat_session(

@router.patch("/chat-session/{session_id}")
def patch_chat_session(
session_id: int,
session_id: UUID,
chat_session_update_req: ChatSessionUpdateRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
Expand All @@ -271,7 +272,7 @@ def patch_chat_session(

@router.delete("/delete-chat-session/{session_id}")
def delete_chat_session_by_id(
session_id: int,
session_id: UUID,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
Expand Down
Loading

0 comments on commit db0779d

Please sign in to comment.