Skip to content

Commit

Permalink
Add assistant notifications + update assistant context (#2816)
Browse files Browse the repository at this point in the history
* add assistant notifications

* nit

* update context

* validated

* ensure context passed properly

* validated + cleaned

* nit: naming

* k

* k

* final validation + new ui

* nit + video

* nit

* nit

* nit

* k

* fix typos
  • Loading branch information
pablonyx authored Oct 19, 2024
1 parent 6913efe commit 8b220d2
Show file tree
Hide file tree
Showing 37 changed files with 822 additions and 346 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""add additional data to notifications
Revision ID: 1b10e1fda030
Revises: 6756efa39ada
Create Date: 2024-10-15 19:26:44.071259
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "1b10e1fda030"
down_revision = "6756efa39ada"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"notification", sa.Column("additional_data", postgresql.JSONB(), nullable=True)
)


def downgrade() -> None:
op.drop_column("notification", "additional_data")
1 change: 1 addition & 0 deletions backend/danswer/configs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class DocumentSource(str, Enum):

class NotificationType(str, Enum):
REINDEX = "reindex"
PERSONA_SHARED = "persona_shared"


class BlobType(str, Enum):
Expand Down
3 changes: 3 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ class Notification(Base):
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))

user: Mapped[User] = relationship("User", back_populates="notifications")
additional_data: Mapped[dict | None] = mapped_column(
postgresql.JSONB(), nullable=True
)


"""
Expand Down
27 changes: 25 additions & 2 deletions backend/danswer/db/notification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from uuid import UUID

from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
Expand All @@ -8,16 +10,37 @@


def create_notification(
user: User | None,
user_id: UUID | None,
notif_type: NotificationType,
db_session: Session,
additional_data: dict | None = None,
) -> Notification:
# Check if an undismissed notification of the same type and data exists
existing_notification = (
db_session.query(Notification)
.filter_by(
user_id=user_id,
notif_type=notif_type,
dismissed=False,
)
.filter(Notification.additional_data == additional_data)
.first()
)

if existing_notification:
# Update the last_shown timestamp
existing_notification.last_shown = func.now()
db_session.commit()
return existing_notification

# Create a new notification if none exists
notification = Notification(
user_id=user.id if user else None,
user_id=user_id,
notif_type=notif_type,
dismissed=False,
last_shown=func.now(),
first_shown=func.now(),
additional_data=additional_data,
)
db_session.add(notification)
db_session.commit()
Expand Down
2 changes: 2 additions & 0 deletions backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
admin_router as admin_input_prompt_router,
)
from danswer.server.features.input_prompt.api import basic_router as input_prompt_router
from danswer.server.features.notifications.api import router as notification_router
from danswer.server.features.persona.api import admin_router as admin_persona_router
from danswer.server.features.persona.api import basic_router as persona_router
from danswer.server.features.prompt.api import basic_router as prompt_router
Expand Down Expand Up @@ -246,6 +247,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, input_prompt_router)
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, notification_router)
include_router_with_global_prefix_prepended(application, prompt_router)
include_router_with_global_prefix_prepended(application, tool_router)
include_router_with_global_prefix_prepended(application, admin_tool_router)
Expand Down
47 changes: 47 additions & 0 deletions backend/danswer/server/features/notifications/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session

from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import dismiss_notification
from danswer.db.notification import get_notification_by_id
from danswer.db.notification import get_notifications
from danswer.server.settings.models import Notification as NotificationModel
from danswer.utils.logger import setup_logger

logger = setup_logger()

router = APIRouter(prefix="/notifications")


@router.get("")
def get_notifications_api(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[NotificationModel]:
notifications = [
NotificationModel.from_model(notif)
for notif in get_notifications(user, db_session, include_dismissed=False)
]
return notifications


@router.post("/{notification_id}/dismiss")
def dismiss_notification_endpoint(
notification_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
notification = get_notification_by_id(notification_id, user, db_session)
except PermissionError:
raise HTTPException(
status_code=403, detail="Not authorized to dismiss this notification"
)
except ValueError:
raise HTTPException(status_code=404, detail="Notification not found")

dismiss_notification(notification, db_session)
48 changes: 36 additions & 12 deletions backend/danswer/server/features/persona/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import NotificationType
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import create_notification
from danswer.db.persona import create_update_persona
from danswer.db.persona import get_persona_by_id
from danswer.db.persona import get_personas
Expand All @@ -28,6 +30,7 @@
from danswer.file_store.models import ChatFileType
from danswer.llm.answering.prompts.utils import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSharedNotificationData
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.features.persona.models import PromptTemplateResponse
from danswer.server.models import DisplayPriorityRequest
Expand Down Expand Up @@ -183,11 +186,12 @@ class PersonaShareRequest(BaseModel):
user_ids: list[UUID]


# We notify each user when a user is shared with them
@basic_router.patch("/{persona_id}/share")
def share_persona(
persona_id: int,
persona_share_request: PersonaShareRequest,
user: User | None = Depends(current_user),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_shared_users(
Expand All @@ -197,6 +201,18 @@ def share_persona(
db_session=db_session,
)

for user_id in persona_share_request.user_ids:
# Don't notify the user that they have access to their own persona
if user_id != user.id:
create_notification(
user_id=user_id,
notif_type=NotificationType.PERSONA_SHARED,
db_session=db_session,
additional_data=PersonaSharedNotificationData(
persona_id=persona_id,
).model_dump(),
)


@basic_router.delete("/{persona_id}")
def delete_persona(
Expand All @@ -216,23 +232,31 @@ def list_personas(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
include_deleted: bool = False,
persona_ids: list[int] = Query(None),
) -> list[PersonaSnapshot]:
return [
PersonaSnapshot.from_model(persona)
for persona in get_personas(
user=user,
include_deleted=include_deleted,
db_session=db_session,
get_editable=False,
joinedload_all=True,
)
# If the persona has an image generation tool and it's not available, don't include it
personas = get_personas(
user=user,
include_deleted=include_deleted,
db_session=db_session,
get_editable=False,
joinedload_all=True,
)

if persona_ids:
personas = [p for p in personas if p.id in persona_ids]

# Filter out personas with unavailable tools
personas = [
p
for p in personas
if not (
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in persona.tools)
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in p.tools)
and not is_image_generation_available(db_session=db_session)
)
]

return [PersonaSnapshot.from_model(p) for p in personas]


@basic_router.get("/{persona_id}")
def get_persona(
Expand Down
4 changes: 4 additions & 0 deletions backend/danswer/server/features/persona/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,7 @@ def from_model(

class PromptTemplateResponse(BaseModel):
final_prompt_template: str


class PersonaSharedNotificationData(BaseModel):
persona_id: int
26 changes: 3 additions & 23 deletions backend/danswer/server/settings/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from danswer.db.models import User
from danswer.db.notification import create_notification
from danswer.db.notification import dismiss_all_notifications
from danswer.db.notification import dismiss_notification
from danswer.db.notification import get_notification_by_id
from danswer.db.notification import get_notifications
from danswer.db.notification import update_notification_last_shown
from danswer.key_value_store.factory import get_kv_store
Expand Down Expand Up @@ -55,7 +53,7 @@ def fetch_settings(
"""Settings and notifications are stuffed into this single endpoint to reduce number of
Postgres calls"""
general_settings = load_settings()
user_notifications = get_user_notifications(user, db_session)
user_notifications = get_reindex_notification(user, db_session)

try:
kv_store = get_kv_store()
Expand All @@ -70,25 +68,7 @@ def fetch_settings(
)


@basic_router.post("/notifications/{notification_id}/dismiss")
def dismiss_notification_endpoint(
notification_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
notification = get_notification_by_id(notification_id, user, db_session)
except PermissionError:
raise HTTPException(
status_code=403, detail="Not authorized to dismiss this notification"
)
except ValueError:
raise HTTPException(status_code=404, detail="Notification not found")

dismiss_notification(notification, db_session)


def get_user_notifications(
def get_reindex_notification(
user: User | None, db_session: Session
) -> list[Notification]:
"""Get notifications for the user, currently the logic is very specific to the reindexing flag"""
Expand Down Expand Up @@ -121,7 +101,7 @@ def get_user_notifications(

if not reindex_notifs:
notif = create_notification(
user=user,
user_id=user.id if user else None,
notif_type=NotificationType.REINDEX,
db_session=db_session,
)
Expand Down
2 changes: 2 additions & 0 deletions backend/danswer/server/settings/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Notification(BaseModel):
dismissed: bool
last_shown: datetime
first_shown: datetime
additional_data: dict | None = None

@classmethod
def from_model(cls, notif: NotificationDBModel) -> "Notification":
Expand All @@ -33,6 +34,7 @@ def from_model(cls, notif: NotificationDBModel) -> "Notification":
dismissed=notif.dismissed,
last_shown=notif.last_shown,
first_shown=notif.first_shown,
additional_data=notif.additional_data,
)


Expand Down
12 changes: 10 additions & 2 deletions web/src/app/admin/settings/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@ export interface Settings {
product_gating: GatingType;
}

export enum NotificationType {
PERSONA_SHARED = "persona_shared",
REINDEX_NEEDED = "reindex_needed",
}

export interface Notification {
id: number;
notif_type: string;
time_created: string;
dismissed: boolean;
last_shown: string;
first_shown: string;
additional_data?: {
persona_id?: number;
[key: string]: any;
};
}

export interface NavigationItem {
Expand Down
14 changes: 3 additions & 11 deletions web/src/app/assistants/SidebarWrapper.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,9 @@ interface SidebarWrapperProps<T extends object> {
folders?: Folder[];
initiallyToggled: boolean;
openedFolders?: { [key: number]: boolean };
content: (props: T) => ReactNode;
headerProps: {
page: pageType;
user: User | null;
};
contentProps: T;
page: pageType;
size?: "sm" | "lg";
children: ReactNode;
}

export default function SidebarWrapper<T extends object>({
Expand All @@ -42,10 +37,8 @@ export default function SidebarWrapper<T extends object>({
folders,
openedFolders,
page,
headerProps,
contentProps,
content,
size = "sm",
children,
}: SidebarWrapperProps<T>) {
const [toggledSidebar, setToggledSidebar] = useState(initiallyToggled);
const [showDocSidebar, setShowDocSidebar] = useState(false); // State to track if sidebar is open
Expand Down Expand Up @@ -144,7 +137,6 @@ export default function SidebarWrapper<T extends object>({
sidebarToggled={toggledSidebar}
toggleSidebar={toggleSidebar}
page="assistants"
user={headerProps.user}
/>
<div className="w-full flex">
<div
Expand All @@ -163,7 +155,7 @@ export default function SidebarWrapper<T extends object>({
<div
className={`mt-4 w-full ${size == "lg" ? "max-w-4xl" : "max-w-3xl"} mx-auto`}
>
{content(contentProps)}
{children}
</div>
</div>
</div>
Expand Down
Loading

0 comments on commit 8b220d2

Please sign in to comment.