From bc9d5fece702daf3264afc30cc75e09d7171a639 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Wed, 18 Sep 2024 21:01:15 -0700 Subject: [PATCH 001/505] prevent trying to submit to jobclient when it can't take any more work (reduces log spam) (#2482) --- backend/danswer/background/update.py | 42 +++++++++++++++++----------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 5fde7cb3da0..55eff6d0a46 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -211,7 +211,6 @@ def cleanup_indexing_jobs( timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT, ) -> dict[int, Future | SimpleJob]: existing_jobs_copy = existing_jobs.copy() - # clean up completed jobs with Session(get_sqlalchemy_engine()) as db_session: for attempt_id, job in existing_jobs.items(): @@ -312,7 +311,12 @@ def kickoff_indexing_jobs( indexing_attempt_count = 0 + primary_client_full = False + secondary_client_full = False for attempt, search_settings in new_indexing_attempts: + if primary_client_full and secondary_client_full: + break + use_secondary_index = ( search_settings.status == IndexModelStatus.FUTURE if search_settings is not None @@ -337,22 +341,28 @@ def kickoff_indexing_jobs( ) continue - if use_secondary_index: - run = secondary_client.submit( - run_indexing_entrypoint, - attempt.id, - attempt.connector_credential_pair_id, - global_version.get_is_ee_version(), - pure=False, - ) + if not use_secondary_index: + if not primary_client_full: + run = client.submit( + run_indexing_entrypoint, + attempt.id, + attempt.connector_credential_pair_id, + global_version.get_is_ee_version(), + pure=False, + ) + if not run: + primary_client_full = True else: - run = client.submit( - run_indexing_entrypoint, - attempt.id, - attempt.connector_credential_pair_id, - global_version.get_is_ee_version(), - pure=False, - ) + if not secondary_client_full: + run = secondary_client.submit( + run_indexing_entrypoint, + attempt.id, + attempt.connector_credential_pair_id, + global_version.get_is_ee_version(), + pure=False, + ) + if not run: + secondary_client_full = True if run: if indexing_attempt_count == 0: From 3884f1d70a48b4a3414ed56576c9f0052277e044 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Wed, 18 Sep 2024 22:36:07 -0700 Subject: [PATCH 002/505] Bugfix/larger test runner (#2508) * add pip retries to the github workflows too * let's try running on amd64 ... docker builds are unusually flaky * bump * try large * no yaml anchors * switch back down to Amd64 --------- Co-authored-by: Richard Kuo --- .github/workflows/pr-helm-chart-testing.yml.disabled.txt | 6 +++--- .github/workflows/pr-python-checks.yml | 6 +++--- .github/workflows/pr-python-connector-tests.yml | 4 ++-- .github/workflows/pr-python-tests.yml | 4 ++-- .github/workflows/run-it.yml | 9 ++++----- 5 files changed, 14 insertions(+), 15 deletions(-) diff --git a/.github/workflows/pr-helm-chart-testing.yml.disabled.txt b/.github/workflows/pr-helm-chart-testing.yml.disabled.txt index 7c4903a07f7..d2d7c7d5b13 100644 --- a/.github/workflows/pr-helm-chart-testing.yml.disabled.txt +++ b/.github/workflows/pr-helm-chart-testing.yml.disabled.txt @@ -37,9 +37,9 @@ jobs: backend/requirements/model_server.txt - run: | python -m pip install --upgrade pip - pip install -r backend/requirements/default.txt - pip install -r backend/requirements/dev.txt - pip install -r backend/requirements/model_server.txt + pip install --retries 5 --timeout 30 -r backend/requirements/default.txt + pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt + pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt - name: Set up chart-testing uses: helm/chart-testing-action@v2.6.1 diff --git a/.github/workflows/pr-python-checks.yml b/.github/workflows/pr-python-checks.yml index 9cc624fa073..e33bad59ef2 100644 --- a/.github/workflows/pr-python-checks.yml +++ b/.github/workflows/pr-python-checks.yml @@ -24,9 +24,9 @@ jobs: backend/requirements/model_server.txt - run: | python -m pip install --upgrade pip - pip install -r backend/requirements/default.txt - pip install -r backend/requirements/dev.txt - pip install -r backend/requirements/model_server.txt + pip install --retries 5 --timeout 30 -r backend/requirements/default.txt + pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt + pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt - name: Run MyPy run: | diff --git a/.github/workflows/pr-python-connector-tests.yml b/.github/workflows/pr-python-connector-tests.yml index 00b92c9b003..d0c693dd19d 100644 --- a/.github/workflows/pr-python-connector-tests.yml +++ b/.github/workflows/pr-python-connector-tests.yml @@ -39,8 +39,8 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - pip install -r backend/requirements/default.txt - pip install -r backend/requirements/dev.txt + pip install --retries 5 --timeout 30 -r backend/requirements/default.txt + pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt - name: Run Tests shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}" diff --git a/.github/workflows/pr-python-tests.yml b/.github/workflows/pr-python-tests.yml index 1acdbc5ddb1..30d1d3ad2f9 100644 --- a/.github/workflows/pr-python-tests.yml +++ b/.github/workflows/pr-python-tests.yml @@ -29,8 +29,8 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip - pip install -r backend/requirements/default.txt - pip install -r backend/requirements/dev.txt + pip install --retries 5 --timeout 30 -r backend/requirements/default.txt + pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt - name: Run Tests shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}" diff --git a/.github/workflows/run-it.yml b/.github/workflows/run-it.yml index 75df647c462..b1a55bfffcc 100644 --- a/.github/workflows/run-it.yml +++ b/.github/workflows/run-it.yml @@ -13,8 +13,7 @@ env: jobs: integration-tests: - runs-on: - group: 'arm64-image-builders' + runs-on: Amd64 steps: - name: Checkout code uses: actions/checkout@v4 @@ -41,7 +40,7 @@ jobs: with: context: ./backend file: ./backend/Dockerfile - platforms: linux/arm64 + platforms: linux/amd64 tags: danswer/danswer-backend:it cache-from: type=registry,ref=danswer/danswer-backend:it cache-to: | @@ -53,7 +52,7 @@ jobs: with: context: ./backend file: ./backend/Dockerfile.model_server - platforms: linux/arm64 + platforms: linux/amd64 tags: danswer/danswer-model-server:it cache-from: type=registry,ref=danswer/danswer-model-server:it cache-to: | @@ -65,7 +64,7 @@ jobs: with: context: ./backend file: ./backend/tests/integration/Dockerfile - platforms: linux/arm64 + platforms: linux/amd64 tags: danswer/integration-test-runner:it cache-from: type=registry,ref=danswer/integration-test-runner:it cache-to: | From f404c4b4482ae8a46d484dbd045e3ad49e66ba5f Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Wed, 18 Sep 2024 23:00:58 -0700 Subject: [PATCH 003/505] Move code block default language creation to citation processing (#2501) * move code block default language creation to citaiton processing * add test cases * update copy --- backend/danswer/chat/process_message.py | 1 + .../stream_processing/citation_processing.py | 9 ++ .../test_citation_processing.py | 86 +++++++++++++++++++ 3 files changed, 96 insertions(+) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 26c59ebf60c..93ad3bdd34e 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -709,6 +709,7 @@ def stream_chat_message_objects( yield FinalUsedContextDocsResponse( final_context_docs=packet.response ) + elif packet.id == IMAGE_GENERATION_RESPONSE_ID: img_generation_response = cast( list[ImageGenerationResponse], packet.response diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py index a72fc70a8ff..f1e5489550d 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_processing.py +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -85,6 +85,15 @@ def extract_citations_from_stream( curr_segment += token llm_out += token + # Handle code blocks without language tags + if "`" in curr_segment: + if curr_segment.endswith("`"): + continue + elif "```" in curr_segment: + piece_that_comes_after = curr_segment.split("```")[1][0] + if piece_that_comes_after == "\n" and in_code_block(llm_out): + curr_segment = curr_segment.replace("```", "```plaintext") + citation_pattern = r"\[(\d+)\]" citations_found = list(re.finditer(citation_pattern, curr_segment)) diff --git a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py index 473ccf2451a..12e3254d6d6 100644 --- a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py +++ b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py @@ -286,6 +286,92 @@ def process_text( "[[1]](https://0.com) Citation at the beginning. ", ["doc_0"], ), + ( + "Code block without language specification", + [ + "Here's", + " a code block", + ":\n```\nd", + "ef example():\n pass\n", + "```\n", + "End of code.", + ], + "Here's a code block:\n```plaintext\ndef example():\n pass\n```\nEnd of code.", + [], + ), + ( + "Code block with language specification", + [ + "Here's a Python code block:\n", + "```", + "python", + "\n", + "def greet", + "(name):", + "\n ", + "print", + "(f'Hello, ", + "{name}!')", + "\n", + "greet('World')", + "\n```\n", + "This function ", + "greets the user.", + ], + "Here's a Python code block:\n```python\ndef greet(name):\n " + "print(f'Hello, {name}!')\ngreet('World')\n```\nThis function greets the user.", + [], + ), + ( + "Multiple code blocks with different languages", + [ + "JavaScript example:\n", + "```", + "javascript", + "\n", + "console", + ".", + "log", + "('Hello, World!');", + "\n```\n", + "Python example", + ":\n", + "```", + "python", + "\n", + "print", + "('Hello, World!')", + "\n```\n", + "Both print greetings", + ".", + ], + "JavaScript example:\n```javascript\nconsole.log('Hello, World!');\n" + "```\nPython example:\n```python\nprint('Hello, World!')\n" + "```\nBoth print greetings.", + [], + ), + ( + "Code block with text block", + [ + "Here's a code block with a text block:\n", + "```\n", + "# This is a comment", + "\n", + "x = 10 # This assigns 10 to x\n", + "print", + "(x) # This prints x", + "\n```\n", + "The code demonstrates variable assignment.", + ], + "Here's a code block with a text block:\n" + "```plaintext\n" + "# This is a comment\n" + "x = 10 # This assigns 10 to x\n" + "print(x) # This prints x\n" + "```\n" + "The code demonstrates variable assignment.", + [], + ), ], ) def test_citation_extraction( From a575d7f1eb0cab900bce9182795f369974e4dcd1 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 19 Sep 2024 12:31:26 -0700 Subject: [PATCH 004/505] Citations prompt for slack now includes thread history (#2510) --- backend/danswer/llm/answering/answer.py | 1 + backend/danswer/llm/answering/prompts/citations_prompt.py | 4 ++++ backend/danswer/prompts/direct_qa_prompts.py | 3 +++ 3 files changed, 8 insertions(+) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 3c0dc4961f1..2bbf7bb72b5 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -179,6 +179,7 @@ def _update_prompt_builder_for_search_tool( if self.answer_style_config.citation_config else False ), + history_message=self.single_message_history or "", ) ) elif self.answer_style_config.quotes_config: diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py index eddae9badb4..52345f3e587 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -29,6 +29,9 @@ from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT from danswer.search.models import InferenceChunk +from danswer.utils.logger import setup_logger + +logger = setup_logger() def get_prompt_tokens(prompt_config: PromptConfig) -> int: @@ -156,6 +159,7 @@ def build_citations_user_message( user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format( task_prompt=task_prompt_with_reminder, user_query=question, + history_block=history_message, ) user_prompt = user_prompt.strip() diff --git a/backend/danswer/prompts/direct_qa_prompts.py b/backend/danswer/prompts/direct_qa_prompts.py index 16768963931..0139da13e88 100644 --- a/backend/danswer/prompts/direct_qa_prompts.py +++ b/backend/danswer/prompts/direct_qa_prompts.py @@ -109,6 +109,9 @@ Refer to the provided context documents when responding to me.{DEFAULT_IGNORE_STATEMENT} \ You should always get right to the point, and never use extraneous language. +CHAT HISTORY: +{{history_block}} + {{task_prompt}} {QUESTION_PAT.upper()} From ef104e9a82783d4b7db74ea873027c26dbf3ce31 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 19 Sep 2024 13:02:36 -0700 Subject: [PATCH 005/505] Non-spotfix deletion of users (#2499) * add description / robustify * additional minor robustification (ideally we organized cascades slightly better) * update deletion for simplicity * minor typing update --- .../1b8206b29c5d_add_user_delete_cascades.py | 102 ++++++++++++++++++ backend/danswer/db/models.py | 58 +++++++--- backend/danswer/server/manage/users.py | 17 ++- .../admin/users/SignedUpUserTable.tsx | 1 + .../components/modals/DeleteEntityModal.tsx | 3 + 5 files changed, 163 insertions(+), 18 deletions(-) create mode 100644 backend/alembic/versions/1b8206b29c5d_add_user_delete_cascades.py diff --git a/backend/alembic/versions/1b8206b29c5d_add_user_delete_cascades.py b/backend/alembic/versions/1b8206b29c5d_add_user_delete_cascades.py new file mode 100644 index 00000000000..250621f74e2 --- /dev/null +++ b/backend/alembic/versions/1b8206b29c5d_add_user_delete_cascades.py @@ -0,0 +1,102 @@ +"""add_user_delete_cascades + +Revision ID: 1b8206b29c5d +Revises: 35e6853a51d5 +Create Date: 2024-09-18 11:48:59.418726 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "1b8206b29c5d" +down_revision = "35e6853a51d5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey") + op.create_foreign_key( + "credential_user_id_fkey", + "credential", + "user", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + + op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey") + op.create_foreign_key( + "chat_session_user_id_fkey", + "chat_session", + "user", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + + op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey") + op.create_foreign_key( + "chat_folder_user_id_fkey", + "chat_folder", + "user", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + + op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey") + op.create_foreign_key( + "prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"], ondelete="CASCADE" + ) + + op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey") + op.create_foreign_key( + "notification_user_id_fkey", + "notification", + "user", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + + op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey") + op.create_foreign_key( + "inputprompt_user_id_fkey", + "inputprompt", + "user", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + + +def downgrade() -> None: + op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey") + op.create_foreign_key( + "credential_user_id_fkey", "credential", "user", ["user_id"], ["id"] + ) + + op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey") + op.create_foreign_key( + "chat_session_user_id_fkey", "chat_session", "user", ["user_id"], ["id"] + ) + + op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey") + op.create_foreign_key( + "chat_folder_user_id_fkey", "chat_folder", "user", ["user_id"], ["id"] + ) + + op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey") + op.create_foreign_key("prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"]) + + op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey") + op.create_foreign_key( + "notification_user_id_fkey", "notification", "user", ["user_id"], ["id"] + ) + + op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey") + op.create_foreign_key( + "inputprompt_user_id_fkey", "inputprompt", "user", ["user_id"], ["id"] + ) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 79d8206586b..33d8d8bc0d8 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -108,7 +108,7 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): class User(SQLAlchemyBaseUserTableUUID, Base): oauth_accounts: Mapped[list[OAuthAccount]] = relationship( - "OAuthAccount", lazy="joined" + "OAuthAccount", lazy="joined", cascade="all, delete-orphan" ) role: Mapped[UserRole] = mapped_column( Enum(UserRole, native_enum=False, default=UserRole.BASIC) @@ -170,7 +170,9 @@ class InputPrompt(Base): active: Mapped[bool] = mapped_column(Boolean) user: Mapped[User | None] = relationship("User", back_populates="input_prompts") is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) - user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) class InputPrompt__User(Base): @@ -214,7 +216,9 @@ class Notification(Base): notif_type: Mapped[NotificationType] = mapped_column( Enum(NotificationType, native_enum=False) ) - user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) dismissed: Mapped[bool] = mapped_column(Boolean, default=False) last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) @@ -249,7 +253,7 @@ class Persona__User(Base): persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True) user_id: Mapped[UUID | None] = mapped_column( - ForeignKey("user.id"), primary_key=True, nullable=True + ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True ) @@ -260,7 +264,7 @@ class DocumentSet__User(Base): ForeignKey("document_set.id"), primary_key=True ) user_id: Mapped[UUID | None] = mapped_column( - ForeignKey("user.id"), primary_key=True, nullable=True + ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True ) @@ -541,7 +545,9 @@ class Credential(Base): id: Mapped[int] = mapped_column(primary_key=True) credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson()) - user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) # if `true`, then all Admins will have access to the credential admin_public: Mapped[bool] = mapped_column(Boolean, default=True) time_created: Mapped[datetime.datetime] = mapped_column( @@ -865,7 +871,9 @@ class ChatSession(Base): __tablename__ = "chat_session" id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) persona_id: Mapped[int | None] = mapped_column( ForeignKey("persona.id"), nullable=True ) @@ -1002,7 +1010,9 @@ class ChatFolder(Base): id: Mapped[int] = mapped_column(primary_key=True) # Only null if auth is off - user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) name: Mapped[str | None] = mapped_column(String, nullable=True) display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0) @@ -1133,7 +1143,9 @@ class DocumentSet(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String, unique=True) description: Mapped[str] = mapped_column(String) - user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) # Whether changes to the document set have been propagated is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) # If `False`, then the document set is not visible to users who are not explicitly @@ -1177,7 +1189,9 @@ class Prompt(Base): __tablename__ = "prompt" id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) name: Mapped[str] = mapped_column(String) description: Mapped[str] = mapped_column(String) system_prompt: Mapped[str] = mapped_column(Text) @@ -1214,7 +1228,9 @@ class Tool(Base): ) # user who created / owns the tool. Will be None for built-in tools. - user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) user: Mapped[User | None] = relationship("User", back_populates="custom_tools") # Relationship to Persona through the association table @@ -1238,7 +1254,9 @@ class Persona(Base): __tablename__ = "persona" id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) name: Mapped[str] = mapped_column(String) description: Mapped[str] = mapped_column(String) # Number of chunks to pass to the LLM for generation. @@ -1434,7 +1452,9 @@ class SamlAccount(Base): __tablename__ = "saml" id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), unique=True) + user_id: Mapped[int] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), unique=True + ) encrypted_cookie: Mapped[str] = mapped_column(Text, unique=True) expires_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) updated_at: Mapped[datetime.datetime] = mapped_column( @@ -1453,7 +1473,7 @@ class User__UserGroup(Base): ForeignKey("user_group.id"), primary_key=True ) user_id: Mapped[UUID | None] = mapped_column( - ForeignKey("user.id"), primary_key=True, nullable=True + ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True ) @@ -1701,7 +1721,9 @@ class ExternalPermission(Base): __tablename__ = "external_permission" id: Mapped[int] = mapped_column(Integer, primary_key=True) - user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) # Email is needed because we want to keep track of users not in Danswer to simplify process # when the user joins user_email: Mapped[str] = mapped_column(String) @@ -1730,7 +1752,9 @@ class EmailToExternalUserCache(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) external_user_id: Mapped[str] = mapped_column(String) - user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + user_id: Mapped[UUID | None] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), nullable=True + ) # Email is needed because we want to keep track of users not in Danswer to simplify process # when the user joins user_email: Mapped[str] = mapped_column(String) @@ -1754,7 +1778,7 @@ class UsageReport(Base): # if None, report was auto-generated requestor_user_id: Mapped[UUID | None] = mapped_column( - ForeignKey("user.id"), nullable=True + ForeignKey("user.id", ondelete="CASCADE"), nullable=True ) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 96c79b4cbe7..46beee856f8 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -31,7 +31,10 @@ from danswer.configs.constants import AuthType from danswer.db.engine import get_session from danswer.db.models import AccessToken +from danswer.db.models import DocumentSet__User +from danswer.db.models import Persona__User from danswer.db.models import User +from danswer.db.models import User__UserGroup from danswer.db.users import get_user_by_email from danswer.db.users import list_users from danswer.dynamic_configs.factory import get_dynamic_config_store @@ -237,10 +240,18 @@ async def delete_user( db_session.expunge(user_to_delete) try: - # Delete related OAuthAccounts first for oauth_account in user_to_delete.oauth_accounts: db_session.delete(oauth_account) + db_session.query(DocumentSet__User).filter( + DocumentSet__User.user_id == user_to_delete.id + ).delete() + db_session.query(Persona__User).filter( + Persona__User.user_id == user_to_delete.id + ).delete() + db_session.query(User__UserGroup).filter( + User__UserGroup.user_id == user_to_delete.id + ).delete() db_session.delete(user_to_delete) db_session.commit() @@ -254,6 +265,10 @@ async def delete_user( logger.info(f"Deleted user {user_to_delete.email}") except Exception as e: + import traceback + + full_traceback = traceback.format_exc() + logger.error(f"Full stack trace:\n{full_traceback}") db_session.rollback() logger.error(f"Error deleting user {user_to_delete.email}: {str(e)}") raise HTTPException(status_code=500, detail="Error deleting user") diff --git a/web/src/components/admin/users/SignedUpUserTable.tsx b/web/src/components/admin/users/SignedUpUserTable.tsx index b747229ea39..cb30e4d89bc 100644 --- a/web/src/components/admin/users/SignedUpUserTable.tsx +++ b/web/src/components/admin/users/SignedUpUserTable.tsx @@ -195,6 +195,7 @@ const DeleteUserButton = ({ entityName={user.email} onClose={() => setShowDeleteModal(false)} onSubmit={() => trigger({ user_email: user.email, method: "DELETE" })} + additionalDetails="All data associated with this user will be deleted (including personas, tools and chat sessions)." /> )} diff --git a/web/src/components/modals/DeleteEntityModal.tsx b/web/src/components/modals/DeleteEntityModal.tsx index a9648c2f81c..5ef76f9c851 100644 --- a/web/src/components/modals/DeleteEntityModal.tsx +++ b/web/src/components/modals/DeleteEntityModal.tsx @@ -7,11 +7,13 @@ export const DeleteEntityModal = ({ onSubmit, entityType, entityName, + additionalDetails, }: { entityType: string; entityName: string; onClose: () => void; onSubmit: () => void; + additionalDetails?: string; }) => { return ( @@ -23,6 +25,7 @@ export const DeleteEntityModal = ({ Click below to confirm that you want to delete{" "} "{entityName}"

+ {additionalDetails &&

{additionalDetails}

}
From 2274cab5549195aad8e13c1a00f99ffc28422094 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 19 Sep 2024 15:07:36 -0700 Subject: [PATCH 006/505] Added permission syncing (#2340) * Added permission syncing on the backend * Rewored to work with celery alembic fix fixed test * frontend changes * got groups working * added comments and fixed public docs * fixed merge issues * frontend complete! * frontend cleanup and mypy fixes * refactored connector access_type selection * mypy fixes * minor refactor and frontend improvements * get to fetch * renames and comments * minor change to var names * got curator stuff working * addressed pablo's comments * refactored user_external_group to reference users table * implemented polling * small refactor * fixed a whoopsies on the frontend * added scripts to seed dummy docs and test query times * fixed frontend build issue * alembic fix * handled is_public overlap * yuhong feedback * added more checks for sync * black * mypy * fixed circular import * todos * alembic fix * alembic --- ...ced_and_last_modified_to_document_table.py | 2 +- .../61ff3651add4_add_permission_syncing.py | 162 ++++++++ ...76026c_standard_answer_match_regex_flag.py | 2 +- ...f7e58d357687_add_has_web_column_to_user.py | 2 +- backend/danswer/access/access.py | 47 ++- backend/danswer/access/models.py | 62 ++- backend/danswer/access/utils.py | 24 +- .../danswer/background/celery/celery_app.py | 10 +- .../background/indexing/run_indexing.py | 10 +- backend/danswer/connectors/factory.py | 2 +- .../connectors/google_drive/connector.py | 47 +-- .../connectors/google_drive/connector_auth.py | 78 +++- .../connectors/google_drive/constants.py | 8 +- backend/danswer/connectors/models.py | 3 + .../slack/handlers/handle_message.py | 2 +- .../danswer/db/connector_credential_pair.py | 31 +- backend/danswer/db/document.py | 57 ++- backend/danswer/db/document_set.py | 5 +- backend/danswer/db/enums.py | 6 + backend/danswer/db/feedback.py | 3 +- backend/danswer/db/models.py | 122 ++---- backend/danswer/db/users.py | 53 ++- backend/danswer/indexing/indexing_pipeline.py | 8 +- backend/danswer/server/documents/cc_pair.py | 6 +- backend/danswer/server/documents/connector.py | 8 +- backend/danswer/server/documents/models.py | 10 +- .../server/features/document_set/models.py | 7 - .../danswer/server/manage/embedding/models.py | 3 + backend/danswer/server/manage/users.py | 6 + backend/ee/danswer/access/access.py | 74 +++- .../danswer/background/celery/celery_app.py | 36 +- backend/ee/danswer/background/celery_utils.py | 24 +- .../ee/danswer/background/permission_sync.py | 224 ---------- .../danswer/background/task_name_builders.py | 8 +- .../connectors/confluence/perm_sync.py | 12 - backend/ee/danswer/connectors/factory.py | 8 - .../danswer/db/connector_credential_pair.py | 23 +- backend/ee/danswer/db/document.py | 51 ++- backend/ee/danswer/db/external_perm.py | 77 ++++ backend/ee/danswer/db/permission_sync.py | 72 ---- backend/ee/danswer/db/user_group.py | 2 +- .../__init__.py | 0 .../confluence/__init__.py | 0 .../confluence/doc_sync.py | 19 + .../confluence/group_sync.py | 19 + .../google_drive/__init__.py | 0 .../google_drive/doc_sync.py | 148 +++++++ .../google_drive/group_sync.py | 147 +++++++ .../external_permissions/permission_sync.py | 126 ++++++ .../permission_sync_function_map.py | 54 +++ .../permission_sync_utils.py | 56 +++ .../query_time_check/seed_dummy_docs.py | 166 ++++++++ .../query_time_check/test_query_times.py | 122 ++++++ .../common_utils/managers/cc_pair.py | 21 +- .../common_utils/managers/document.py | 7 +- .../integration/common_utils/test_models.py | 3 +- .../permissions/test_cc_pair_permissions.py | 13 +- .../permissions/test_doc_set_permissions.py | 5 +- .../permissions/test_whole_curator_flow.py | 3 +- .../[connector]/AddConnectorPage.tsx | 116 +++++- .../admin/connectors/[connector]/Sidebar.tsx | 6 +- .../[connector]/pages/gdrive/Credential.tsx | 70 ++-- .../[connector]/pages/utils/files.ts | 2 +- .../[connector]/pages/utils/google_site.ts | 2 + .../sets/DocumentSetCreationForm.tsx | 4 +- .../status/CCPairIndexingStatusTable.tsx | 7 +- .../app/ee/admin/groups/ConnectorEditor.tsx | 2 +- .../ee/admin/groups/UserGroupCreationForm.tsx | 9 +- .../groups/[groupId]/AddConnectorForm.tsx | 2 +- .../admin/connectors/AccessTypeForm.tsx | 88 ++++ .../connectors/AccessTypeGroupSelector.tsx | 147 +++++++ .../admin/connectors/AutoSyncOptions.tsx | 30 ++ .../admin/connectors/ConnectorForm.tsx | 382 ------------------ .../credentials/CredentialSection.tsx | 1 - .../credentials/actions/CreateCredential.tsx | 2 +- .../lib/connectors/AutoSyncOptionFields.tsx | 46 +++ web/src/lib/connectors/connectors.ts | 3 +- web/src/lib/credential.ts | 9 +- web/src/lib/types.ts | 7 +- 79 files changed, 2192 insertions(+), 1049 deletions(-) create mode 100644 backend/alembic/versions/61ff3651add4_add_permission_syncing.py delete mode 100644 backend/ee/danswer/background/permission_sync.py delete mode 100644 backend/ee/danswer/connectors/confluence/perm_sync.py delete mode 100644 backend/ee/danswer/connectors/factory.py create mode 100644 backend/ee/danswer/db/external_perm.py delete mode 100644 backend/ee/danswer/db/permission_sync.py rename backend/ee/danswer/{connectors => external_permissions}/__init__.py (100%) rename backend/ee/danswer/{connectors => external_permissions}/confluence/__init__.py (100%) create mode 100644 backend/ee/danswer/external_permissions/confluence/doc_sync.py create mode 100644 backend/ee/danswer/external_permissions/confluence/group_sync.py create mode 100644 backend/ee/danswer/external_permissions/google_drive/__init__.py create mode 100644 backend/ee/danswer/external_permissions/google_drive/doc_sync.py create mode 100644 backend/ee/danswer/external_permissions/google_drive/group_sync.py create mode 100644 backend/ee/danswer/external_permissions/permission_sync.py create mode 100644 backend/ee/danswer/external_permissions/permission_sync_function_map.py create mode 100644 backend/ee/danswer/external_permissions/permission_sync_utils.py create mode 100644 backend/scripts/query_time_check/seed_dummy_docs.py create mode 100644 backend/scripts/query_time_check/test_query_times.py create mode 100644 web/src/components/admin/connectors/AccessTypeForm.tsx create mode 100644 web/src/components/admin/connectors/AccessTypeGroupSelector.tsx create mode 100644 web/src/components/admin/connectors/AutoSyncOptions.tsx delete mode 100644 web/src/components/admin/connectors/ConnectorForm.tsx create mode 100644 web/src/lib/connectors/AutoSyncOptionFields.tsx diff --git a/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py b/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py index f284c7b4bf1..068342095b6 100644 --- a/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py +++ b/backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py @@ -1,7 +1,7 @@ """Add last synced and last modified to document table Revision ID: 52a219fb5233 -Revises: f17bf3b0d9f1 +Revises: f7e58d357687 Create Date: 2024-08-28 17:40:46.077470 """ diff --git a/backend/alembic/versions/61ff3651add4_add_permission_syncing.py b/backend/alembic/versions/61ff3651add4_add_permission_syncing.py new file mode 100644 index 00000000000..697e1060e0b --- /dev/null +++ b/backend/alembic/versions/61ff3651add4_add_permission_syncing.py @@ -0,0 +1,162 @@ +"""Add Permission Syncing + +Revision ID: 61ff3651add4 +Revises: 1b8206b29c5d +Create Date: 2024-09-05 13:57:11.770413 + +""" +import fastapi_users_db_sqlalchemy + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "61ff3651add4" +down_revision = "1b8206b29c5d" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Admin user who set up connectors will lose access to the docs temporarily + # only way currently to give back access is to rerun from beginning + op.add_column( + "connector_credential_pair", + sa.Column( + "access_type", + sa.String(), + nullable=True, + ), + ) + op.execute( + "UPDATE connector_credential_pair SET access_type = 'PUBLIC' WHERE is_public = true" + ) + op.execute( + "UPDATE connector_credential_pair SET access_type = 'PRIVATE' WHERE is_public = false" + ) + op.alter_column("connector_credential_pair", "access_type", nullable=False) + + op.add_column( + "connector_credential_pair", + sa.Column( + "auto_sync_options", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + op.add_column( + "connector_credential_pair", + sa.Column("last_time_perm_sync", sa.DateTime(timezone=True), nullable=True), + ) + op.drop_column("connector_credential_pair", "is_public") + + op.add_column( + "document", + sa.Column("external_user_emails", postgresql.ARRAY(sa.String()), nullable=True), + ) + op.add_column( + "document", + sa.Column( + "external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True + ), + ) + op.add_column( + "document", + sa.Column("is_public", sa.Boolean(), nullable=True), + ) + + op.create_table( + "user__external_user_group_id", + sa.Column( + "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False + ), + sa.Column("external_user_group_id", sa.String(), nullable=False), + sa.Column("cc_pair_id", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("user_id"), + ) + + op.drop_column("external_permission", "user_id") + op.drop_column("email_to_external_user_cache", "user_id") + op.drop_table("permission_sync_run") + op.drop_table("external_permission") + op.drop_table("email_to_external_user_cache") + + +def downgrade() -> None: + op.add_column( + "connector_credential_pair", + sa.Column("is_public", sa.BOOLEAN(), nullable=True), + ) + op.execute( + "UPDATE connector_credential_pair SET is_public = (access_type = 'PUBLIC')" + ) + op.alter_column("connector_credential_pair", "is_public", nullable=False) + + op.drop_column("connector_credential_pair", "auto_sync_options") + op.drop_column("connector_credential_pair", "access_type") + op.drop_column("connector_credential_pair", "last_time_perm_sync") + op.drop_column("document", "external_user_emails") + op.drop_column("document", "external_user_group_ids") + op.drop_column("document", "is_public") + + op.drop_table("user__external_user_group_id") + + # Drop the enum type at the end of the downgrade + op.create_table( + "permission_sync_run", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "source_type", + sa.String(), + nullable=False, + ), + sa.Column("update_type", sa.String(), nullable=False), + sa.Column("cc_pair_id", sa.Integer(), nullable=True), + sa.Column( + "status", + sa.String(), + nullable=False, + ), + sa.Column("error_msg", sa.Text(), nullable=True), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["cc_pair_id"], + ["connector_credential_pair.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "external_permission", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("user_email", sa.String(), nullable=False), + sa.Column( + "source_type", + sa.String(), + nullable=False, + ), + sa.Column("external_permission_group", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "email_to_external_user_cache", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("external_user_id", sa.String(), nullable=False), + sa.Column("user_id", sa.UUID(), nullable=True), + sa.Column("user_email", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) diff --git a/backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py b/backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py index c85bb68a3b9..e67d31b81ff 100644 --- a/backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py +++ b/backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py @@ -1,7 +1,7 @@ """standard answer match_regex flag Revision ID: efb35676026c -Revises: 52a219fb5233 +Revises: 0ebb1d516877 Create Date: 2024-09-11 13:55:46.101149 """ diff --git a/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py b/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py index 2d8e7402e48..c2a131d6002 100644 --- a/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py +++ b/backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py @@ -1,7 +1,7 @@ """add has_web_login column to user Revision ID: f7e58d357687 -Revises: bceb1e139447 +Revises: ba98eba0f66a Create Date: 2024-09-07 20:20:54.522620 """ diff --git a/backend/danswer/access/access.py b/backend/danswer/access/access.py index 9088ddf8425..7c879099594 100644 --- a/backend/danswer/access/access.py +++ b/backend/danswer/access/access.py @@ -1,7 +1,7 @@ from sqlalchemy.orm import Session from danswer.access.models import DocumentAccess -from danswer.access.utils import prefix_user +from danswer.access.utils import prefix_user_email from danswer.configs.constants import PUBLIC_DOC_PAT from danswer.db.document import get_access_info_for_document from danswer.db.document import get_access_info_for_documents @@ -18,10 +18,13 @@ def _get_access_for_document( document_id=document_id, ) - if not info: - return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False) - - return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2]) + return DocumentAccess.build( + user_emails=info[1] if info and info[1] else [], + user_groups=[], + external_user_emails=[], + external_user_group_ids=[], + is_public=info[2] if info else False, + ) def get_access_for_document( @@ -34,6 +37,16 @@ def get_access_for_document( return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore +def get_null_document_access() -> DocumentAccess: + return DocumentAccess( + user_emails=set(), + user_groups=set(), + is_public=False, + external_user_emails=set(), + external_user_group_ids=set(), + ) + + def _get_access_for_documents( document_ids: list[str], db_session: Session, @@ -42,13 +55,27 @@ def _get_access_for_documents( db_session=db_session, document_ids=document_ids, ) - return { - document_id: DocumentAccess.build( - user_ids=user_ids, user_groups=[], is_public=is_public + doc_access = { + document_id: DocumentAccess( + user_emails=set([email for email in user_emails if email]), + # MIT version will wipe all groups and external groups on update + user_groups=set(), + is_public=is_public, + external_user_emails=set(), + external_user_group_ids=set(), ) - for document_id, user_ids, is_public in document_access_info + for document_id, user_emails, is_public in document_access_info } + # Sometimes the document has not be indexed by the indexing job yet, in those cases + # the document does not exist and so we use least permissive. Specifically the EE version + # checks the MIT version permissions and creates a superset. This ensures that this flow + # does not fail even if the Document has not yet been indexed. + for doc_id in document_ids: + if doc_id not in doc_access: + doc_access[doc_id] = get_null_document_access() + return doc_access + def get_access_for_documents( document_ids: list[str], @@ -70,7 +97,7 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]: matches one entry in the returned set. """ if user: - return {prefix_user(str(user.id)), PUBLIC_DOC_PAT} + return {prefix_user_email(user.email), PUBLIC_DOC_PAT} return {PUBLIC_DOC_PAT} diff --git a/backend/danswer/access/models.py b/backend/danswer/access/models.py index a87e2d94f25..af5a021ca97 100644 --- a/backend/danswer/access/models.py +++ b/backend/danswer/access/models.py @@ -1,30 +1,72 @@ from dataclasses import dataclass -from uuid import UUID -from danswer.access.utils import prefix_user +from danswer.access.utils import prefix_external_group +from danswer.access.utils import prefix_user_email from danswer.access.utils import prefix_user_group from danswer.configs.constants import PUBLIC_DOC_PAT @dataclass(frozen=True) -class DocumentAccess: - user_ids: set[str] # stringified UUIDs - user_groups: set[str] # names of user groups associated with this document +class ExternalAccess: + # Emails of external users with access to the doc externally + external_user_emails: set[str] + # Names or external IDs of groups with access to the doc + external_user_group_ids: set[str] + # Whether the document is public in the external system or Danswer is_public: bool - def to_acl(self) -> list[str]: - return ( - [prefix_user(user_id) for user_id in self.user_ids] + +@dataclass(frozen=True) +class DocumentAccess(ExternalAccess): + # User emails for Danswer users, None indicates admin + user_emails: set[str | None] + # Names of user groups associated with this document + user_groups: set[str] + + def to_acl(self) -> set[str]: + return set( + [ + prefix_user_email(user_email) + for user_email in self.user_emails + if user_email + ] + [prefix_user_group(group_name) for group_name in self.user_groups] + + [ + prefix_user_email(user_email) + for user_email in self.external_user_emails + ] + + [ + # The group names are already prefixed by the source type + # This adds an additional prefix of "external_group:" + prefix_external_group(group_name) + for group_name in self.external_user_group_ids + ] + ([PUBLIC_DOC_PAT] if self.is_public else []) ) @classmethod def build( - cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool + cls, + user_emails: list[str | None], + user_groups: list[str], + external_user_emails: list[str], + external_user_group_ids: list[str], + is_public: bool, ) -> "DocumentAccess": return cls( - user_ids={str(user_id) for user_id in user_ids if user_id}, + external_user_emails={ + prefix_user_email(external_email) + for external_email in external_user_emails + }, + external_user_group_ids={ + prefix_external_group(external_group_id) + for external_group_id in external_user_group_ids + }, + user_emails={ + prefix_user_email(user_email) + for user_email in user_emails + if user_email + }, user_groups=set(user_groups), is_public=is_public, ) diff --git a/backend/danswer/access/utils.py b/backend/danswer/access/utils.py index 060560eaedc..82abf9785f8 100644 --- a/backend/danswer/access/utils.py +++ b/backend/danswer/access/utils.py @@ -1,10 +1,24 @@ -def prefix_user(user_id: str) -> str: - """Prefixes a user ID to eliminate collision with group names. - This assumes that groups are prefixed with a different prefix.""" - return f"user_id:{user_id}" +from danswer.configs.constants import DocumentSource + + +def prefix_user_email(user_email: str) -> str: + """Prefixes a user email to eliminate collision with group names. + This applies to both a Danswer user and an External user, this is to make the query time + more efficient""" + return f"user_email:{user_email}" def prefix_user_group(user_group_name: str) -> str: - """Prefixes a user group name to eliminate collision with user IDs. + """Prefixes a user group name to eliminate collision with user emails. This assumes that user ids are prefixed with a different prefix.""" return f"group:{user_group_name}" + + +def prefix_external_group(ext_group_name: str) -> str: + """Prefixes an external group name to eliminate collision with user emails / Danswer groups.""" + return f"external_group:{ext_group_name}" + + +def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str: + """External groups may collide across sources, every source needs its own prefix.""" + return f"{source.value.upper()}_{ext_group_name}" diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index aa7c53aa3d2..cdd9b1d5c44 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -128,11 +128,11 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: return runnable_connector = instantiate_connector( - cc_pair.connector.source, - InputType.PRUNE, - cc_pair.connector.connector_specific_config, - cc_pair.credential, - db_session, + db_session=db_session, + source=cc_pair.connector.source, + input_type=InputType.PRUNE, + connector_specific_config=cc_pair.connector.connector_specific_config, + credential=cc_pair.credential, ) all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector( diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 86b4285361f..a29ddd76c2b 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -56,11 +56,11 @@ def _get_connector_runner( try: runnable_connector = instantiate_connector( - attempt.connector_credential_pair.connector.source, - task, - attempt.connector_credential_pair.connector.connector_specific_config, - attempt.connector_credential_pair.credential, - db_session, + db_session=db_session, + source=attempt.connector_credential_pair.connector.source, + input_type=task, + connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config, + credential=attempt.connector_credential_pair.credential, ) except Exception as e: logger.exception(f"Unable to instantiate connector due to {e}") diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index 1a3d605d3a5..42d3b0bd937 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -124,11 +124,11 @@ def identify_connector_class( def instantiate_connector( + db_session: Session, source: DocumentSource, input_type: InputType, connector_specific_config: dict[str, Any], credential: Credential, - db_session: Session, ) -> BaseConnector: connector_class = identify_connector_class(source, input_type) connector = connector_class(**connector_specific_config) diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 80674b5a37d..bf267ab7786 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -6,7 +6,6 @@ from enum import Enum from itertools import chain from typing import Any -from typing import cast from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore @@ -21,19 +20,13 @@ from danswer.configs.constants import DocumentSource from danswer.configs.constants import IGNORE_FOR_QA from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder -from danswer.connectors.google_drive.connector_auth import ( - get_google_drive_creds_for_authorized_user, -) -from danswer.connectors.google_drive.connector_auth import ( - get_google_drive_creds_for_service_account, -) +from danswer.connectors.google_drive.connector_auth import get_google_drive_creds from danswer.connectors.google_drive.constants import ( DB_CREDENTIALS_DICT_DELEGATED_USER_KEY, ) from danswer.connectors.google_drive.constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) -from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector @@ -407,42 +400,7 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None (2) A credential which holds a service account key JSON file, which can then be used to impersonate any user in the workspace. """ - creds: OAuthCredentials | ServiceAccountCredentials | None = None - new_creds_dict = None - if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials: - access_token_json_str = cast( - str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY] - ) - creds = get_google_drive_creds_for_authorized_user( - token_json_str=access_token_json_str - ) - - # tell caller to update token stored in DB if it has changed - # (e.g. the token has been refreshed) - new_creds_json_str = creds.to_json() if creds else "" - if new_creds_json_str != access_token_json_str: - new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str} - - if DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials: - service_account_key_json_str = credentials[ - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY - ] - creds = get_google_drive_creds_for_service_account( - service_account_key_json_str=service_account_key_json_str - ) - - # "Impersonate" a user if one is specified - delegated_user_email = cast( - str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY) - ) - if delegated_user_email: - creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore - - if creds is None: - raise PermissionError( - "Unable to access Google Drive - unknown credential structure." - ) - + creds, new_creds_dict = get_google_drive_creds(credentials) self.creds = creds return new_creds_dict @@ -509,6 +467,7 @@ def _fetch_docs_from_drive( file["modifiedTime"] ).astimezone(timezone.utc), metadata={} if text_contents else {IGNORE_FOR_QA: "True"}, + additional_info=file.get("id"), ) ) except Exception as e: diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index 0f47727e6ee..cc68fec54ea 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -10,11 +10,13 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore from sqlalchemy.orm import Session +from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import DocumentSource from danswer.configs.constants import KV_CRED_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY +from danswer.connectors.google_drive.constants import BASE_SCOPES from danswer.connectors.google_drive.constants import ( DB_CREDENTIALS_DICT_DELEGATED_USER_KEY, ) @@ -22,7 +24,8 @@ DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY -from danswer.connectors.google_drive.constants import SCOPES +from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES +from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES from danswer.db.credentials import update_credential_json from danswer.db.models import User from danswer.dynamic_configs.factory import get_dynamic_config_store @@ -34,15 +37,25 @@ logger = setup_logger() +def build_gdrive_scopes() -> list[str]: + base_scopes: list[str] = BASE_SCOPES + permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES + groups_scopes: list[str] = FETCH_GROUPS_SCOPES + + if ENTERPRISE_EDITION_ENABLED: + return base_scopes + permissions_scopes + groups_scopes + return base_scopes + permissions_scopes + + def _build_frontend_google_drive_redirect() -> str: return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback" def get_google_drive_creds_for_authorized_user( - token_json_str: str, + token_json_str: str, scopes: list[str] = build_gdrive_scopes() ) -> OAuthCredentials | None: creds_json = json.loads(token_json_str) - creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES) + creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes) if creds.valid: return creds @@ -59,18 +72,67 @@ def get_google_drive_creds_for_authorized_user( return None -def get_google_drive_creds_for_service_account( - service_account_key_json_str: str, +def _get_google_drive_creds_for_service_account( + service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes() ) -> ServiceAccountCredentials | None: service_account_key = json.loads(service_account_key_json_str) creds = ServiceAccountCredentials.from_service_account_info( - service_account_key, scopes=SCOPES + service_account_key, scopes=scopes ) if not creds.valid or not creds.expired: creds.refresh(Request()) return creds if creds.valid else None +def get_google_drive_creds( + credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes() +) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]: + oauth_creds = None + service_creds = None + new_creds_dict = None + if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials: + access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]) + oauth_creds = get_google_drive_creds_for_authorized_user( + token_json_str=access_token_json_str, scopes=scopes + ) + + # tell caller to update token stored in DB if it has changed + # (e.g. the token has been refreshed) + new_creds_json_str = oauth_creds.to_json() if oauth_creds else "" + if new_creds_json_str != access_token_json_str: + new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str} + + elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials: + service_account_key_json_str = credentials[ + DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY + ] + service_creds = _get_google_drive_creds_for_service_account( + service_account_key_json_str=service_account_key_json_str, + scopes=scopes, + ) + + # "Impersonate" a user if one is specified + delegated_user_email = cast( + str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY) + ) + if delegated_user_email: + service_creds = ( + service_creds.with_subject(delegated_user_email) + if service_creds + else None + ) + + creds: ServiceAccountCredentials | OAuthCredentials | None = ( + oauth_creds or service_creds + ) + if creds is None: + raise PermissionError( + "Unable to access Google Drive - unknown credential structure." + ) + + return creds, new_creds_dict + + def verify_csrf(credential_id: int, state: str) -> None: csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id))) if csrf != state: @@ -84,7 +146,7 @@ def get_auth_url(credential_id: int) -> str: credential_json = json.loads(creds_str) flow = InstalledAppFlow.from_client_config( credential_json, - scopes=SCOPES, + scopes=build_gdrive_scopes(), redirect_uri=_build_frontend_google_drive_redirect(), ) auth_url, _ = flow.authorization_url(prompt="consent") @@ -107,7 +169,7 @@ def update_credential_access_tokens( app_credentials = get_google_app_cred() flow = InstalledAppFlow.from_client_config( app_credentials.model_dump(), - scopes=SCOPES, + scopes=build_gdrive_scopes(), redirect_uri=_build_frontend_google_drive_redirect(), ) flow.fetch_token(code=auth_code) diff --git a/backend/danswer/connectors/google_drive/constants.py b/backend/danswer/connectors/google_drive/constants.py index 214bfd5cb97..0cca65c13df 100644 --- a/backend/danswer/connectors/google_drive/constants.py +++ b/backend/danswer/connectors/google_drive/constants.py @@ -1,7 +1,7 @@ DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens" DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key" DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user" -SCOPES = [ - "https://www.googleapis.com/auth/drive.readonly", - "https://www.googleapis.com/auth/drive.metadata.readonly", -] + +BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"] +FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"] +FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"] diff --git a/backend/danswer/connectors/models.py b/backend/danswer/connectors/models.py index 192aa1b206a..7d86d21980d 100644 --- a/backend/danswer/connectors/models.py +++ b/backend/danswer/connectors/models.py @@ -113,6 +113,9 @@ class DocumentBase(BaseModel): # The default title is semantic_identifier though unless otherwise specified title: str | None = None from_ingestion_api: bool = False + # Anything else that may be useful that is specific to this particular connector type that other + # parts of the code may need. If you're unsure, this can be left as None + additional_info: Any = None def get_title_for_document_index( self, diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index cce45331ee7..0882796204d 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -211,7 +211,7 @@ def handle_message( with Session(get_sqlalchemy_engine()) as db_session: if message_info.email: - add_non_web_user_if_not_exists(message_info.email, db_session) + add_non_web_user_if_not_exists(db_session, message_info.email) # first check if we need to respond with a standard answer used_standard_answer = handle_standard_answers( diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py index 004b5a754e4..ec9a3a08ef8 100644 --- a/backend/danswer/db/connector_credential_pair.py +++ b/backend/danswer/db/connector_credential_pair.py @@ -12,6 +12,7 @@ from danswer.configs.constants import DocumentSource from danswer.db.connector import fetch_connector_by_id from danswer.db.credentials import fetch_credential_by_id +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import ConnectorCredentialPair from danswer.db.models import IndexAttempt @@ -24,6 +25,10 @@ from danswer.db.models import UserRole from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger +from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit +from ee.danswer.external_permissions.permission_sync_function_map import ( + check_if_valid_sync_source, +) logger = setup_logger() @@ -74,7 +79,7 @@ def _add_user_filters( .correlate(ConnectorCredentialPair) ) else: - where_clause |= ConnectorCredentialPair.is_public == True # noqa: E712 + where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC return stmt.where(where_clause) @@ -94,8 +99,7 @@ def get_connector_credential_pairs( ) # noqa if ids: stmt = stmt.where(ConnectorCredentialPair.id.in_(ids)) - results = db_session.scalars(stmt) - return list(results.all()) + return list(db_session.scalars(stmt).all()) def add_deletion_failure_message( @@ -309,9 +313,9 @@ def associate_default_cc_pair(db_session: Session) -> None: association = ConnectorCredentialPair( connector_id=0, credential_id=0, + access_type=AccessType.PUBLIC, name="DefaultCCPair", status=ConnectorCredentialPairStatus.ACTIVE, - is_public=True, ) db_session.add(association) db_session.commit() @@ -336,8 +340,9 @@ def add_credential_to_connector( connector_id: int, credential_id: int, cc_pair_name: str | None, - is_public: bool, + access_type: AccessType, groups: list[int] | None, + auto_sync_options: dict | None = None, ) -> StatusResponse: connector = fetch_connector_by_id(connector_id, db_session) credential = fetch_credential_by_id(credential_id, user, db_session) @@ -345,6 +350,13 @@ def add_credential_to_connector( if connector is None: raise HTTPException(status_code=404, detail="Connector does not exist") + if access_type == AccessType.SYNC: + if not check_if_valid_sync_source(connector.source): + raise HTTPException( + status_code=400, + detail=f"Connector of type {connector.source} does not support SYNC access type", + ) + if credential is None: error_msg = ( f"Credential {credential_id} does not exist or does not belong to user" @@ -375,12 +387,13 @@ def add_credential_to_connector( credential_id=credential_id, name=cc_pair_name, status=ConnectorCredentialPairStatus.ACTIVE, - is_public=is_public, + access_type=access_type, + auto_sync_options=auto_sync_options, ) db_session.add(association) db_session.flush() # make sure the association has an id - if groups: + if groups and access_type != AccessType.SYNC: _relate_groups_to_cc_pair__no_commit( db_session=db_session, cc_pair_id=association.id, @@ -423,6 +436,10 @@ def remove_credential_from_connector( ) if association is not None: + delete_user__ext_group_for_cc_pair__no_commit( + db_session=db_session, + cc_pair_id=association.id, + ) db_session.delete(association) db_session.commit() return StatusResponse( diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index 0d5bc276bc1..bffffddd755 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -4,7 +4,6 @@ from collections.abc import Sequence from datetime import datetime from datetime import timezone -from uuid import UUID from sqlalchemy import and_ from sqlalchemy import delete @@ -17,14 +16,17 @@ from sqlalchemy.engine.util import TransactionalContext from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import null from danswer.configs.constants import DEFAULT_BOOST +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.feedback import delete_document_feedback_for_documents__no_commit from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Credential from danswer.db.models import Document as DbDocument from danswer.db.models import DocumentByConnectorCredentialPair +from danswer.db.models import User from danswer.db.tag import delete_document_tags_for_documents__no_commit from danswer.db.utils import model_to_dict from danswer.document_index.interfaces import DocumentMetadata @@ -186,16 +188,14 @@ def get_document_counts_for_cc_pairs( def get_access_info_for_document( db_session: Session, document_id: str, -) -> tuple[str, list[UUID | None], bool] | None: +) -> tuple[str, list[str | None], bool] | None: """Gets access info for a single document by calling the get_access_info_for_documents function and passing a list with a single document ID. - Args: db_session (Session): The database session to use. document_id (str): The document ID to fetch access info for. - Returns: - Optional[Tuple[str, List[UUID | None], bool]]: A tuple containing the document ID, a list of user IDs, + Optional[Tuple[str, List[str | None], bool]]: A tuple containing the document ID, a list of user emails, and a boolean indicating if the document is globally public, or None if no results are found. """ results = get_access_info_for_documents(db_session, [document_id]) @@ -208,19 +208,27 @@ def get_access_info_for_document( def get_access_info_for_documents( db_session: Session, document_ids: list[str], -) -> Sequence[tuple[str, list[UUID | None], bool]]: +) -> Sequence[tuple[str, list[str | None], bool]]: """Gets back all relevant access info for the given documents. This includes the user_ids for cc pairs that the document is associated with + whether any of the associated cc pairs are intending to make the document globally public. + Returns the list where each element contains: + - Document ID (which is also the ID of the DocumentByConnectorCredentialPair) + - List of emails of Danswer users with direct access to the doc (includes a "None" element if + the connector was set up by an admin when auth was off + - bool for whether the document is public (the document later can also be marked public by + automatic permission sync step) """ + stmt = select( + DocumentByConnectorCredentialPair.id, + func.array_agg(func.coalesce(User.email, null())).label("user_emails"), + func.bool_or(ConnectorCredentialPair.access_type == AccessType.PUBLIC).label( + "public_doc" + ), + ).where(DocumentByConnectorCredentialPair.id.in_(document_ids)) + stmt = ( - select( - DocumentByConnectorCredentialPair.id, - func.array_agg(Credential.user_id).label("user_ids"), - func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"), - ) - .where(DocumentByConnectorCredentialPair.id.in_(document_ids)) - .join( + stmt.join( Credential, DocumentByConnectorCredentialPair.credential_id == Credential.id, ) @@ -233,6 +241,13 @@ def get_access_info_for_documents( == ConnectorCredentialPair.credential_id, ), ) + .outerjoin( + User, + and_( + Credential.user_id == User.id, + ConnectorCredentialPair.access_type != AccessType.SYNC, + ), + ) # don't include CC pairs that are being deleted # NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them .where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING) @@ -278,9 +293,19 @@ def upsert_documents( for doc in seen_documents.values() ] ) - # for now, there are no columns to update. If more metadata is added, then this - # needs to change to an `on_conflict_do_update` - on_conflict_stmt = insert_stmt.on_conflict_do_nothing() + + on_conflict_stmt = insert_stmt.on_conflict_do_update( + index_elements=["id"], # Conflict target + set_={ + "from_ingestion_api": insert_stmt.excluded.from_ingestion_api, + "boost": insert_stmt.excluded.boost, + "hidden": insert_stmt.excluded.hidden, + "semantic_id": insert_stmt.excluded.semantic_id, + "link": insert_stmt.excluded.link, + "primary_owners": insert_stmt.excluded.primary_owners, + "secondary_owners": insert_stmt.excluded.secondary_owners, + }, + ) db_session.execute(on_conflict_stmt) db_session.commit() diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py index a8c1e4ebb1d..0ba6c4e9ab3 100644 --- a/backend/danswer/db/document_set.py +++ b/backend/danswer/db/document_set.py @@ -14,6 +14,7 @@ from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids from danswer.db.connector_credential_pair import get_connector_credential_pairs +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Document @@ -180,7 +181,7 @@ def _check_if_cc_pairs_are_owned_by_groups( ids=missing_cc_pair_ids, ) for cc_pair in cc_pairs: - if not cc_pair.is_public: + if cc_pair.access_type != AccessType.PUBLIC: raise ValueError( f"Connector Credential Pair with ID: '{cc_pair.id}'" " is not owned by the specified groups" @@ -704,7 +705,7 @@ def check_document_sets_are_public( ConnectorCredentialPair.id.in_( connector_credential_pair_ids # type:ignore ), - ConnectorCredentialPair.is_public.is_(False), + ConnectorCredentialPair.access_type != AccessType.PUBLIC, ) .limit(1) .first() diff --git a/backend/danswer/db/enums.py b/backend/danswer/db/enums.py index eac048e10ab..8d9515d387a 100644 --- a/backend/danswer/db/enums.py +++ b/backend/danswer/db/enums.py @@ -51,3 +51,9 @@ class ConnectorCredentialPairStatus(str, PyEnum): def is_active(self) -> bool: return self == ConnectorCredentialPairStatus.ACTIVE + + +class AccessType(str, PyEnum): + PUBLIC = "public" + PRIVATE = "private" + SYNC = "sync" diff --git a/backend/danswer/db/feedback.py b/backend/danswer/db/feedback.py index 6df1f1f5051..219e2474729 100644 --- a/backend/danswer/db/feedback.py +++ b/backend/danswer/db/feedback.py @@ -16,6 +16,7 @@ from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.db.chat import get_chat_message +from danswer.db.enums import AccessType from danswer.db.models import ChatMessageFeedback from danswer.db.models import ConnectorCredentialPair from danswer.db.models import Document as DbDocument @@ -94,7 +95,7 @@ def _add_user_filters( .correlate(CCPair) ) else: - where_clause |= CCPair.is_public == True # noqa: E712 + where_clause |= CCPair.access_type == AccessType.PUBLIC return stmt.where(where_clause) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 33d8d8bc0d8..fd2d1344afd 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -39,6 +39,7 @@ from danswer.configs.constants import DocumentSource from danswer.configs.constants import FileOrigin from danswer.configs.constants import MessageType +from danswer.db.enums import AccessType from danswer.configs.constants import NotificationType from danswer.configs.constants import SearchFeedbackType from danswer.configs.constants import TokenRateLimitScope @@ -388,10 +389,20 @@ class ConnectorCredentialPair(Base): # controls whether the documents indexed by this CC pair are visible to all # or if they are only visible to those with that are given explicit access # (e.g. via owning the credential or being a part of a group that is given access) - is_public: Mapped[bool] = mapped_column( - Boolean, - default=True, - nullable=False, + access_type: Mapped[AccessType] = mapped_column( + Enum(AccessType, native_enum=False), nullable=False + ) + + # special info needed for the auto-sync feature. The exact structure depends on the + + # source type (defined in the connector's `source` field) + # E.g. for google_drive perm sync: + # {"customer_id": "123567", "company_domain": "@danswer.ai"} + auto_sync_options: Mapped[dict[str, Any] | None] = mapped_column( + postgresql.JSONB(), nullable=True + ) + last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True ) # Time finished, not used for calculating backend jobs which uses time started (created) last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column( @@ -422,6 +433,7 @@ class ConnectorCredentialPair(Base): class Document(Base): __tablename__ = "document" + # NOTE: if more sensitive data is added here for display, make sure to add user/group permission # this should correspond to the ID of the document # (as is passed around in Danswer) @@ -465,7 +477,18 @@ class Document(Base): secondary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) - # TODO if more sensitive data is added here for display, make sure to add user/group permission + # Permission sync columns + # Email addresses are saved at the document level for externally synced permissions + # This is becuase the normal flow of assigning permissions is through the cc_pair + # doesn't apply here + external_user_emails: Mapped[list[str] | None] = mapped_column( + postgresql.ARRAY(String), nullable=True + ) + # These group ids have been prefixed by the source type + external_user_group_ids: Mapped[list[str] | None] = mapped_column( + postgresql.ARRAY(String), nullable=True + ) + is_public: Mapped[bool] = mapped_column(Boolean, default=False) retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship( "DocumentRetrievalFeedback", back_populates="document" @@ -1674,95 +1697,18 @@ class StandardAnswer(Base): """Tables related to Permission Sync""" -class PermissionSyncStatus(str, PyEnum): - IN_PROGRESS = "in_progress" - SUCCESS = "success" - FAILED = "failed" - - -class PermissionSyncJobType(str, PyEnum): - USER_LEVEL = "user_level" - GROUP_LEVEL = "group_level" - - -class PermissionSyncRun(Base): - """Represents one run of a permission sync job. For some given cc_pair, it is either sync-ing - the users or it is sync-ing the groups""" - - __tablename__ = "permission_sync_run" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - # Not strictly needed but makes it easy to use without fetching from cc_pair - source_type: Mapped[DocumentSource] = mapped_column( - Enum(DocumentSource, native_enum=False) - ) - # Currently all sync jobs are handled as a group permission sync or a user permission sync - update_type: Mapped[PermissionSyncJobType] = mapped_column( - Enum(PermissionSyncJobType) - ) - cc_pair_id: Mapped[int | None] = mapped_column( - ForeignKey("connector_credential_pair.id"), nullable=True - ) - status: Mapped[PermissionSyncStatus] = mapped_column(Enum(PermissionSyncStatus)) - error_msg: Mapped[str | None] = mapped_column(Text, default=None) - updated_at: Mapped[datetime.datetime] = mapped_column( - DateTime(timezone=True), server_default=func.now(), onupdate=func.now() - ) - - cc_pair: Mapped[ConnectorCredentialPair] = relationship("ConnectorCredentialPair") - - -class ExternalPermission(Base): +class User__ExternalUserGroupId(Base): """Maps user info both internal and external to the name of the external group This maps the user to all of their external groups so that the external group name can be attached to the ACL list matching during query time. User level permissions can be handled by directly adding the Danswer user to the doc ACL list""" - __tablename__ = "external_permission" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - user_id: Mapped[UUID | None] = mapped_column( - ForeignKey("user.id", ondelete="CASCADE"), nullable=True - ) - # Email is needed because we want to keep track of users not in Danswer to simplify process - # when the user joins - user_email: Mapped[str] = mapped_column(String) - source_type: Mapped[DocumentSource] = mapped_column( - Enum(DocumentSource, native_enum=False) - ) - external_permission_group: Mapped[str] = mapped_column(String) - user = relationship("User") - - -class EmailToExternalUserCache(Base): - """A way to map users IDs in the external tool to a user in Danswer or at least an email for - when the user joins. Used as a cache for when fetching external groups which have their own - user ids, this can easily be mapped back to users already known in Danswer without needing - to call external APIs to get the user emails. - - This way when groups are updated in the external tool and we need to update the mapping of - internal users to the groups, we can sync the internal users to the external groups they are - part of using this. - - Ie. User Chris is part of groups alpha, beta, and we can update this if Chris is no longer - part of alpha in some external tool. - """ - - __tablename__ = "email_to_external_user_cache" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - external_user_id: Mapped[str] = mapped_column(String) - user_id: Mapped[UUID | None] = mapped_column( - ForeignKey("user.id", ondelete="CASCADE"), nullable=True - ) - # Email is needed because we want to keep track of users not in Danswer to simplify process - # when the user joins - user_email: Mapped[str] = mapped_column(String) - source_type: Mapped[DocumentSource] = mapped_column( - Enum(DocumentSource, native_enum=False) - ) + __tablename__ = "user__external_user_group_id" - user = relationship("User") + user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True) + # These group ids have been prefixed by the source type + external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True) + cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id")) class UsageReport(Base): diff --git a/backend/danswer/db/users.py b/backend/danswer/db/users.py index 61ba6e475fe..a6319481bc4 100644 --- a/backend/danswer/db/users.py +++ b/backend/danswer/db/users.py @@ -22,6 +22,17 @@ def list_users( return db_session.scalars(stmt).unique().all() +def get_users_by_emails( + db_session: Session, emails: list[str] +) -> tuple[list[User], list[str]]: + # Use distinct to avoid duplicates + stmt = select(User).filter(User.email.in_(emails)) # type: ignore + found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list + found_users_emails = [user.email for user in found_users] + missing_user_emails = [email for email in emails if email not in found_users_emails] + return found_users, missing_user_emails + + def get_user_by_email(email: str, db_session: Session) -> User | None: user = db_session.query(User).filter(User.email == email).first() # type: ignore @@ -34,20 +45,50 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None: return user -def add_non_web_user_if_not_exists(email: str, db_session: Session) -> User: - user = get_user_by_email(email, db_session) - if user is not None: - return user - +def _generate_non_web_user(email: str) -> User: fastapi_users_pw_helper = PasswordHelper() password = fastapi_users_pw_helper.generate() hashed_pass = fastapi_users_pw_helper.hash(password) - user = User( + return User( email=email, hashed_password=hashed_pass, has_web_login=False, role=UserRole.BASIC, ) + + +def add_non_web_user_if_not_exists(db_session: Session, email: str) -> User: + user = get_user_by_email(email, db_session) + if user is not None: + return user + + user = _generate_non_web_user(email=email) db_session.add(user) db_session.commit() return user + + +def add_non_web_user_if_not_exists__no_commit(db_session: Session, email: str) -> User: + user = get_user_by_email(email, db_session) + if user is not None: + return user + + user = _generate_non_web_user(email=email) + db_session.add(user) + db_session.flush() # generate id + return user + + +def batch_add_non_web_user_if_not_exists__no_commit( + db_session: Session, emails: list[str] +) -> list[User]: + found_users, missing_user_emails = get_users_by_emails(db_session, emails) + + new_users: list[User] = [] + for email in missing_user_emails: + new_users.append(_generate_non_web_user(email=email)) + + db_session.add_all(new_users) + db_session.flush() # generate ids + + return found_users + new_users diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index 51cd23e7431..5d7412ea9d7 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -265,7 +265,13 @@ def index_doc_batch( Note that the documents should already be batched at this point so that it does not inflate the memory requirements""" - no_access = DocumentAccess.build(user_ids=[], user_groups=[], is_public=False) + no_access = DocumentAccess.build( + user_emails=[], + user_groups=[], + external_user_emails=[], + external_user_group_ids=[], + is_public=False, + ) ctx = index_doc_batch_prepare( document_batch=document_batch, diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 876886ca28d..b06aacc1e6a 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -18,6 +18,7 @@ ) from danswer.db.document import get_document_counts_for_cc_pairs from danswer.db.engine import get_session +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair from danswer.db.index_attempt import cancel_indexing_attempts_past_model @@ -201,7 +202,7 @@ def associate_credential_to_connector( db_session=db_session, user=user, target_group_ids=metadata.groups, - object_is_public=metadata.is_public, + object_is_public=metadata.access_type == AccessType.PUBLIC, ) try: @@ -211,7 +212,8 @@ def associate_credential_to_connector( connector_id=connector_id, credential_id=credential_id, cc_pair_name=metadata.name, - is_public=True if metadata.is_public is None else metadata.is_public, + access_type=metadata.access_type, + auto_sync_options=metadata.auto_sync_options, groups=metadata.groups, ) diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 73e28b8fb0b..58dcf7e7691 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -64,6 +64,7 @@ from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed from danswer.db.document import get_document_counts_for_cc_pairs from danswer.db.engine import get_session +from danswer.db.enums import AccessType from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempts_for_cc_pair from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id @@ -559,7 +560,7 @@ def get_connector_indexing_status( cc_pair_status=cc_pair.status, connector=ConnectorSnapshot.from_connector_db_model(connector), credential=CredentialSnapshot.from_credential_db_model(credential), - public_doc=cc_pair.is_public, + access_type=cc_pair.access_type, owner=credential.user.email if credential.user else "", groups=group_cc_pair_relationships_dict.get(cc_pair.id, []), last_finished_status=( @@ -668,12 +669,15 @@ def create_connector_with_mock_credential( credential = create_credential( mock_credential, user=user, db_session=db_session ) + access_type = ( + AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE + ) response = add_credential_to_connector( db_session=db_session, user=user, connector_id=cast(int, connector_response.id), # will aways be an int credential_id=credential.id, - is_public=connector_data.is_public or False, + access_type=access_type, cc_pair_name=connector_data.name, groups=connector_data.groups, ) diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index 517813892b8..b4052303ba8 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -10,6 +10,7 @@ from danswer.configs.constants import DocumentSource from danswer.connectors.models import DocumentErrorSummary from danswer.connectors.models import InputType +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import Connector from danswer.db.models import ConnectorCredentialPair @@ -218,7 +219,7 @@ class CCPairFullInfo(BaseModel): number_of_index_attempts: int last_index_attempt_status: IndexingStatus | None latest_deletion_attempt: DeletionAttemptSnapshot | None - is_public: bool + access_type: AccessType is_editable_for_current_user: bool deletion_failure_message: str | None @@ -261,7 +262,7 @@ def from_models( number_of_index_attempts=number_of_index_attempts, last_index_attempt_status=last_indexing_status, latest_deletion_attempt=latest_deletion_attempt, - is_public=cc_pair_model.is_public, + access_type=cc_pair_model.access_type, is_editable_for_current_user=is_editable_for_current_user, deletion_failure_message=cc_pair_model.deletion_failure_message, ) @@ -288,7 +289,7 @@ class ConnectorIndexingStatus(BaseModel): credential: CredentialSnapshot owner: str groups: list[int] - public_doc: bool + access_type: AccessType last_finished_status: IndexingStatus | None last_status: IndexingStatus | None last_success: datetime | None @@ -306,7 +307,8 @@ class ConnectorCredentialPairIdentifier(BaseModel): class ConnectorCredentialPairMetadata(BaseModel): name: str | None = None - is_public: bool | None = None + access_type: AccessType + auto_sync_options: dict[str, Any] | None = None groups: list[int] = Field(default_factory=list) diff --git a/backend/danswer/server/features/document_set/models.py b/backend/danswer/server/features/document_set/models.py index 55f3376545f..740cb6906cf 100644 --- a/backend/danswer/server/features/document_set/models.py +++ b/backend/danswer/server/features/document_set/models.py @@ -47,7 +47,6 @@ class DocumentSet(BaseModel): description: str cc_pair_descriptors: list[ConnectorCredentialPairDescriptor] is_up_to_date: bool - contains_non_public: bool is_public: bool # For Private Document Sets, who should be able to access these users: list[UUID] @@ -59,12 +58,6 @@ def from_model(cls, document_set_model: DocumentSetDBModel) -> "DocumentSet": id=document_set_model.id, name=document_set_model.name, description=document_set_model.description, - contains_non_public=any( - [ - not cc_pair.is_public - for cc_pair in document_set_model.connector_credential_pairs - ] - ), cc_pair_descriptors=[ ConnectorCredentialPairDescriptor( id=cc_pair.id, diff --git a/backend/danswer/server/manage/embedding/models.py b/backend/danswer/server/manage/embedding/models.py index b4ca7862b55..d6210118df5 100644 --- a/backend/danswer/server/manage/embedding/models.py +++ b/backend/danswer/server/manage/embedding/models.py @@ -18,6 +18,9 @@ class TestEmbeddingRequest(BaseModel): api_url: str | None = None model_name: str | None = None + # This disables the "model_" protected namespace for pydantic + model_config = {"protected_namespaces": ()} + class CloudEmbeddingProvider(BaseModel): provider_type: EmbeddingProvider diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 46beee856f8..04d2c124493 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -48,6 +48,7 @@ from danswer.server.models import MinimalUserSnapshot from danswer.utils.logger import setup_logger from ee.danswer.db.api_key import is_api_key_email_address +from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit from ee.danswer.db.user_group import remove_curator_status__no_commit logger = setup_logger() @@ -243,6 +244,11 @@ async def delete_user( for oauth_account in user_to_delete.oauth_accounts: db_session.delete(oauth_account) + delete_user__ext_group_for_user__no_commit( + db_session=db_session, + user_id=user_to_delete.id, + ) + db_session.query(DocumentSet__User).filter( DocumentSet__User.user_id == user_to_delete.id ).delete() diff --git a/backend/ee/danswer/access/access.py b/backend/ee/danswer/access/access.py index 2b3cdb7a9dc..094298677a5 100644 --- a/backend/ee/danswer/access/access.py +++ b/backend/ee/danswer/access/access.py @@ -5,8 +5,11 @@ ) from danswer.access.access import _get_acl_for_user as get_acl_for_user_without_groups from danswer.access.models import DocumentAccess +from danswer.access.utils import prefix_external_group from danswer.access.utils import prefix_user_group +from danswer.db.document import get_documents_by_ids from danswer.db.models import User +from ee.danswer.db.external_perm import fetch_external_groups_for_user from ee.danswer.db.user_group import fetch_user_groups_for_documents from ee.danswer.db.user_group import fetch_user_groups_for_user @@ -17,7 +20,13 @@ def _get_access_for_document( ) -> DocumentAccess: id_to_access = _get_access_for_documents([document_id], db_session) if len(id_to_access) == 0: - return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False) + return DocumentAccess.build( + user_emails=[], + user_groups=[], + external_user_emails=[], + external_user_group_ids=[], + is_public=False, + ) return next(iter(id_to_access.values())) @@ -30,22 +39,48 @@ def _get_access_for_documents( document_ids=document_ids, db_session=db_session, ) - user_group_info = { + user_group_info: dict[str, list[str]] = { document_id: group_names for document_id, group_names in fetch_user_groups_for_documents( db_session=db_session, document_ids=document_ids, ) } + documents = get_documents_by_ids( + db_session=db_session, + document_ids=document_ids, + ) + doc_id_map = {doc.id: doc for doc in documents} - return { - document_id: DocumentAccess( - user_ids=non_ee_access.user_ids, - user_groups=user_group_info.get(document_id, []), # type: ignore - is_public=non_ee_access.is_public, + access_map = {} + for document_id, non_ee_access in non_ee_access_dict.items(): + document = doc_id_map[document_id] + + ext_u_emails = ( + set(document.external_user_emails) + if document.external_user_emails + else set() + ) + + ext_u_groups = ( + set(document.external_user_group_ids) + if document.external_user_group_ids + else set() ) - for document_id, non_ee_access in non_ee_access_dict.items() - } + + # If the document is determined to be "public" externally (through a SYNC connector) + # then it's given the same access level as if it were marked public within Danswer + is_public_anywhere = document.is_public or non_ee_access.is_public + + # To avoid collisions of group namings between connectors, they need to be prefixed + access_map[document_id] = DocumentAccess( + user_emails=non_ee_access.user_emails, + user_groups=set(user_group_info.get(document_id, [])), + is_public=is_public_anywhere, + external_user_emails=ext_u_emails, + external_user_group_ids=ext_u_groups, + ) + return access_map def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]: @@ -56,7 +91,20 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]: NOTE: is imported in danswer.access.access by `fetch_versioned_implementation` DO NOT REMOVE.""" - user_groups = fetch_user_groups_for_user(db_session, user.id) if user else [] - return set( - [prefix_user_group(user_group.name) for user_group in user_groups] - ).union(get_acl_for_user_without_groups(user, db_session)) + db_user_groups = fetch_user_groups_for_user(db_session, user.id) if user else [] + prefixed_user_groups = [ + prefix_user_group(db_user_group.name) for db_user_group in db_user_groups + ] + + db_external_groups = ( + fetch_external_groups_for_user(db_session, user.id) if user else [] + ) + prefixed_external_groups = [ + prefix_external_group(db_external_group.external_user_group_id) + for db_external_group in db_external_groups + ] + + user_acl = set(prefixed_user_groups + prefixed_external_groups) + user_acl.update(get_acl_for_user_without_groups(user, db_session)) + + return user_acl diff --git a/backend/ee/danswer/background/celery/celery_app.py b/backend/ee/danswer/background/celery/celery_app.py index 2b4c96ccb1e..2b3475beee9 100644 --- a/backend/ee/danswer/background/celery/celery_app.py +++ b/backend/ee/danswer/background/celery/celery_app.py @@ -11,7 +11,13 @@ from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import global_version from ee.danswer.background.celery_utils import should_perform_chat_ttl_check +from ee.danswer.background.celery_utils import should_perform_external_permissions_check from ee.danswer.background.task_name_builders import name_chat_ttl_task +from ee.danswer.background.task_name_builders import name_sync_external_permissions_task +from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs +from ee.danswer.external_permissions.permission_sync import ( + run_permission_sync_entrypoint, +) from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report logger = setup_logger() @@ -20,6 +26,13 @@ global_version.set_ee() +@build_celery_task_wrapper(name_sync_external_permissions_task) +@celery_app.task(soft_time_limit=JOB_TIMEOUT) +def sync_external_permissions_task(cc_pair_id: int) -> None: + with Session(get_sqlalchemy_engine()) as db_session: + run_permission_sync_entrypoint(db_session=db_session, cc_pair_id=cc_pair_id) + + @build_celery_task_wrapper(name_chat_ttl_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def perform_ttl_management_task(retention_limit_days: int) -> None: @@ -30,6 +43,23 @@ def perform_ttl_management_task(retention_limit_days: int) -> None: ##### # Periodic Tasks ##### +@celery_app.task( + name="check_sync_external_permissions_task", + soft_time_limit=JOB_TIMEOUT, +) +def check_sync_external_permissions_task() -> None: + """Runs periodically to sync external permissions""" + with Session(get_sqlalchemy_engine()) as db_session: + cc_pairs = get_all_auto_sync_cc_pairs(db_session) + for cc_pair in cc_pairs: + if should_perform_external_permissions_check( + cc_pair=cc_pair, db_session=db_session + ): + sync_external_permissions_task.apply_async( + kwargs=dict(cc_pair_id=cc_pair.id), + ) + + @celery_app.task( name="check_ttl_management_task", soft_time_limit=JOB_TIMEOUT, @@ -64,7 +94,11 @@ def autogenerate_usage_report_task() -> None: # Celery Beat (Periodic Tasks) Settings ##### celery_app.conf.beat_schedule = { - "autogenerate-usage-report": { + "sync-external-permissions": { + "task": "check_sync_external_permissions_task", + "schedule": timedelta(seconds=60), # TODO: optimize this + }, + "autogenerate_usage_report": { "task": "autogenerate_usage_report_task", "schedule": timedelta(days=30), # TODO: change this to config flag }, diff --git a/backend/ee/danswer/background/celery_utils.py b/backend/ee/danswer/background/celery_utils.py index 879487180af..bd1c5d874db 100644 --- a/backend/ee/danswer/background/celery_utils.py +++ b/backend/ee/danswer/background/celery_utils.py @@ -6,10 +6,13 @@ from danswer.background.celery.celery_app import task_logger from danswer.background.celery.celery_redis import RedisUserGroup from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.enums import AccessType +from danswer.db.models import ConnectorCredentialPair from danswer.db.tasks import check_task_is_live_and_not_timed_out from danswer.db.tasks import get_latest_task from danswer.utils.logger import setup_logger from ee.danswer.background.task_name_builders import name_chat_ttl_task +from ee.danswer.background.task_name_builders import name_sync_external_permissions_task from ee.danswer.db.user_group import delete_user_group from ee.danswer.db.user_group import fetch_user_group from ee.danswer.db.user_group import mark_user_group_as_synced @@ -30,11 +33,30 @@ def should_perform_chat_ttl_check( return True if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session): - logger.info("TTL check is already being performed. Skipping.") + logger.debug(f"{task_name} is already being performed. Skipping.") return False return True +def should_perform_external_permissions_check( + cc_pair: ConnectorCredentialPair, db_session: Session +) -> bool: + if cc_pair.access_type != AccessType.SYNC: + return False + + task_name = name_sync_external_permissions_task(cc_pair_id=cc_pair.id) + + latest_task = get_latest_task(task_name, db_session) + if not latest_task: + return True + + if check_task_is_live_and_not_timed_out(latest_task, db_session): + logger.debug(f"{task_name} is already being performed. Skipping.") + return False + + return True + + def monitor_usergroup_taskset(key_bytes: bytes, r: Redis) -> None: """This function is likely to move in the worker refactor happening next.""" key = key_bytes.decode("utf-8") diff --git a/backend/ee/danswer/background/permission_sync.py b/backend/ee/danswer/background/permission_sync.py deleted file mode 100644 index c14094b6042..00000000000 --- a/backend/ee/danswer/background/permission_sync.py +++ /dev/null @@ -1,224 +0,0 @@ -import logging -import time -from datetime import datetime - -import dask -from dask.distributed import Client -from dask.distributed import Future -from distributed import LocalCluster -from sqlalchemy.orm import Session - -from danswer.background.indexing.dask_utils import ResourceLogger -from danswer.background.indexing.job_client import SimpleJob -from danswer.background.indexing.job_client import SimpleJobClient -from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT -from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED -from danswer.configs.constants import DocumentSource -from danswer.configs.constants import POSTGRES_PERMISSIONS_APP_NAME -from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.engine import init_sqlalchemy_engine -from danswer.db.models import PermissionSyncStatus -from danswer.utils.logger import setup_logger -from ee.danswer.configs.app_configs import NUM_PERMISSION_WORKERS -from ee.danswer.connectors.factory import CONNECTOR_PERMISSION_FUNC_MAP -from ee.danswer.db.connector import fetch_sources_with_connectors -from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source -from ee.danswer.db.permission_sync import create_perm_sync -from ee.danswer.db.permission_sync import expire_perm_sync_timed_out -from ee.danswer.db.permission_sync import get_perm_sync_attempt -from ee.danswer.db.permission_sync import mark_all_inprogress_permission_sync_failed -from shared_configs.configs import LOG_LEVEL - -logger = setup_logger() - -# If the indexing dies, it's most likely due to resource constraints, -# restarting just delays the eventual failure, not useful to the user -dask.config.set({"distributed.scheduler.allowed-failures": 0}) - - -def cleanup_perm_sync_jobs( - existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob], - # Just reusing the same timeout, fine for now - timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT, -) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]: - existing_jobs_copy = existing_jobs.copy() - - with Session(get_sqlalchemy_engine()) as db_session: - # clean up completed jobs - for (attempt_id, details), job in existing_jobs.items(): - perm_sync_attempt = get_perm_sync_attempt( - attempt_id=attempt_id, db_session=db_session - ) - - # do nothing for ongoing jobs that haven't been stopped - if ( - not job.done() - and perm_sync_attempt.status == PermissionSyncStatus.IN_PROGRESS - ): - continue - - if job.status == "error": - logger.error(job.exception()) - - job.release() - del existing_jobs_copy[(attempt_id, details)] - - # clean up in-progress jobs that were never completed - expire_perm_sync_timed_out( - timeout_hours=timeout_hours, - db_session=db_session, - ) - - return existing_jobs_copy - - -def create_group_sync_jobs( - existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob], - client: Client | SimpleJobClient, -) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]: - """Creates new relational DB group permission sync job for each source that: - - has permission sync enabled - - has at least 1 connector (enabled or paused) - - has no sync already running - """ - existing_jobs_copy = existing_jobs.copy() - sources_w_runs = [ - key[1] - for key in existing_jobs_copy.keys() - if isinstance(key[1], DocumentSource) - ] - with Session(get_sqlalchemy_engine()) as db_session: - sources_w_connector = fetch_sources_with_connectors(db_session) - for source_type in sources_w_connector: - if source_type not in CONNECTOR_PERMISSION_FUNC_MAP: - continue - if source_type in sources_w_runs: - continue - - db_group_fnc, _ = CONNECTOR_PERMISSION_FUNC_MAP[source_type] - perm_sync = create_perm_sync( - source_type=source_type, - group_update=True, - cc_pair_id=None, - db_session=db_session, - ) - - run = client.submit(db_group_fnc, pure=False) - - logger.info( - f"Kicked off group permission sync for source type {source_type}" - ) - - if run: - existing_jobs_copy[(perm_sync.id, source_type)] = run - - return existing_jobs_copy - - -def create_connector_perm_sync_jobs( - existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob], - client: Client | SimpleJobClient, -) -> dict[tuple[int, int | DocumentSource], Future | SimpleJob]: - """Update Document Index ACL sync job for each cc-pair where: - - source type has permission sync enabled - - has no sync already running - """ - existing_jobs_copy = existing_jobs.copy() - cc_pairs_w_runs = [ - key[1] - for key in existing_jobs_copy.keys() - if isinstance(key[1], DocumentSource) - ] - with Session(get_sqlalchemy_engine()) as db_session: - sources_w_connector = fetch_sources_with_connectors(db_session) - for source_type in sources_w_connector: - if source_type not in CONNECTOR_PERMISSION_FUNC_MAP: - continue - - _, index_sync_fnc = CONNECTOR_PERMISSION_FUNC_MAP[source_type] - - cc_pairs = get_cc_pairs_by_source(source_type, db_session) - - for cc_pair in cc_pairs: - if cc_pair.id in cc_pairs_w_runs: - continue - - perm_sync = create_perm_sync( - source_type=source_type, - group_update=False, - cc_pair_id=cc_pair.id, - db_session=db_session, - ) - - run = client.submit(index_sync_fnc, cc_pair.id, pure=False) - - logger.info(f"Kicked off ACL sync for cc-pair {cc_pair.id}") - - if run: - existing_jobs_copy[(perm_sync.id, cc_pair.id)] = run - - return existing_jobs_copy - - -def permission_loop(delay: int = 60, num_workers: int = NUM_PERMISSION_WORKERS) -> None: - client: Client | SimpleJobClient - if DASK_JOB_CLIENT_ENABLED: - cluster_primary = LocalCluster( - n_workers=num_workers, - threads_per_worker=1, - # there are warning about high memory usage + "Event loop unresponsive" - # which are not relevant to us since our workers are expected to use a - # lot of memory + involve CPU intensive tasks that will not relinquish - # the event loop - silence_logs=logging.ERROR, - ) - client = Client(cluster_primary) - if LOG_LEVEL.lower() == "debug": - client.register_worker_plugin(ResourceLogger()) - else: - client = SimpleJobClient(n_workers=num_workers) - - existing_jobs: dict[tuple[int, int | DocumentSource], Future | SimpleJob] = {} - engine = get_sqlalchemy_engine() - - with Session(engine) as db_session: - # Any jobs still in progress on restart must have died - mark_all_inprogress_permission_sync_failed(db_session) - - while True: - start = time.time() - start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S") - logger.info(f"Running Permission Sync, current UTC time: {start_time_utc}") - - if existing_jobs: - logger.debug( - "Found existing permission sync jobs: " - f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}" - ) - - try: - # TODO turn this on when it works - """ - existing_jobs = cleanup_perm_sync_jobs(existing_jobs=existing_jobs) - existing_jobs = create_group_sync_jobs( - existing_jobs=existing_jobs, client=client - ) - existing_jobs = create_connector_perm_sync_jobs( - existing_jobs=existing_jobs, client=client - ) - """ - except Exception as e: - logger.exception(f"Failed to run update due to {e}") - sleep_time = delay - (time.time() - start) - if sleep_time > 0: - time.sleep(sleep_time) - - -def update__main() -> None: - logger.notice("Starting Permission Syncing Loop") - init_sqlalchemy_engine(POSTGRES_PERMISSIONS_APP_NAME) - permission_loop() - - -if __name__ == "__main__": - update__main() diff --git a/backend/ee/danswer/background/task_name_builders.py b/backend/ee/danswer/background/task_name_builders.py index 4f1046adbbb..75a5aa36bd9 100644 --- a/backend/ee/danswer/background/task_name_builders.py +++ b/backend/ee/danswer/background/task_name_builders.py @@ -1,6 +1,6 @@ -def name_user_group_sync_task(user_group_id: int) -> str: - return f"user_group_sync_task__{user_group_id}" - - def name_chat_ttl_task(retention_limit_days: int) -> str: return f"chat_ttl_{retention_limit_days}_days" + + +def name_sync_external_permissions_task(cc_pair_id: int) -> str: + return f"sync_external_permissions_task__{cc_pair_id}" diff --git a/backend/ee/danswer/connectors/confluence/perm_sync.py b/backend/ee/danswer/connectors/confluence/perm_sync.py deleted file mode 100644 index 2985b47b0d1..00000000000 --- a/backend/ee/danswer/connectors/confluence/perm_sync.py +++ /dev/null @@ -1,12 +0,0 @@ -from danswer.utils.logger import setup_logger - - -logger = setup_logger() - - -def confluence_update_db_group() -> None: - logger.debug("Not yet implemented group sync for confluence, no-op") - - -def confluence_update_index_acl(cc_pair_id: int) -> None: - logger.debug("Not yet implemented ACL sync for confluence, no-op") diff --git a/backend/ee/danswer/connectors/factory.py b/backend/ee/danswer/connectors/factory.py deleted file mode 100644 index 52f9324948b..00000000000 --- a/backend/ee/danswer/connectors/factory.py +++ /dev/null @@ -1,8 +0,0 @@ -from danswer.configs.constants import DocumentSource -from ee.danswer.connectors.confluence.perm_sync import confluence_update_db_group -from ee.danswer.connectors.confluence.perm_sync import confluence_update_index_acl - - -CONNECTOR_PERMISSION_FUNC_MAP = { - DocumentSource.CONFLUENCE: (confluence_update_db_group, confluence_update_index_acl) -} diff --git a/backend/ee/danswer/db/connector_credential_pair.py b/backend/ee/danswer/db/connector_credential_pair.py index a2172913476..bb91c0de74f 100644 --- a/backend/ee/danswer/db/connector_credential_pair.py +++ b/backend/ee/danswer/db/connector_credential_pair.py @@ -3,6 +3,7 @@ from danswer.configs.constants import DocumentSource from danswer.db.connector_credential_pair import get_connector_credential_pair +from danswer.db.enums import AccessType from danswer.db.models import Connector from danswer.db.models import ConnectorCredentialPair from danswer.db.models import UserGroup__ConnectorCredentialPair @@ -32,14 +33,30 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit( def get_cc_pairs_by_source( - source_type: DocumentSource, db_session: Session, + source_type: DocumentSource, + only_sync: bool, ) -> list[ConnectorCredentialPair]: - cc_pairs = ( + query = ( db_session.query(ConnectorCredentialPair) .join(ConnectorCredentialPair.connector) .filter(Connector.source == source_type) - .all() ) + if only_sync: + query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC) + + cc_pairs = query.all() return cc_pairs + + +def get_all_auto_sync_cc_pairs( + db_session: Session, +) -> list[ConnectorCredentialPair]: + return ( + db_session.query(ConnectorCredentialPair) + .where( + ConnectorCredentialPair.access_type == AccessType.SYNC, + ) + .all() + ) diff --git a/backend/ee/danswer/db/document.py b/backend/ee/danswer/db/document.py index 5a368ea170e..d67bc0e57e7 100644 --- a/backend/ee/danswer/db/document.py +++ b/backend/ee/danswer/db/document.py @@ -1,14 +1,47 @@ -from collections.abc import Sequence - from sqlalchemy import select from sqlalchemy.orm import Session -from danswer.db.models import Document +from danswer.access.models import ExternalAccess +from danswer.access.utils import prefix_group_w_source +from danswer.configs.constants import DocumentSource +from danswer.db.models import Document as DbDocument + + +def upsert_document_external_perms__no_commit( + db_session: Session, + doc_id: str, + external_access: ExternalAccess, + source_type: DocumentSource, +) -> None: + """ + This sets the permissions for a document in postgres. + NOTE: this will replace any existing external access, it will not do a union + """ + document = db_session.scalars( + select(DbDocument).where(DbDocument.id == doc_id) + ).first() + + prefixed_external_groups = [ + prefix_group_w_source( + ext_group_name=group_id, + source=source_type, + ) + for group_id in external_access.external_user_group_ids + ] + if not document: + # If the document does not exist, still store the external access + # So that if the document is added later, the external access is already stored + document = DbDocument( + id=doc_id, + semantic_id="", + external_user_emails=external_access.external_user_emails, + external_user_group_ids=prefixed_external_groups, + is_public=external_access.is_public, + ) + db_session.add(document) + return -def fetch_documents_from_ids( - db_session: Session, document_ids: list[str] -) -> Sequence[Document]: - return db_session.scalars( - select(Document).where(Document.id.in_(document_ids)) - ).all() + document.external_user_emails = list(external_access.external_user_emails) + document.external_user_group_ids = prefixed_external_groups + document.is_public = external_access.is_public diff --git a/backend/ee/danswer/db/external_perm.py b/backend/ee/danswer/db/external_perm.py new file mode 100644 index 00000000000..25881df55d3 --- /dev/null +++ b/backend/ee/danswer/db/external_perm.py @@ -0,0 +1,77 @@ +from collections.abc import Sequence +from uuid import UUID + +from pydantic import BaseModel +from sqlalchemy import delete +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.access.utils import prefix_group_w_source +from danswer.configs.constants import DocumentSource +from danswer.db.models import User__ExternalUserGroupId + + +class ExternalUserGroup(BaseModel): + id: str + user_ids: list[UUID] + + +def delete_user__ext_group_for_user__no_commit( + db_session: Session, + user_id: UUID, +) -> None: + db_session.execute( + delete(User__ExternalUserGroupId).where( + User__ExternalUserGroupId.user_id == user_id + ) + ) + + +def delete_user__ext_group_for_cc_pair__no_commit( + db_session: Session, + cc_pair_id: int, +) -> None: + db_session.execute( + delete(User__ExternalUserGroupId).where( + User__ExternalUserGroupId.cc_pair_id == cc_pair_id + ) + ) + + +def replace_user__ext_group_for_cc_pair__no_commit( + db_session: Session, + cc_pair_id: int, + group_defs: list[ExternalUserGroup], + source: DocumentSource, +) -> None: + """ + This function clears all existing external user group relations for a given cc_pair_id + and replaces them with the new group definitions. + """ + delete_user__ext_group_for_cc_pair__no_commit( + db_session=db_session, + cc_pair_id=cc_pair_id, + ) + + new_external_permissions = [ + User__ExternalUserGroupId( + user_id=user_id, + external_user_group_id=prefix_group_w_source(external_group.id, source), + cc_pair_id=cc_pair_id, + ) + for external_group in group_defs + for user_id in external_group.user_ids + ] + + db_session.add_all(new_external_permissions) + + +def fetch_external_groups_for_user( + db_session: Session, + user_id: UUID, +) -> Sequence[User__ExternalUserGroupId]: + return db_session.scalars( + select(User__ExternalUserGroupId).where( + User__ExternalUserGroupId.user_id == user_id + ) + ).all() diff --git a/backend/ee/danswer/db/permission_sync.py b/backend/ee/danswer/db/permission_sync.py deleted file mode 100644 index 7642bb65321..00000000000 --- a/backend/ee/danswer/db/permission_sync.py +++ /dev/null @@ -1,72 +0,0 @@ -from datetime import timedelta - -from sqlalchemy import func -from sqlalchemy import select -from sqlalchemy import update -from sqlalchemy.exc import NoResultFound -from sqlalchemy.orm import Session - -from danswer.configs.constants import DocumentSource -from danswer.db.models import PermissionSyncRun -from danswer.db.models import PermissionSyncStatus -from danswer.utils.logger import setup_logger - -logger = setup_logger() - - -def mark_all_inprogress_permission_sync_failed( - db_session: Session, -) -> None: - stmt = ( - update(PermissionSyncRun) - .where(PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS) - .values(status=PermissionSyncStatus.FAILED) - ) - db_session.execute(stmt) - db_session.commit() - - -def get_perm_sync_attempt(attempt_id: int, db_session: Session) -> PermissionSyncRun: - stmt = select(PermissionSyncRun).where(PermissionSyncRun.id == attempt_id) - try: - return db_session.scalars(stmt).one() - except NoResultFound: - raise ValueError(f"No PermissionSyncRun found with id {attempt_id}") - - -def expire_perm_sync_timed_out( - timeout_hours: int, - db_session: Session, -) -> None: - cutoff_time = func.now() - timedelta(hours=timeout_hours) - - update_stmt = ( - update(PermissionSyncRun) - .where( - PermissionSyncRun.status == PermissionSyncStatus.IN_PROGRESS, - PermissionSyncRun.updated_at < cutoff_time, - ) - .values(status=PermissionSyncStatus.FAILED, error_msg="timed out") - ) - - db_session.execute(update_stmt) - db_session.commit() - - -def create_perm_sync( - source_type: DocumentSource, - group_update: bool, - cc_pair_id: int | None, - db_session: Session, -) -> PermissionSyncRun: - new_run = PermissionSyncRun( - source_type=source_type, - status=PermissionSyncStatus.IN_PROGRESS, - group_update=group_update, - cc_pair_id=cc_pair_id, - ) - - db_session.add(new_run) - db_session.commit() - - return new_run diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index ab666f747b5..f62fe19a7d7 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -199,7 +199,7 @@ def fetch_documents_for_user_group_paginated( def fetch_user_groups_for_documents( db_session: Session, document_ids: list[str], -) -> Sequence[tuple[int, list[str]]]: +) -> Sequence[tuple[str, list[str]]]: stmt = ( select(Document.id, func.array_agg(UserGroup.name)) .join( diff --git a/backend/ee/danswer/connectors/__init__.py b/backend/ee/danswer/external_permissions/__init__.py similarity index 100% rename from backend/ee/danswer/connectors/__init__.py rename to backend/ee/danswer/external_permissions/__init__.py diff --git a/backend/ee/danswer/connectors/confluence/__init__.py b/backend/ee/danswer/external_permissions/confluence/__init__.py similarity index 100% rename from backend/ee/danswer/connectors/confluence/__init__.py rename to backend/ee/danswer/external_permissions/confluence/__init__.py diff --git a/backend/ee/danswer/external_permissions/confluence/doc_sync.py b/backend/ee/danswer/external_permissions/confluence/doc_sync.py new file mode 100644 index 00000000000..2044a2e583e --- /dev/null +++ b/backend/ee/danswer/external_permissions/confluence/doc_sync.py @@ -0,0 +1,19 @@ +from typing import Any + +from sqlalchemy.orm import Session + +from danswer.db.models import ConnectorCredentialPair +from danswer.utils.logger import setup_logger +from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo + + +logger = setup_logger() + + +def confluence_doc_sync( + db_session: Session, + cc_pair: ConnectorCredentialPair, + docs_with_additional_info: list[DocsWithAdditionalInfo], + sync_details: dict[str, Any], +) -> None: + logger.debug("Not yet implemented ACL sync for confluence, no-op") diff --git a/backend/ee/danswer/external_permissions/confluence/group_sync.py b/backend/ee/danswer/external_permissions/confluence/group_sync.py new file mode 100644 index 00000000000..6e9e6777d69 --- /dev/null +++ b/backend/ee/danswer/external_permissions/confluence/group_sync.py @@ -0,0 +1,19 @@ +from typing import Any + +from sqlalchemy.orm import Session + +from danswer.db.models import ConnectorCredentialPair +from danswer.utils.logger import setup_logger +from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo + + +logger = setup_logger() + + +def confluence_group_sync( + db_session: Session, + cc_pair: ConnectorCredentialPair, + docs_with_additional_info: list[DocsWithAdditionalInfo], + sync_details: dict[str, Any], +) -> None: + logger.debug("Not yet implemented group sync for confluence, no-op") diff --git a/backend/ee/danswer/external_permissions/google_drive/__init__.py b/backend/ee/danswer/external_permissions/google_drive/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py new file mode 100644 index 00000000000..80ffe471e91 --- /dev/null +++ b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py @@ -0,0 +1,148 @@ +from collections.abc import Iterator +from typing import Any +from typing import cast + +from googleapiclient.discovery import build # type: ignore +from googleapiclient.errors import HttpError # type: ignore +from sqlalchemy.orm import Session + +from danswer.access.models import ExternalAccess +from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder +from danswer.connectors.google_drive.connector_auth import ( + get_google_drive_creds, +) +from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES +from danswer.db.models import ConnectorCredentialPair +from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit +from danswer.utils.logger import setup_logger +from ee.danswer.db.document import upsert_document_external_perms__no_commit +from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo + +# Google Drive APIs are quite flakey and may 500 for an +# extended period of time. Trying to combat here by adding a very +# long retry period (~20 minutes of trying every minute) +add_retries = retry_builder(tries=5, delay=5, max_delay=30) + + +logger = setup_logger() + + +def _fetch_permissions_paginated( + drive_service: Any, drive_file_id: str +) -> Iterator[dict[str, Any]]: + next_token = None + + # Check if the file is trashed + # Returning nothing here will cause the external permissions to + # be empty which will get written to vespa (failing shut) + try: + file_metadata = add_retries( + lambda: drive_service.files() + .get(fileId=drive_file_id, fields="id, trashed") + .execute() + )() + except HttpError as e: + if e.resp.status == 404 or e.resp.status == 403: + return + logger.error(f"Failed to fetch permissions: {e}") + raise + + if file_metadata.get("trashed", False): + logger.debug(f"File with ID {drive_file_id} is trashed") + return + + # Get paginated permissions for the file id + while True: + try: + permissions_resp: dict[str, Any] = add_retries( + lambda: ( + drive_service.permissions() + .list( + fileId=drive_file_id, + fields="permissions(id, emailAddress, role, type, domain)", + supportsAllDrives=True, + pageToken=next_token, + ) + .execute() + ) + )() + except HttpError as e: + if e.resp.status == 404 or e.resp.status == 403: + break + logger.error(f"Failed to fetch permissions: {e}") + raise + + for permission in permissions_resp.get("permissions", []): + yield permission + + next_token = permissions_resp.get("nextPageToken") + if not next_token: + break + + +def _fetch_google_permissions_for_document_id( + db_session: Session, + drive_file_id: str, + raw_credentials_json: dict[str, str], + company_google_domains: list[str], +) -> ExternalAccess: + # Authenticate and construct service + google_drive_creds, _ = get_google_drive_creds( + raw_credentials_json, scopes=FETCH_PERMISSIONS_SCOPES + ) + if not google_drive_creds.valid: + raise ValueError("Invalid Google Drive credentials") + + drive_service = build("drive", "v3", credentials=google_drive_creds) + + user_emails: set[str] = set() + group_emails: set[str] = set() + public = False + for permission in _fetch_permissions_paginated(drive_service, drive_file_id): + permission_type = permission["type"] + if permission_type == "user": + user_emails.add(permission["emailAddress"]) + elif permission_type == "group": + group_emails.add(permission["emailAddress"]) + elif permission_type == "domain": + if permission["domain"] in company_google_domains: + public = True + elif permission_type == "anyone": + public = True + + batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails)) + + return ExternalAccess( + external_user_emails=user_emails, + external_user_group_ids=group_emails, + is_public=public, + ) + + +def gdrive_doc_sync( + db_session: Session, + cc_pair: ConnectorCredentialPair, + docs_with_additional_info: list[DocsWithAdditionalInfo], + sync_details: dict[str, Any], +) -> None: + """ + Adds the external permissions to the documents in postgres + if the document doesn't already exists in postgres, we create + it in postgres so that when it gets created later, the permissions are + already populated + """ + for doc in docs_with_additional_info: + ext_access = _fetch_google_permissions_for_document_id( + db_session=db_session, + drive_file_id=doc.additional_info, + raw_credentials_json=cc_pair.credential.credential_json, + company_google_domains=[ + cast(dict[str, str], sync_details)["company_domain"] + ], + ) + upsert_document_external_perms__no_commit( + db_session=db_session, + doc_id=doc.id, + external_access=ext_access, + source_type=cc_pair.connector.source, + ) diff --git a/backend/ee/danswer/external_permissions/google_drive/group_sync.py b/backend/ee/danswer/external_permissions/google_drive/group_sync.py new file mode 100644 index 00000000000..a0fe068be7c --- /dev/null +++ b/backend/ee/danswer/external_permissions/google_drive/group_sync.py @@ -0,0 +1,147 @@ +from collections.abc import Iterator +from typing import Any + +from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore +from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore +from googleapiclient.discovery import build # type: ignore +from googleapiclient.errors import HttpError # type: ignore +from sqlalchemy.orm import Session + +from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder +from danswer.connectors.google_drive.connector_auth import ( + get_google_drive_creds, +) +from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES +from danswer.db.models import ConnectorCredentialPair +from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit +from danswer.utils.logger import setup_logger +from ee.danswer.db.external_perm import ExternalUserGroup +from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit +from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo + +logger = setup_logger() + + +# Google Drive APIs are quite flakey and may 500 for an +# extended period of time. Trying to combat here by adding a very +# long retry period (~20 minutes of trying every minute) +add_retries = retry_builder(tries=5, delay=5, max_delay=30) + + +def _fetch_groups_paginated( + google_drive_creds: ServiceAccountCredentials | OAuthCredentials, + identity_source: str | None = None, + customer_id: str | None = None, +) -> Iterator[dict[str, Any]]: + # Note that Google Drive does not use of update the user_cache as the user email + # comes directly with the call to fetch the groups, therefore this is not a valid + # place to save on requests + if identity_source is None and customer_id is None: + raise ValueError( + "Either identity_source or customer_id must be provided to fetch groups" + ) + + cloud_identity_service = build( + "cloudidentity", "v1", credentials=google_drive_creds + ) + parent = ( + f"identitysources/{identity_source}" + if identity_source + else f"customers/{customer_id}" + ) + + while True: + try: + groups_resp: dict[str, Any] = add_retries( + lambda: (cloud_identity_service.groups().list(parent=parent).execute()) + )() + for group in groups_resp.get("groups", []): + yield group + + next_token = groups_resp.get("nextPageToken") + if not next_token: + break + except HttpError as e: + if e.resp.status == 404 or e.resp.status == 403: + break + logger.error(f"Error fetching groups: {e}") + raise + + +def _fetch_group_members_paginated( + google_drive_creds: ServiceAccountCredentials | OAuthCredentials, + group_name: str, +) -> Iterator[dict[str, Any]]: + cloud_identity_service = build( + "cloudidentity", "v1", credentials=google_drive_creds + ) + next_token = None + while True: + try: + membership_info = add_retries( + lambda: ( + cloud_identity_service.groups() + .memberships() + .searchTransitiveMemberships( + parent=group_name, pageToken=next_token + ) + .execute() + ) + )() + + for member in membership_info.get("memberships", []): + yield member + + next_token = membership_info.get("nextPageToken") + if not next_token: + break + except HttpError as e: + if e.resp.status == 404 or e.resp.status == 403: + break + logger.error(f"Error fetching group members: {e}") + raise + + +def gdrive_group_sync( + db_session: Session, + cc_pair: ConnectorCredentialPair, + docs_with_additional_info: list[DocsWithAdditionalInfo], + sync_details: dict[str, Any], +) -> None: + google_drive_creds, _ = get_google_drive_creds( + cc_pair.credential.credential_json, + scopes=FETCH_GROUPS_SCOPES, + ) + + danswer_groups: list[ExternalUserGroup] = [] + for group in _fetch_groups_paginated( + google_drive_creds, + identity_source=sync_details.get("identity_source"), + customer_id=sync_details.get("customer_id"), + ): + # The id is the group email + group_email = group["groupKey"]["id"] + + group_member_emails: list[str] = [] + for member in _fetch_group_members_paginated(google_drive_creds, group["name"]): + member_keys = member["preferredMemberKey"] + member_emails = [member_key["id"] for member_key in member_keys] + for member_email in member_emails: + group_member_emails.append(member_email) + + group_members = batch_add_non_web_user_if_not_exists__no_commit( + db_session=db_session, emails=group_member_emails + ) + if group_members: + danswer_groups.append( + ExternalUserGroup( + id=group_email, user_ids=[user.id for user in group_members] + ) + ) + + replace_user__ext_group_for_cc_pair__no_commit( + db_session=db_session, + cc_pair_id=cc_pair.id, + group_defs=danswer_groups, + source=cc_pair.connector.source, + ) diff --git a/backend/ee/danswer/external_permissions/permission_sync.py b/backend/ee/danswer/external_permissions/permission_sync.py new file mode 100644 index 00000000000..a3829b64e79 --- /dev/null +++ b/backend/ee/danswer/external_permissions/permission_sync.py @@ -0,0 +1,126 @@ +from datetime import datetime +from datetime import timezone + +from sqlalchemy.orm import Session + +from danswer.access.access import get_access_for_documents +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id +from danswer.db.search_settings import get_current_search_settings +from danswer.document_index.factory import get_default_document_index +from danswer.document_index.interfaces import UpdateRequest +from danswer.utils.logger import setup_logger +from ee.danswer.external_permissions.permission_sync_function_map import ( + DOC_PERMISSIONS_FUNC_MAP, +) +from ee.danswer.external_permissions.permission_sync_function_map import ( + FULL_FETCH_PERIOD_IN_SECONDS, +) +from ee.danswer.external_permissions.permission_sync_function_map import ( + GROUP_PERMISSIONS_FUNC_MAP, +) +from ee.danswer.external_permissions.permission_sync_utils import ( + get_docs_with_additional_info, +) + +logger = setup_logger() + + +def run_permission_sync_entrypoint( + db_session: Session, + cc_pair_id: int, +) -> None: + # TODO: seperate out group and doc sync + cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + if cc_pair is None: + raise ValueError(f"No connector credential pair found for id: {cc_pair_id}") + + source_type = cc_pair.connector.source + + doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) + group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type) + + if doc_sync_func is None: + raise ValueError( + f"No permission sync function found for source type: {source_type}" + ) + + sync_details = cc_pair.auto_sync_options + if sync_details is None: + raise ValueError(f"No auto sync options found for source type: {source_type}") + + # If the source type is not polling, we only fetch the permissions every + # _FULL_FETCH_PERIOD_IN_SECONDS seconds + full_fetch_period = FULL_FETCH_PERIOD_IN_SECONDS[source_type] + if full_fetch_period is not None: + last_sync = cc_pair.last_time_perm_sync + if ( + last_sync + and ( + datetime.now(timezone.utc) - last_sync.replace(tzinfo=timezone.utc) + ).total_seconds() + < full_fetch_period + ): + return + + # Here we run the connector to grab all the ids + # this may grab ids before they are indexed but that is fine because + # we create a document in postgres to hold the permissions info + # until the indexing job has a chance to run + docs_with_additional_info = get_docs_with_additional_info( + db_session=db_session, + cc_pair=cc_pair, + ) + + # This function updates: + # - the user_email <-> external_user_group_id mapping + # in postgres without committing + logger.debug(f"Syncing groups for {source_type}") + if group_sync_func is not None: + group_sync_func( + db_session, + cc_pair, + docs_with_additional_info, + sync_details, + ) + + # This function updates: + # - the user_email <-> document mapping + # - the external_user_group_id <-> document mapping + # in postgres without committing + logger.debug(f"Syncing docs for {source_type}") + doc_sync_func( + db_session, + cc_pair, + docs_with_additional_info, + sync_details, + ) + + # This function fetches the updated access for the documents + # and returns a dictionary of document_ids and access + # This is the access we want to update vespa with + docs_access = get_access_for_documents( + document_ids=[doc.id for doc in docs_with_additional_info], + db_session=db_session, + ) + + # Then we build the update requests to update vespa + update_reqs = [ + UpdateRequest(document_ids=[doc_id], access=doc_access) + for doc_id, doc_access in docs_access.items() + ] + + # Don't bother sync-ing secondary, it will be sync-ed after switch anyway + search_settings = get_current_search_settings(db_session) + document_index = get_default_document_index( + primary_index_name=search_settings.index_name, + secondary_index_name=None, + ) + + try: + # update vespa + document_index.update(update_reqs) + # update postgres + db_session.commit() + except Exception as e: + logger.error(f"Error updating document index: {e}") + db_session.rollback() diff --git a/backend/ee/danswer/external_permissions/permission_sync_function_map.py b/backend/ee/danswer/external_permissions/permission_sync_function_map.py new file mode 100644 index 00000000000..9bedbbc6321 --- /dev/null +++ b/backend/ee/danswer/external_permissions/permission_sync_function_map.py @@ -0,0 +1,54 @@ +from collections.abc import Callable +from typing import Any + +from sqlalchemy.orm import Session + +from danswer.configs.constants import DocumentSource +from danswer.db.models import ConnectorCredentialPair +from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync +from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync +from ee.danswer.external_permissions.google_drive.doc_sync import gdrive_doc_sync +from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group_sync +from ee.danswer.external_permissions.permission_sync_utils import DocsWithAdditionalInfo + +GroupSyncFuncType = Callable[ + [Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]], + None, +] + +DocSyncFuncType = Callable[ + [Session, ConnectorCredentialPair, list[DocsWithAdditionalInfo], dict[str, Any]], + None, +] + +# These functions update: +# - the user_email <-> document mapping +# - the external_user_group_id <-> document mapping +# in postgres without committing +# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK +DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = { + DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync, + DocumentSource.CONFLUENCE: confluence_doc_sync, +} + +# These functions update: +# - the user_email <-> external_user_group_id mapping +# in postgres without committing +# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS +GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = { + DocumentSource.GOOGLE_DRIVE: gdrive_group_sync, + DocumentSource.CONFLUENCE: confluence_group_sync, +} + + +# None means that the connector supports polling from last_time_perm_sync to now +FULL_FETCH_PERIOD_IN_SECONDS: dict[DocumentSource, int | None] = { + # Polling is supported + DocumentSource.GOOGLE_DRIVE: None, + # Polling is not supported so we fetch all doc permissions every 10 minutes + DocumentSource.CONFLUENCE: 10 * 60, +} + + +def check_if_valid_sync_source(source_type: DocumentSource) -> bool: + return source_type in DOC_PERMISSIONS_FUNC_MAP diff --git a/backend/ee/danswer/external_permissions/permission_sync_utils.py b/backend/ee/danswer/external_permissions/permission_sync_utils.py new file mode 100644 index 00000000000..7be817bdbdd --- /dev/null +++ b/backend/ee/danswer/external_permissions/permission_sync_utils.py @@ -0,0 +1,56 @@ +from datetime import datetime +from datetime import timezone +from typing import Any + +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.connectors.factory import instantiate_connector +from danswer.connectors.interfaces import PollConnector +from danswer.connectors.models import InputType +from danswer.db.models import ConnectorCredentialPair +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +class DocsWithAdditionalInfo(BaseModel): + id: str + additional_info: Any + + +def get_docs_with_additional_info( + db_session: Session, + cc_pair: ConnectorCredentialPair, +) -> list[DocsWithAdditionalInfo]: + # Get all document ids that need their permissions updated + runnable_connector = instantiate_connector( + db_session=db_session, + source=cc_pair.connector.source, + input_type=InputType.POLL, + connector_specific_config=cc_pair.connector.connector_specific_config, + credential=cc_pair.credential, + ) + + assert isinstance(runnable_connector, PollConnector) + + current_time = datetime.now(timezone.utc) + start_time = ( + cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp() + if cc_pair.last_time_perm_sync + else 0 + ) + cc_pair.last_time_perm_sync = current_time + + doc_batch_generator = runnable_connector.poll_source( + start=start_time, end=current_time.timestamp() + ) + + docs_with_additional_info = [ + DocsWithAdditionalInfo(id=doc.id, additional_info=doc.additional_info) + for doc_batch in doc_batch_generator + for doc in doc_batch + ] + logger.debug(f"Docs with additional info: {len(docs_with_additional_info)}") + + return docs_with_additional_info diff --git a/backend/scripts/query_time_check/seed_dummy_docs.py b/backend/scripts/query_time_check/seed_dummy_docs.py new file mode 100644 index 00000000000..96b6b4a0133 --- /dev/null +++ b/backend/scripts/query_time_check/seed_dummy_docs.py @@ -0,0 +1,166 @@ +""" +launch: +- api server +- postgres +- vespa +- model server (this is only needed so the api server can startup, no embedding is done) + +Run this script to seed the database with dummy documents. +Then run test_query_times.py to test query times. +""" +import random +from datetime import datetime + +from danswer.access.models import DocumentAccess +from danswer.configs.constants import DocumentSource +from danswer.connectors.models import Document +from danswer.db.engine import get_session_context_manager +from danswer.db.search_settings import get_current_search_settings +from danswer.document_index.vespa.index import VespaIndex +from danswer.indexing.models import ChunkEmbedding +from danswer.indexing.models import DocMetadataAwareIndexChunk +from danswer.indexing.models import IndexChunk +from danswer.utils.timing import log_function_time +from shared_configs.model_server_models import Embedding + + +TOTAL_DOC_SETS = 8 +TOTAL_ACL_ENTRIES_PER_CATEGORY = 80 + + +def generate_random_embedding(dim: int) -> Embedding: + return [random.uniform(-1, 1) for _ in range(dim)] + + +def generate_random_identifier() -> str: + return f"dummy_doc_{random.randint(1, 1000)}" + + +def generate_dummy_chunk( + doc_id: str, + chunk_id: int, + embedding_dim: int, + number_of_acl_entries: int, + number_of_document_sets: int, +) -> DocMetadataAwareIndexChunk: + document = Document( + id=doc_id, + source=DocumentSource.GOOGLE_DRIVE, + sections=[], + metadata={}, + semantic_identifier=generate_random_identifier(), + ) + + chunk = IndexChunk( + chunk_id=chunk_id, + blurb=f"Blurb for chunk {chunk_id} of document {doc_id}.", + content=f"Content for chunk {chunk_id} of document {doc_id}. This is dummy text for testing purposes.", + source_links={}, + section_continuation=False, + source_document=document, + title_prefix=f"Title prefix for doc {doc_id}", + metadata_suffix_semantic="", + metadata_suffix_keyword="", + mini_chunk_texts=None, + embeddings=ChunkEmbedding( + full_embedding=generate_random_embedding(embedding_dim), + mini_chunk_embeddings=[], + ), + title_embedding=generate_random_embedding(embedding_dim), + ) + + document_set_names = [] + for i in range(number_of_document_sets): + document_set_names.append(f"Document Set {i}") + + user_emails: set[str | None] = set() + user_groups: set[str] = set() + external_user_emails: set[str] = set() + external_user_group_ids: set[str] = set() + for i in range(number_of_acl_entries): + user_emails.add(f"user_{i}@example.com") + user_groups.add(f"group_{i}") + external_user_emails.add(f"external_user_{i}@example.com") + external_user_group_ids.add(f"external_group_{i}") + + return DocMetadataAwareIndexChunk.from_index_chunk( + index_chunk=chunk, + access=DocumentAccess( + user_emails=user_emails, + user_groups=user_groups, + external_user_emails=external_user_emails, + external_user_group_ids=external_user_group_ids, + is_public=random.choice([True, False]), + ), + document_sets={document_set for document_set in document_set_names}, + boost=random.randint(-1, 1), + ) + + +@log_function_time() +def do_insertion( + vespa_index: VespaIndex, all_chunks: list[DocMetadataAwareIndexChunk] +) -> None: + insertion_records = vespa_index.index(all_chunks) + print(f"Indexed {len(insertion_records)} documents.") + print( + f"New documents: {sum(1 for record in insertion_records if not record.already_existed)}" + ) + print( + f"Existing documents updated: {sum(1 for record in insertion_records if record.already_existed)}" + ) + + +@log_function_time() +def seed_dummy_docs( + number_of_document_sets: int, + number_of_acl_entries: int, + num_docs: int = 1000, + chunks_per_doc: int = 5, + batch_size: int = 100, +) -> None: + with get_session_context_manager() as db_session: + search_settings = get_current_search_settings(db_session) + index_name = search_settings.index_name + embedding_dim = search_settings.model_dim + + vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None) + print(index_name) + + all_chunks = [] + chunk_count = 0 + for doc_num in range(num_docs): + doc_id = f"dummy_doc_{doc_num}_{datetime.now().isoformat()}" + for chunk_num in range(chunks_per_doc): + chunk = generate_dummy_chunk( + doc_id=doc_id, + chunk_id=chunk_num, + embedding_dim=embedding_dim, + number_of_acl_entries=number_of_acl_entries, + number_of_document_sets=number_of_document_sets, + ) + all_chunks.append(chunk) + chunk_count += 1 + + if len(all_chunks) >= chunks_per_doc * batch_size: + do_insertion(vespa_index, all_chunks) + print( + f"Indexed {chunk_count} chunks out of {num_docs * chunks_per_doc}." + ) + print( + f"percentage: {chunk_count / (num_docs * chunks_per_doc) * 100:.2f}% \n" + ) + all_chunks = [] + + if all_chunks: + do_insertion(vespa_index, all_chunks) + + +if __name__ == "__main__": + seed_dummy_docs( + number_of_document_sets=TOTAL_DOC_SETS, + number_of_acl_entries=TOTAL_ACL_ENTRIES_PER_CATEGORY, + num_docs=100000, + chunks_per_doc=5, + batch_size=1000, + ) diff --git a/backend/scripts/query_time_check/test_query_times.py b/backend/scripts/query_time_check/test_query_times.py new file mode 100644 index 00000000000..c839fc610e1 --- /dev/null +++ b/backend/scripts/query_time_check/test_query_times.py @@ -0,0 +1,122 @@ +""" +RUN THIS AFTER SEED_DUMMY_DOCS.PY +""" +import random +import time + +from danswer.configs.constants import DocumentSource +from danswer.configs.model_configs import DOC_EMBEDDING_DIM +from danswer.db.engine import get_session_context_manager +from danswer.db.search_settings import get_current_search_settings +from danswer.document_index.vespa.index import VespaIndex +from danswer.search.models import IndexFilters +from scripts.query_time_check.seed_dummy_docs import TOTAL_ACL_ENTRIES_PER_CATEGORY +from scripts.query_time_check.seed_dummy_docs import TOTAL_DOC_SETS +from shared_configs.model_server_models import Embedding + +# make sure these are smaller than TOTAL_ACL_ENTRIES_PER_CATEGORY and TOTAL_DOC_SETS, respectively +NUMBER_OF_ACL_ENTRIES_PER_QUERY = 6 +NUMBER_OF_DOC_SETS_PER_QUERY = 2 + + +def get_slowest_99th_percentile(results: list[float]) -> float: + return sorted(results)[int(0.99 * len(results))] + + +# Generate random filters +def _random_filters() -> IndexFilters: + """ + Generate random filters for the query containing: + - NUMBER_OF_ACL_ENTRIES_PER_QUERY user emails + - NUMBER_OF_ACL_ENTRIES_PER_QUERY groups + - NUMBER_OF_ACL_ENTRIES_PER_QUERY external groups + - NUMBER_OF_DOC_SETS_PER_QUERY document sets + """ + access_control_list = [ + f"user_email:user_{random.randint(0, TOTAL_ACL_ENTRIES_PER_CATEGORY - 1)}@example.com", + ] + acl_indices = random.sample( + range(TOTAL_ACL_ENTRIES_PER_CATEGORY), NUMBER_OF_ACL_ENTRIES_PER_QUERY + ) + for i in acl_indices: + access_control_list.append(f"group:group_{acl_indices[i]}") + access_control_list.append(f"external_group:external_group_{acl_indices[i]}") + + doc_sets = [] + doc_set_indices = random.sample( + range(TOTAL_DOC_SETS), NUMBER_OF_ACL_ENTRIES_PER_QUERY + ) + for i in doc_set_indices: + doc_sets.append(f"document_set:Document Set {doc_set_indices[i]}") + + return IndexFilters( + source_type=[DocumentSource.GOOGLE_DRIVE], + document_set=doc_sets, + tags=[], + access_control_list=access_control_list, + ) + + +def test_hybrid_retrieval_times( + number_of_queries: int, +) -> None: + with get_session_context_manager() as db_session: + search_settings = get_current_search_settings(db_session) + index_name = search_settings.index_name + + vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None) + + # Generate random queries + queries = [f"Random Query {i}" for i in range(number_of_queries)] + + # Generate random embeddings + embeddings = [ + Embedding([random.random() for _ in range(DOC_EMBEDDING_DIM)]) + for _ in range(number_of_queries) + ] + + total_time = 0.0 + results = [] + for i in range(number_of_queries): + start_time = time.time() + + vespa_index.hybrid_retrieval( + query=queries[i], + query_embedding=embeddings[i], + final_keywords=None, + filters=_random_filters(), + hybrid_alpha=0.5, + time_decay_multiplier=1.0, + num_to_retrieve=50, + offset=0, + title_content_ratio=0.5, + ) + + end_time = time.time() + query_time = end_time - start_time + total_time += query_time + results.append(query_time) + + print(f"Query {i+1}: {query_time:.4f} seconds") + + avg_time = total_time / number_of_queries + fast_time = min(results) + slow_time = max(results) + ninety_ninth_percentile = get_slowest_99th_percentile(results) + # Write results to a file + _OUTPUT_PATH = "query_times_results_large_more.txt" + with open(_OUTPUT_PATH, "w") as f: + f.write(f"Average query time: {avg_time:.4f} seconds\n") + f.write(f"Fastest query: {fast_time:.4f} seconds\n") + f.write(f"Slowest query: {slow_time:.4f} seconds\n") + f.write(f"99th percentile: {ninety_ninth_percentile:.4f} seconds\n") + print(f"Results written to {_OUTPUT_PATH}") + + print(f"\nAverage query time: {avg_time:.4f} seconds") + print(f"Fastest query: {fast_time:.4f} seconds") + print(f"Slowest query: {max(results):.4f} seconds") + print(f"99th percentile: {get_slowest_99th_percentile(results):.4f} seconds") + + +if __name__ == "__main__": + test_hybrid_retrieval_times(number_of_queries=1000) diff --git a/backend/tests/integration/common_utils/managers/cc_pair.py b/backend/tests/integration/common_utils/managers/cc_pair.py index 6498252bbe8..5512eca714d 100644 --- a/backend/tests/integration/common_utils/managers/cc_pair.py +++ b/backend/tests/integration/common_utils/managers/cc_pair.py @@ -5,6 +5,7 @@ import requests from danswer.connectors.models import InputType +from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorIndexingStatus @@ -22,7 +23,7 @@ def _cc_pair_creator( connector_id: int, credential_id: int, name: str | None = None, - is_public: bool = True, + access_type: AccessType = AccessType.PUBLIC, groups: list[int] | None = None, user_performing_action: TestUser | None = None, ) -> TestCCPair: @@ -30,7 +31,7 @@ def _cc_pair_creator( request = { "name": name, - "is_public": is_public, + "access_type": access_type, "groups": groups or [], } @@ -47,7 +48,7 @@ def _cc_pair_creator( name=name, connector_id=connector_id, credential_id=credential_id, - is_public=is_public, + access_type=access_type, groups=groups or [], ) @@ -56,7 +57,7 @@ class CCPairManager: @staticmethod def create_from_scratch( name: str | None = None, - is_public: bool = True, + access_type: AccessType = AccessType.PUBLIC, groups: list[int] | None = None, source: DocumentSource = DocumentSource.FILE, input_type: InputType = InputType.LOAD_STATE, @@ -69,7 +70,7 @@ def create_from_scratch( source=source, input_type=input_type, connector_specific_config=connector_specific_config, - is_public=is_public, + is_public=(access_type == AccessType.PUBLIC), groups=groups, user_performing_action=user_performing_action, ) @@ -77,7 +78,7 @@ def create_from_scratch( credential_json=credential_json, name=name, source=source, - curator_public=is_public, + curator_public=(access_type == AccessType.PUBLIC), groups=groups, user_performing_action=user_performing_action, ) @@ -85,7 +86,7 @@ def create_from_scratch( connector_id=connector.id, credential_id=credential.id, name=name, - is_public=is_public, + access_type=access_type, groups=groups, user_performing_action=user_performing_action, ) @@ -95,7 +96,7 @@ def create( connector_id: int, credential_id: int, name: str | None = None, - is_public: bool = True, + access_type: AccessType = AccessType.PUBLIC, groups: list[int] | None = None, user_performing_action: TestUser | None = None, ) -> TestCCPair: @@ -103,7 +104,7 @@ def create( connector_id=connector_id, credential_id=credential_id, name=name, - is_public=is_public, + access_type=access_type, groups=groups, user_performing_action=user_performing_action, ) @@ -172,7 +173,7 @@ def verify( retrieved_cc_pair.name == cc_pair.name and retrieved_cc_pair.connector.id == cc_pair.connector_id and retrieved_cc_pair.credential.id == cc_pair.credential_id - and retrieved_cc_pair.public_doc == cc_pair.is_public + and retrieved_cc_pair.access_type == cc_pair.access_type and set(retrieved_cc_pair.groups) == set(cc_pair.groups) ): return diff --git a/backend/tests/integration/common_utils/managers/document.py b/backend/tests/integration/common_utils/managers/document.py index 3f691eca8f9..dcd8def5c5a 100644 --- a/backend/tests/integration/common_utils/managers/document.py +++ b/backend/tests/integration/common_utils/managers/document.py @@ -3,6 +3,7 @@ import requests from danswer.configs.constants import DocumentSource +from danswer.db.enums import AccessType from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.constants import NUM_DOCS @@ -22,7 +23,7 @@ def _verify_document_permissions( ) -> None: acl_keys = set(retrieved_doc["access_control_list"].keys()) print(f"ACL keys: {acl_keys}") - if cc_pair.is_public: + if cc_pair.access_type == AccessType.PUBLIC: if "PUBLIC" not in acl_keys: raise ValueError( f"Document {retrieved_doc['document_id']} is public but" @@ -30,10 +31,10 @@ def _verify_document_permissions( ) if doc_creating_user is not None: - if f"user_id:{doc_creating_user.id}" not in acl_keys: + if f"user_email:{doc_creating_user.email}" not in acl_keys: raise ValueError( f"Document {retrieved_doc['document_id']} was created by user" - f" {doc_creating_user.id} but does not have the user_id:{doc_creating_user.id} ACL key" + f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key" ) if group_names is not None: diff --git a/backend/tests/integration/common_utils/test_models.py b/backend/tests/integration/common_utils/test_models.py index 2d8744327df..db46409ab2d 100644 --- a/backend/tests/integration/common_utils/test_models.py +++ b/backend/tests/integration/common_utils/test_models.py @@ -5,6 +5,7 @@ from pydantic import Field from danswer.auth.schemas import UserRole +from danswer.db.enums import AccessType from danswer.search.enums import RecencyBiasSetting from danswer.server.documents.models import DocumentSource from danswer.server.documents.models import InputType @@ -67,7 +68,7 @@ class TestCCPair(BaseModel): name: str connector_id: int credential_id: int - is_public: bool + access_type: AccessType groups: list[int] documents: list[SimpleTestDocument] = Field(default_factory=list) diff --git a/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py b/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py index c52c5826eae..c29a9faa5c9 100644 --- a/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py +++ b/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py @@ -5,6 +5,7 @@ import pytest from requests.exceptions import HTTPError +from danswer.db.enums import AccessType from danswer.server.documents.models import DocumentSource from tests.integration.common_utils.managers.cc_pair import CCPairManager from tests.integration.common_utils.managers.connector import ConnectorManager @@ -91,8 +92,8 @@ def test_cc_pair_permissions(reset: None) -> None: connector_id=connector_1.id, credential_id=credential_1.id, name="invalid_cc_pair_1", + access_type=AccessType.PUBLIC, groups=[user_group_1.id], - is_public=True, user_performing_action=curator, ) @@ -103,8 +104,8 @@ def test_cc_pair_permissions(reset: None) -> None: connector_id=connector_1.id, credential_id=credential_1.id, name="invalid_cc_pair_2", + access_type=AccessType.PRIVATE, groups=[user_group_1.id, user_group_2.id], - is_public=False, user_performing_action=curator, ) @@ -115,8 +116,8 @@ def test_cc_pair_permissions(reset: None) -> None: connector_id=connector_1.id, credential_id=credential_1.id, name="invalid_cc_pair_2", + access_type=AccessType.PRIVATE, groups=[], - is_public=False, user_performing_action=curator, ) @@ -129,8 +130,8 @@ def test_cc_pair_permissions(reset: None) -> None: # connector_id=connector_2.id, # credential_id=credential_1.id, # name="invalid_cc_pair_3", + # access_type=AccessType.PRIVATE, # groups=[user_group_1.id], - # is_public=False, # user_performing_action=curator, # ) @@ -141,8 +142,8 @@ def test_cc_pair_permissions(reset: None) -> None: connector_id=connector_1.id, credential_id=credential_2.id, name="invalid_cc_pair_4", + access_type=AccessType.PRIVATE, groups=[user_group_1.id], - is_public=False, user_performing_action=curator, ) @@ -154,8 +155,8 @@ def test_cc_pair_permissions(reset: None) -> None: name="valid_cc_pair", connector_id=connector_1.id, credential_id=credential_1.id, + access_type=AccessType.PRIVATE, groups=[user_group_1.id], - is_public=False, user_performing_action=curator, ) diff --git a/backend/tests/integration/tests/permissions/test_doc_set_permissions.py b/backend/tests/integration/tests/permissions/test_doc_set_permissions.py index 412b5d41fad..b8913485eb2 100644 --- a/backend/tests/integration/tests/permissions/test_doc_set_permissions.py +++ b/backend/tests/integration/tests/permissions/test_doc_set_permissions.py @@ -1,6 +1,7 @@ import pytest from requests.exceptions import HTTPError +from danswer.db.enums import AccessType from danswer.server.documents.models import DocumentSource from tests.integration.common_utils.managers.cc_pair import CCPairManager from tests.integration.common_utils.managers.document_set import DocumentSetManager @@ -47,14 +48,14 @@ def test_doc_set_permissions_setup(reset: None) -> None: # Admin creates a cc_pair private_cc_pair = CCPairManager.create_from_scratch( - is_public=False, + access_type=AccessType.PRIVATE, source=DocumentSource.INGESTION_API, user_performing_action=admin_user, ) # Admin creates a public cc_pair public_cc_pair = CCPairManager.create_from_scratch( - is_public=True, + access_type=AccessType.PUBLIC, source=DocumentSource.INGESTION_API, user_performing_action=admin_user, ) diff --git a/backend/tests/integration/tests/permissions/test_whole_curator_flow.py b/backend/tests/integration/tests/permissions/test_whole_curator_flow.py index 878ba1e17e8..00b1545711d 100644 --- a/backend/tests/integration/tests/permissions/test_whole_curator_flow.py +++ b/backend/tests/integration/tests/permissions/test_whole_curator_flow.py @@ -1,6 +1,7 @@ """ This test tests the happy path for curator permissions """ +from danswer.db.enums import AccessType from danswer.db.models import UserRole from danswer.server.documents.models import DocumentSource from tests.integration.common_utils.managers.cc_pair import CCPairManager @@ -64,8 +65,8 @@ def test_whole_curator_flow(reset: None) -> None: connector_id=test_connector.id, credential_id=test_credential.id, name="curator_test_cc_pair", + access_type=AccessType.PRIVATE, groups=[user_group_1.id], - is_public=False, user_performing_action=curator, ) diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index 98d4601614d..614ae6c92bc 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -1,6 +1,6 @@ "use client"; -import { errorHandlingFetcher } from "@/lib/fetcher"; +import { FetchError, errorHandlingFetcher } from "@/lib/fetcher"; import useSWR, { mutate } from "swr"; import { HealthCheckBanner } from "@/components/health/healthcheck"; @@ -12,7 +12,6 @@ import { useFormContext } from "@/components/context/FormContext"; import { getSourceDisplayName } from "@/lib/sources"; import { SourceIcon } from "@/components/SourceIcon"; import { useState } from "react"; -import { submitConnector } from "@/components/admin/connectors/ConnectorForm"; import { deleteCredential, linkCredential } from "@/lib/credential"; import { submitFiles } from "./pages/utils/files"; import { submitGoogleSite } from "./pages/utils/google_site"; @@ -38,7 +37,8 @@ import { useGoogleDriveCredentials, } from "./pages/utils/hooks"; import { Formik } from "formik"; -import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector"; +import { AccessTypeForm } from "@/components/admin/connectors/AccessTypeForm"; +import { AccessTypeGroupSelector } from "@/components/admin/connectors/AccessTypeGroupSelector"; import NavigationRow from "./NavigationRow"; export interface AdvancedConfig { @@ -46,6 +46,64 @@ export interface AdvancedConfig { pruneFreq: number; indexingStart: string; } +import { Connector, ConnectorBase } from "@/lib/connectors/connectors"; + +const BASE_CONNECTOR_URL = "/api/manage/admin/connector"; + +export async function submitConnector( + connector: ConnectorBase, + connectorId?: number, + fakeCredential?: boolean, + isPublicCcpair?: boolean // exclusively for mock credentials, when also need to specify ccpair details +): Promise<{ message: string; isSuccess: boolean; response?: Connector }> { + const isUpdate = connectorId !== undefined; + if (!connector.connector_specific_config) { + connector.connector_specific_config = {} as T; + } + + try { + if (fakeCredential) { + const response = await fetch( + "/api/manage/admin/connector-with-mock-credential", + { + method: isUpdate ? "PATCH" : "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ ...connector, is_public: isPublicCcpair }), + } + ); + if (response.ok) { + const responseJson = await response.json(); + return { message: "Success!", isSuccess: true, response: responseJson }; + } else { + const errorData = await response.json(); + return { message: `Error: ${errorData.detail}`, isSuccess: false }; + } + } else { + const response = await fetch( + BASE_CONNECTOR_URL + (isUpdate ? `/${connectorId}` : ""), + { + method: isUpdate ? "PATCH" : "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(connector), + } + ); + + if (response.ok) { + const responseJson = await response.json(); + return { message: "Success!", isSuccess: true, response: responseJson }; + } else { + const errorData = await response.json(); + return { message: `Error: ${errorData.detail}`, isSuccess: false }; + } + } + } catch (error) { + return { message: `Error: ${error}`, isSuccess: false }; + } +} export default function AddConnector({ connector, @@ -84,10 +142,35 @@ export default function AddConnector({ const { liveGDriveCredential } = useGoogleDriveCredentials(); const { liveGmailCredential } = useGmailCredentials(); + const { + data: appCredentialData, + isLoading: isAppCredentialLoading, + error: isAppCredentialError, + } = useSWR<{ client_id: string }, FetchError>( + "/api/manage/admin/connector/google-drive/app-credential", + errorHandlingFetcher + ); + const { + data: serviceAccountKeyData, + isLoading: isServiceAccountKeyLoading, + error: isServiceAccountKeyError, + } = useSWR<{ service_account_email: string }, FetchError>( + "/api/manage/admin/connector/google-drive/service-account-key", + errorHandlingFetcher + ); + // Check if credential is activated const credentialActivated = - (connector === "google_drive" && liveGDriveCredential) || - (connector === "gmail" && liveGmailCredential) || + (connector === "google_drive" && + (liveGDriveCredential || + appCredentialData || + serviceAccountKeyData || + currentCredential)) || + (connector === "gmail" && + (liveGmailCredential || + appCredentialData || + serviceAccountKeyData || + currentCredential)) || currentCredential; // Check if there are no credentials @@ -159,10 +242,12 @@ export default function AddConnector({ const { name, groups, - is_public: isPublic, + access_type, pruneFreq, indexingStart, refreshFreq, + auto_sync_options, + is_public, ...connector_specific_config } = values; @@ -204,6 +289,7 @@ export default function AddConnector({ advancedConfiguration.refreshFreq, advancedConfiguration.pruneFreq, advancedConfiguration.indexingStart, + values.access_type == "public", name ); if (response) { @@ -219,7 +305,7 @@ export default function AddConnector({ setPopup, setSelectedFiles, name, - isPublic, + access_type == "public", groups ); if (response) { @@ -234,15 +320,15 @@ export default function AddConnector({ input_type: connector == "web" ? "load_state" : "poll", // single case name: name, source: connector, + is_public: access_type == "public", refresh_freq: advancedConfiguration.refreshFreq || null, prune_freq: advancedConfiguration.pruneFreq || null, indexing_start: advancedConfiguration.indexingStart || null, - is_public: isPublic, groups: groups, }, undefined, credentialActivated ? false : true, - isPublic + access_type == "public" ); // If no credential if (!credentialActivated) { @@ -261,8 +347,9 @@ export default function AddConnector({ response.id, credential?.id!, name, - isPublic, - groups + access_type, + groups, + auto_sync_options ); if (linkCredentialResponse.ok) { onSuccess(); @@ -366,11 +453,8 @@ export default function AddConnector({ selectedFiles={selectedFiles} /> - + + )} diff --git a/web/src/app/admin/connectors/[connector]/Sidebar.tsx b/web/src/app/admin/connectors/[connector]/Sidebar.tsx index e0c85029fb0..4b7f2970daf 100644 --- a/web/src/app/admin/connectors/[connector]/Sidebar.tsx +++ b/web/src/app/admin/connectors/[connector]/Sidebar.tsx @@ -6,12 +6,14 @@ import { Logo } from "@/components/Logo"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import { credentialTemplates } from "@/lib/connectors/credentials"; import Link from "next/link"; +import { useUser } from "@/components/user/UserProvider"; import { useContext } from "react"; export default function Sidebar() { const { formStep, setFormStep, connector, allowAdvanced, allowCreate } = useFormContext(); const combinedSettings = useContext(SettingsContext); + const { isLoadingUser, isAdmin } = useUser(); if (!combinedSettings) { return null; } @@ -59,7 +61,9 @@ export default function Sidebar() { className="w-full p-2 bg-white border-border border rounded items-center hover:bg-background-200 cursor-pointer transition-all duration-150 flex gap-x-2" > -

Admin Page

+

+ {isAdmin ? "Admin Page" : "Curator Page"} +

diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx index 98e517e43aa..8fc1fa76787 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx @@ -222,36 +222,46 @@ export const DriveJsonUploadSection = ({ Found existing app credentials with the following Client ID:

{appCredentialData.client_id}

-
- If you want to update these credentials, delete the existing - credentials through the button below, and then upload a new - credentials JSON. -
- + {isAdmin ? ( + <> +
+ If you want to update these credentials, delete the existing + credentials through the button below, and then upload a new + credentials JSON. +
+ + + ) : ( +
+ To change these credentials, please contact an administrator. +
+ )} ); } diff --git a/web/src/app/admin/connectors/[connector]/pages/utils/files.ts b/web/src/app/admin/connectors/[connector]/pages/utils/files.ts index 7535eec35bb..c7b36ac0fbe 100644 --- a/web/src/app/admin/connectors/[connector]/pages/utils/files.ts +++ b/web/src/app/admin/connectors/[connector]/pages/utils/files.ts @@ -80,7 +80,7 @@ export const submitFiles = async ( connector.id, credentialId, name, - isPublic, + isPublic ? "public" : "private", groups ); if (!credentialResponse.ok) { diff --git a/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts b/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts index 11d7f46ecea..abc1097cc36 100644 --- a/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts +++ b/web/src/app/admin/connectors/[connector]/pages/utils/google_site.ts @@ -10,6 +10,7 @@ export const submitGoogleSite = async ( refreshFreq: number, pruneFreq: number, indexingStart: Date, + is_public: boolean, name?: string ) => { const uploadCreateAndTriggerConnector = async () => { @@ -42,6 +43,7 @@ export const submitGoogleSite = async ( base_url: base_url, zip_path: filePaths[0], }, + is_public: is_public, refresh_freq: refreshFreq, prune_freq: pruneFreq, indexing_start: indexingStart, diff --git a/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx b/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx index 2778103e345..fb7e56cc60c 100644 --- a/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx +++ b/web/src/app/admin/documents/sets/DocumentSetCreationForm.tsx @@ -155,7 +155,7 @@ export const DocumentSetCreationForm = ({ // Filter visible cc pairs const visibleCcPairs = localCcPairs.filter( (ccPair) => - ccPair.public_doc || + ccPair.access_type === "public" || (ccPair.groups.length > 0 && props.values.groups.every((group) => ccPair.groups.includes(group) @@ -228,7 +228,7 @@ export const DocumentSetCreationForm = ({ // Filter non-visible cc pairs const nonVisibleCcPairs = localCcPairs.filter( (ccPair) => - !ccPair.public_doc && + !(ccPair.access_type === "public") && (ccPair.groups.length === 0 || !props.values.groups.every((group) => ccPair.groups.includes(group) diff --git a/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx b/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx index 2a38e5eddcf..2b78e11de2d 100644 --- a/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx +++ b/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx @@ -232,7 +232,7 @@ function ConnectorRow({ {getActivityBadge()} {isPaidEnterpriseFeaturesEnabled && ( - {ccPairsIndexingStatus.public_doc ? ( + {ccPairsIndexingStatus.access_type === "public" ? ( status.cc_pair_status === ConnectorCredentialPairStatus.ACTIVE ).length, - public: statuses.filter((status) => status.public_doc).length, + public: statuses.filter((status) => status.access_type === "public") + .length, totalDocsIndexed: statuses.reduce( (sum, status) => sum + status.docs_indexed, 0 @@ -420,7 +421,7 @@ export function CCPairIndexingStatusTable({ credential_json: {}, admin_public: false, }, - public_doc: true, + access_type: "public", docs_indexed: 1000, last_success: "2023-07-01T12:00:00Z", last_finished_status: "success", diff --git a/web/src/app/ee/admin/groups/ConnectorEditor.tsx b/web/src/app/ee/admin/groups/ConnectorEditor.tsx index d632e5f6ad8..ab6bbb0bec9 100644 --- a/web/src/app/ee/admin/groups/ConnectorEditor.tsx +++ b/web/src/app/ee/admin/groups/ConnectorEditor.tsx @@ -16,7 +16,7 @@ export const ConnectorEditor = ({
{allCCPairs // remove public docs, since they don't make sense as part of a group - .filter((ccPair) => !ccPair.public_doc) + .filter((ccPair) => !(ccPair.access_type === "public")) .map((ccPair) => { const ind = selectedCCPairIds.indexOf(ccPair.cc_pair_id); let isSelected = ind !== -1; diff --git a/web/src/app/ee/admin/groups/UserGroupCreationForm.tsx b/web/src/app/ee/admin/groups/UserGroupCreationForm.tsx index c87bfb5e584..ec1ac52e609 100644 --- a/web/src/app/ee/admin/groups/UserGroupCreationForm.tsx +++ b/web/src/app/ee/admin/groups/UserGroupCreationForm.tsx @@ -28,6 +28,11 @@ export const UserGroupCreationForm = ({ }: UserGroupCreationFormProps) => { const isUpdate = existingUserGroup !== undefined; + // Filter out ccPairs that aren't access_type "private" + const privateCcPairs = ccPairs.filter( + (ccPair) => ccPair.access_type === "private" + ); + return (
@@ -96,7 +101,7 @@ export const UserGroupCreationForm = ({

- Select which connectors this group has access to: + Select which private connectors this group has access to:

All documents indexed by the selected connectors will be @@ -104,7 +109,7 @@ export const UserGroupCreationForm = ({

setFieldValue("cc_pair_ids", ccPairsIds) diff --git a/web/src/app/ee/admin/groups/[groupId]/AddConnectorForm.tsx b/web/src/app/ee/admin/groups/[groupId]/AddConnectorForm.tsx index 3e1896edd16..8b5166ed851 100644 --- a/web/src/app/ee/admin/groups/[groupId]/AddConnectorForm.tsx +++ b/web/src/app/ee/admin/groups/[groupId]/AddConnectorForm.tsx @@ -76,7 +76,7 @@ export const AddConnectorForm: React.FC = ({ .includes(ccPair.cc_pair_id) ) // remove public docs, since they don't make sense as part of a group - .filter((ccPair) => !ccPair.public_doc) + .filter((ccPair) => !(ccPair.access_type === "public")) .map((ccPair) => { return { name: ccPair.name?.toString() || "", diff --git a/web/src/components/admin/connectors/AccessTypeForm.tsx b/web/src/components/admin/connectors/AccessTypeForm.tsx new file mode 100644 index 00000000000..f3a7eef113d --- /dev/null +++ b/web/src/components/admin/connectors/AccessTypeForm.tsx @@ -0,0 +1,88 @@ +import { DefaultDropdown } from "@/components/Dropdown"; +import { + AccessType, + ValidAutoSyncSources, + ConfigurableSources, + validAutoSyncSources, +} from "@/lib/types"; +import { Text, Title } from "@tremor/react"; +import { useUser } from "@/components/user/UserProvider"; +import { useField } from "formik"; +import { AutoSyncOptions } from "./AutoSyncOptions"; + +function isValidAutoSyncSource( + value: ConfigurableSources +): value is ValidAutoSyncSources { + return validAutoSyncSources.includes(value as ValidAutoSyncSources); +} + +export function AccessTypeForm({ + connector, +}: { + connector: ConfigurableSources; +}) { + const [access_type, meta, access_type_helpers] = + useField("access_type"); + + const isAutoSyncSupported = isValidAutoSyncSource(connector); + const { isLoadingUser, isAdmin } = useUser(); + + const options = [ + { + name: "Private", + value: "private", + description: + "Only users who have expliticly been given access to this connector (through the User Groups page) can access the documents pulled in by this connector", + }, + ]; + + if (isAdmin) { + options.push({ + name: "Public", + value: "public", + description: + "Everyone with an account on Danswer can access the documents pulled in by this connector", + }); + } + + if (isAutoSyncSupported && isAdmin) { + options.push({ + name: "Auto Sync", + value: "sync", + description: + "We will automatically sync permissions from the source. A document will be searchable in Danswer if and only if the user performing the search has permission to access the document in the source.", + }); + } + + return ( +
+
+ +
+

+ Control who has access to the documents indexed by this connector. +

+ + {isAdmin && ( + <> + + access_type_helpers.setValue(selected as AccessType) + } + includeDefault={false} + /> + + {access_type.value === "sync" && isAutoSyncSupported && ( +
+ +
+ )} + + )} +
+ ); +} diff --git a/web/src/components/admin/connectors/AccessTypeGroupSelector.tsx b/web/src/components/admin/connectors/AccessTypeGroupSelector.tsx new file mode 100644 index 00000000000..4a165515c56 --- /dev/null +++ b/web/src/components/admin/connectors/AccessTypeGroupSelector.tsx @@ -0,0 +1,147 @@ +import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; +import React, { useState, useEffect } from "react"; +import { FieldArray, ArrayHelpers, ErrorMessage, useField } from "formik"; +import { Text, Divider } from "@tremor/react"; +import { FiUsers } from "react-icons/fi"; +import { UserGroup, User, UserRole } from "@/lib/types"; +import { useUserGroups } from "@/lib/hooks"; +import { AccessType } from "@/lib/types"; +import { useUser } from "@/components/user/UserProvider"; + +// This should be included for all forms that require groups / public access +// to be set, and access to this / permissioning should be handled within this component itself. +export function AccessTypeGroupSelector({}: {}) { + const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups(); + const { isAdmin, user, isLoadingUser, isCurator } = useUser(); + const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled(); + const [shouldHideContent, setShouldHideContent] = useState(false); + + const [access_type, meta, access_type_helpers] = + useField("access_type"); + const [groups, groups_meta, groups_helpers] = useField("groups"); + + useEffect(() => { + if (user && userGroups && isPaidEnterpriseFeaturesEnabled) { + const isUserAdmin = user.role === UserRole.ADMIN; + if (!isPaidEnterpriseFeaturesEnabled) { + access_type_helpers.setValue("public"); + return; + } + if (!isUserAdmin) { + access_type_helpers.setValue("private"); + } + if (userGroups.length === 1 && !isUserAdmin) { + groups_helpers.setValue([userGroups[0].id]); + setShouldHideContent(true); + } else if (access_type.value !== "private") { + groups_helpers.setValue([]); + setShouldHideContent(false); + } else { + setShouldHideContent(false); + } + } + }, [user, userGroups, access_type.value]); + + if (isLoadingUser || userGroupsIsLoading) { + return
Loading...
; + } + if (!isPaidEnterpriseFeaturesEnabled) { + return null; + } + + if (shouldHideContent) { + return ( + <> + {userGroups && ( +
+ This Connector will be assigned to group {userGroups[0].name} + . +
+ )} + + ); + } + + return ( +
+ {(access_type.value === "private" || isCurator) && + userGroups && + userGroups?.length > 0 && ( + <> + +
+
+ Assign group access for this Connector +
+
+ {userGroupsIsLoading ? ( +
+ ) : ( + + {isAdmin ? ( + <> + This Connector will be visible/accessible by the groups + selected below + + ) : ( + <> + Curators must select one or more groups to give access to + this Connector + + )} + + )} + ( +
+ {userGroupsIsLoading ? ( +
+ ) : ( + userGroups && + userGroups.map((userGroup: UserGroup) => { + const ind = groups.value.indexOf(userGroup.id); + let isSelected = ind !== -1; + return ( +
{ + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push(userGroup.id); + } + }} + > +
+ {" "} + {userGroup.name} +
+
+ ); + }) + )} +
+ )} + /> + + + )} +
+ ); +} diff --git a/web/src/components/admin/connectors/AutoSyncOptions.tsx b/web/src/components/admin/connectors/AutoSyncOptions.tsx new file mode 100644 index 00000000000..f060ad31f81 --- /dev/null +++ b/web/src/components/admin/connectors/AutoSyncOptions.tsx @@ -0,0 +1,30 @@ +import { TextFormField } from "@/components/admin/connectors/Field"; +import { useFormikContext } from "formik"; +import { ValidAutoSyncSources } from "@/lib/types"; +import { Divider } from "@tremor/react"; +import { autoSyncConfigBySource } from "@/lib/connectors/AutoSyncOptionFields"; + +export function AutoSyncOptions({ + connectorType, +}: { + connectorType: ValidAutoSyncSources; +}) { + return ( +
+ + <> + {Object.entries(autoSyncConfigBySource[connectorType]).map( + ([key, config]) => ( +
+ +
+ ) + )} + +
+ ); +} diff --git a/web/src/components/admin/connectors/ConnectorForm.tsx b/web/src/components/admin/connectors/ConnectorForm.tsx deleted file mode 100644 index 2befa6b2bc6..00000000000 --- a/web/src/components/admin/connectors/ConnectorForm.tsx +++ /dev/null @@ -1,382 +0,0 @@ -"use client"; - -import React, { useState } from "react"; -import { Formik, Form } from "formik"; -import * as Yup from "yup"; -import { Popup, usePopup } from "./Popup"; -import { ValidInputTypes, ValidSources } from "@/lib/types"; -import { deleteConnectorIfExistsAndIsUnlinked } from "@/lib/connector"; -import { FormBodyBuilder, RequireAtLeastOne } from "./types"; -import { BooleanFormField, TextFormField } from "./Field"; -import { createCredential, linkCredential } from "@/lib/credential"; -import { useSWRConfig } from "swr"; -import { Button, Divider } from "@tremor/react"; -import IsPublicField from "./IsPublicField"; -import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; -import { Connector, ConnectorBase } from "@/lib/connectors/connectors"; - -const BASE_CONNECTOR_URL = "/api/manage/admin/connector"; - -export async function submitConnector( - connector: ConnectorBase, - connectorId?: number, - fakeCredential?: boolean, - isPublicCcpair?: boolean // exclusively for mock credentials, when also need to specify ccpair details -): Promise<{ message: string; isSuccess: boolean; response?: Connector }> { - const isUpdate = connectorId !== undefined; - if (!connector.connector_specific_config) { - connector.connector_specific_config = {} as T; - } - - try { - if (fakeCredential) { - const response = await fetch( - "/api/manage/admin/connector-with-mock-credential", - { - method: isUpdate ? "PATCH" : "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ ...connector, is_public: isPublicCcpair }), - } - ); - if (response.ok) { - const responseJson = await response.json(); - return { message: "Success!", isSuccess: true, response: responseJson }; - } else { - const errorData = await response.json(); - return { message: `Error: ${errorData.detail}`, isSuccess: false }; - } - } else { - const response = await fetch( - BASE_CONNECTOR_URL + (isUpdate ? `/${connectorId}` : ""), - { - method: isUpdate ? "PATCH" : "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(connector), - } - ); - - if (response.ok) { - const responseJson = await response.json(); - return { message: "Success!", isSuccess: true, response: responseJson }; - } else { - const errorData = await response.json(); - return { message: `Error: ${errorData.detail}`, isSuccess: false }; - } - } - } catch (error) { - return { message: `Error: ${error}`, isSuccess: false }; - } -} - -const CCPairNameHaver = Yup.object().shape({ - cc_pair_name: Yup.string().required("Please enter a name for the connector"), -}); - -interface BaseProps { - nameBuilder: (values: T) => string; - ccPairNameBuilder?: (values: T) => string | null; - source: ValidSources; - inputType: ValidInputTypes; - // if specified, will automatically try and link the credential - credentialId?: number; - // If both are specified, will render formBody and then formBodyBuilder - formBody?: JSX.Element | null; - formBodyBuilder?: FormBodyBuilder; - validationSchema: Yup.ObjectSchema; - validate?: (values: T) => Record; - initialValues: T; - onSubmit?: ( - isSuccess: boolean, - responseJson: Connector | undefined - ) => void; - refreshFreq?: number; - pruneFreq?: number; - indexingStart?: Date; - // If specified, then we will create an empty credential and associate - // the connector with it. If credentialId is specified, then this will be ignored - shouldCreateEmptyCredentialForConnector?: boolean; -} - -type ConnectorFormProps = RequireAtLeastOne< - BaseProps, - "formBody" | "formBodyBuilder" ->; - -export function ConnectorForm({ - nameBuilder, - ccPairNameBuilder, - source, - inputType, - credentialId, - formBody, - formBodyBuilder, - validationSchema, - validate, - initialValues, - refreshFreq, - pruneFreq, - indexingStart, - onSubmit, - shouldCreateEmptyCredentialForConnector, -}: ConnectorFormProps): JSX.Element { - const { mutate } = useSWRConfig(); - const { popup, setPopup } = usePopup(); - - // only show this option for EE, since groups are not supported in CE - const showNonPublicOption = usePaidEnterpriseFeaturesEnabled(); - - const shouldHaveNameInput = credentialId !== undefined && !ccPairNameBuilder; - - const ccPairNameInitialValue = shouldHaveNameInput - ? { cc_pair_name: "" } - : {}; - const publicOptionInitialValue = showNonPublicOption - ? { is_public: false } - : {}; - - let finalValidationSchema = - validationSchema as Yup.ObjectSchema; - if (shouldHaveNameInput) { - finalValidationSchema = finalValidationSchema.concat(CCPairNameHaver); - } - if (showNonPublicOption) { - finalValidationSchema = finalValidationSchema.concat( - Yup.object().shape({ - is_public: Yup.boolean(), - }) - ); - } - - return ( - <> - {popup} - { - formikHelpers.setSubmitting(true); - const connectorName = nameBuilder(values); - const connectorConfig = Object.fromEntries( - Object.keys(initialValues).map((key) => [key, values[key]]) - ) as T; - - // best effort check to see if existing connector exists - // delete it if: - // 1. it exists - // 2. AND it has no credentials linked to it - // If the ^ are true, that means things have gotten into a bad - // state, and we should delete the connector to recover - const errorMsg = await deleteConnectorIfExistsAndIsUnlinked({ - source, - name: connectorName, - }); - if (errorMsg) { - setPopup({ - message: `Unable to delete existing connector - ${errorMsg}`, - type: "error", - }); - return; - } - - const { message, isSuccess, response } = await submitConnector({ - name: connectorName, - source, - input_type: inputType, - connector_specific_config: connectorConfig, - refresh_freq: refreshFreq || 0, - prune_freq: pruneFreq ?? null, - indexing_start: indexingStart || null, - }); - - if (!isSuccess || !response) { - setPopup({ message, type: "error" }); - formikHelpers.setSubmitting(false); - return; - } - - let credentialIdToLinkTo = credentialId; - // create empty credential if specified - if ( - shouldCreateEmptyCredentialForConnector && - credentialIdToLinkTo === undefined - ) { - const createCredentialResponse = await createCredential({ - credential_json: {}, - admin_public: true, - source: source, - }); - - if (!createCredentialResponse.ok) { - const errorMsg = await createCredentialResponse.text(); - setPopup({ - message: `Error creating credential for CC Pair - ${errorMsg}`, - type: "error", - }); - formikHelpers.setSubmitting(false); - return; - } - credentialIdToLinkTo = (await createCredentialResponse.json()).id; - } - - if (credentialIdToLinkTo !== undefined) { - const ccPairName = ccPairNameBuilder - ? ccPairNameBuilder(values) - : values.cc_pair_name; - const linkCredentialResponse = await linkCredential( - response.id, - credentialIdToLinkTo, - ccPairName as string, - values.is_public - ); - if (!linkCredentialResponse.ok) { - const linkCredentialErrorMsg = - await linkCredentialResponse.text(); - setPopup({ - message: `Error linking credential - ${linkCredentialErrorMsg}`, - type: "error", - }); - formikHelpers.setSubmitting(false); - return; - } - } - - mutate("/api/manage/admin/connector/indexing-status"); - setPopup({ message, type: isSuccess ? "success" : "error" }); - formikHelpers.setSubmitting(false); - if (isSuccess) { - formikHelpers.resetForm(); - } - if (onSubmit) { - onSubmit(isSuccess, response); - } - }} - > - {({ isSubmitting, values }) => ( -
- {shouldHaveNameInput && ( - - )} - {formBody && formBody} - {formBodyBuilder && formBodyBuilder(values)} - {showNonPublicOption && ( - <> - - - - - )} -
- -
- - )} -
- - ); -} - -interface UpdateConnectorBaseProps { - nameBuilder?: (values: T) => string; - existingConnector: Connector; - // If both are specified, uses formBody - formBody?: JSX.Element | null; - formBodyBuilder?: FormBodyBuilder; - validationSchema: Yup.ObjectSchema; - onSubmit?: (isSuccess: boolean, responseJson?: Connector) => void; -} - -type UpdateConnectorFormProps = RequireAtLeastOne< - UpdateConnectorBaseProps, - "formBody" | "formBodyBuilder" ->; - -export function UpdateConnectorForm({ - nameBuilder, - existingConnector, - formBody, - formBodyBuilder, - validationSchema, - onSubmit, -}: UpdateConnectorFormProps): JSX.Element { - const [popup, setPopup] = useState<{ - message: string; - type: "success" | "error"; - } | null>(null); - - return ( - <> - {popup && } - { - formikHelpers.setSubmitting(true); - - const { message, isSuccess, response } = await submitConnector( - { - name: nameBuilder ? nameBuilder(values) : existingConnector.name, - source: existingConnector.source, - input_type: existingConnector.input_type, - connector_specific_config: values, - refresh_freq: existingConnector.refresh_freq, - prune_freq: existingConnector.prune_freq, - indexing_start: existingConnector.indexing_start, - }, - existingConnector.id - ); - - setPopup({ message, type: isSuccess ? "success" : "error" }); - formikHelpers.setSubmitting(false); - if (isSuccess) { - formikHelpers.resetForm(); - } - setTimeout(() => { - setPopup(null); - }, 4000); - if (onSubmit) { - onSubmit(isSuccess, response); - } - }} - > - {({ isSubmitting, values }) => ( -
- {formBody ? formBody : formBodyBuilder && formBodyBuilder(values)} -
- -
-
- )} -
- - ); -} diff --git a/web/src/components/credentials/CredentialSection.tsx b/web/src/components/credentials/CredentialSection.tsx index d26cc78c286..a6002fd9a0d 100644 --- a/web/src/components/credentials/CredentialSection.tsx +++ b/web/src/components/credentials/CredentialSection.tsx @@ -28,7 +28,6 @@ import { ConfluenceCredentialJson, Credential, } from "@/lib/connectors/credentials"; -import { UserGroup } from "@/lib/types"; // Added this import export default function CredentialSection({ ccPair, diff --git a/web/src/components/credentials/actions/CreateCredential.tsx b/web/src/components/credentials/actions/CreateCredential.tsx index 0a5d3cb23ca..5188f3d02a4 100644 --- a/web/src/components/credentials/actions/CreateCredential.tsx +++ b/web/src/components/credentials/actions/CreateCredential.tsx @@ -232,7 +232,7 @@ export default function CreateCredential({ setShowAdvancedOptions={setShowAdvancedOptions} /> )} - {showAdvancedOptions && ( + {(showAdvancedOptions || !isAdmin) && ( +> = { + google_drive: { + customer_id: { + label: "Google Workspace Customer ID", + subtext: ( + <> + The unique identifier for your Google Workspace account. To find this, + checkout the{" "} + + guide from Google + + . + + ), + }, + company_domain: { + label: "Google Workspace Company Domain", + subtext: ( + <> + The email domain for your Google Workspace account. +
+
+ For example, if your email provided through Google Workspace looks + something like chris@danswer.ai, then your company domain is{" "} + danswer.ai + + ), + }, + }, +}; diff --git a/web/src/lib/connectors/connectors.ts b/web/src/lib/connectors/connectors.ts index 7e56a498dcd..b526818b083 100644 --- a/web/src/lib/connectors/connectors.ts +++ b/web/src/lib/connectors/connectors.ts @@ -29,6 +29,7 @@ export interface Option { export interface SelectOption extends Option { type: "select"; options?: StringWithDescription[]; + default?: string; } export interface ListOption extends Option { @@ -599,7 +600,7 @@ For example, specifying .*-support.* as a "channel" will cause the connector to { name: "articles", value: "articles" }, { name: "tickets", value: "tickets" }, ], - default: 0, + default: "articles", }, ], }, diff --git a/web/src/lib/credential.ts b/web/src/lib/credential.ts index 03f6c6e75da..c6596804707 100644 --- a/web/src/lib/credential.ts +++ b/web/src/lib/credential.ts @@ -1,4 +1,5 @@ import { CredentialBase } from "./connectors/credentials"; +import { AccessType } from "@/lib/types"; export async function createCredential(credential: CredentialBase) { return await fetch(`/api/manage/credential`, { @@ -47,8 +48,9 @@ export function linkCredential( connectorId: number, credentialId: number, name?: string, - isPublic?: boolean, - groups?: number[] + accessType?: AccessType, + groups?: number[], + autoSyncOptions?: Record ) { return fetch( `/api/manage/connector/${connectorId}/credential/${credentialId}`, @@ -59,8 +61,9 @@ export function linkCredential( }, body: JSON.stringify({ name: name || null, - is_public: isPublic !== undefined ? isPublic : true, + access_type: accessType !== undefined ? accessType : "public", groups: groups || null, + auto_sync_options: autoSyncOptions || null, }), } ); diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 936e4e6c84c..bf342ca4029 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -49,6 +49,7 @@ export type ValidStatuses = | "not_started"; export type TaskStatus = "PENDING" | "STARTED" | "SUCCESS" | "FAILURE"; export type Feedback = "like" | "dislike"; +export type AccessType = "public" | "private" | "sync"; export type SessionType = "Chat" | "Search" | "Slack"; export interface DocumentBoostStatus { @@ -90,7 +91,7 @@ export interface ConnectorIndexingStatus< cc_pair_status: ConnectorCredentialPairStatus; connector: Connector; credential: Credential; - public_doc: boolean; + access_type: AccessType; owner: string; groups: number[]; last_finished_status: ValidStatuses | null; @@ -258,3 +259,7 @@ export type ConfigurableSources = Exclude< ValidSources, "not_applicable" | "ingestion_api" >; + +// The sources that have auto-sync support on the backend +export const validAutoSyncSources = ["google_drive"] as const; +export type ValidAutoSyncSources = (typeof validAutoSyncSources)[number]; From 8a8e2b310ed9b1382c593e0ad6157e75af0bd2b6 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 19 Sep 2024 16:36:15 -0700 Subject: [PATCH 007/505] Assistants panel rework (#2509) * update user model * squash - update assistant gallery * rework assistant display logic + ux * update tool + assistant display * update a couple function names * update typing + some logic * remove unnecessary comments * finalize functionality * updated logic * fully functional * remove logs + ports * small update to logic * update typing * allow seeding of display priority * reorder migrations * update for alembic --- .../versions/55546a7967ee_assistant_rework.py | 79 ++++ backend/danswer/chat/load_yamls.py | 2 +- backend/danswer/db/models.py | 20 +- backend/danswer/db/persona.py | 22 +- backend/danswer/db/slack_bot_config.py | 2 +- .../danswer/server/features/persona/models.py | 8 +- backend/danswer/server/manage/models.py | 5 + backend/danswer/server/manage/users.py | 62 +++ backend/ee/danswer/server/seeding.py | 1 + .../app/admin/assistants/AssistantEditor.tsx | 6 +- web/src/app/admin/assistants/PersonaTable.tsx | 6 +- web/src/app/admin/assistants/interfaces.ts | 3 +- web/src/app/admin/assistants/lib.ts | 7 + .../app/assistants/AssistantSharedStatus.tsx | 14 +- .../app/assistants/AssistantsPageTitle.tsx | 3 +- web/src/app/assistants/ToolsDisplay.tsx | 38 +- .../assistants/gallery/AssistantsGallery.tsx | 431 +++++++++++------- .../assistants/mine/AssistantSharingModal.tsx | 107 ++--- .../app/assistants/mine/AssistantsList.tsx | 401 ++++++++-------- web/src/app/chat/ChatPage.tsx | 20 +- .../modal/configuration/AssistantsTab.tsx | 1 + web/src/app/chat/page.tsx | 1 + .../sessionSidebar/ChatSessionDisplay.tsx | 1 - web/src/components/Dropdown.tsx | 2 +- .../components/assistants/AssistantIcon.tsx | 7 +- web/src/components/popover/DefaultPopover.tsx | 22 +- web/src/components/tooltip/CustomTooltip.tsx | 2 +- web/src/lib/assistants/checkOwnership.ts | 2 +- web/src/lib/assistants/orderAssistants.ts | 34 -- .../assistants/updateAssistantPreferences.ts | 31 +- web/src/lib/assistants/utils.ts | 120 +++++ web/src/lib/credential.ts | 5 +- web/src/lib/types.ts | 2 + 33 files changed, 914 insertions(+), 553 deletions(-) create mode 100644 backend/alembic/versions/55546a7967ee_assistant_rework.py create mode 100644 web/src/lib/assistants/utils.ts diff --git a/backend/alembic/versions/55546a7967ee_assistant_rework.py b/backend/alembic/versions/55546a7967ee_assistant_rework.py new file mode 100644 index 00000000000..a027321a7c6 --- /dev/null +++ b/backend/alembic/versions/55546a7967ee_assistant_rework.py @@ -0,0 +1,79 @@ +"""assistant_rework + +Revision ID: 55546a7967ee +Revises: 61ff3651add4 +Create Date: 2024-09-18 17:00:23.755399 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision = "55546a7967ee" +down_revision = "61ff3651add4" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Reworking persona and user tables for new assistant features + # keep track of user's chosen assistants separate from their `ordering` + op.add_column("persona", sa.Column("builtin_persona", sa.Boolean(), nullable=True)) + op.execute("UPDATE persona SET builtin_persona = default_persona") + op.alter_column("persona", "builtin_persona", nullable=False) + op.drop_index("_default_persona_name_idx", table_name="persona") + op.create_index( + "_builtin_persona_name_idx", + "persona", + ["name"], + unique=True, + postgresql_where=sa.text("builtin_persona = true"), + ) + + op.add_column( + "user", sa.Column("visible_assistants", postgresql.JSONB(), nullable=True) + ) + op.add_column( + "user", sa.Column("hidden_assistants", postgresql.JSONB(), nullable=True) + ) + op.execute( + "UPDATE \"user\" SET visible_assistants = '[]'::jsonb, hidden_assistants = '[]'::jsonb" + ) + op.alter_column( + "user", + "visible_assistants", + nullable=False, + server_default=sa.text("'[]'::jsonb"), + ) + op.alter_column( + "user", + "hidden_assistants", + nullable=False, + server_default=sa.text("'[]'::jsonb"), + ) + op.drop_column("persona", "default_persona") + op.add_column( + "persona", sa.Column("is_default_persona", sa.Boolean(), nullable=True) + ) + + +def downgrade() -> None: + # Reverting changes made in upgrade + op.drop_column("user", "hidden_assistants") + op.drop_column("user", "visible_assistants") + op.drop_index("_builtin_persona_name_idx", table_name="persona") + + op.drop_column("persona", "is_default_persona") + op.add_column("persona", sa.Column("default_persona", sa.Boolean(), nullable=True)) + op.execute("UPDATE persona SET default_persona = builtin_persona") + op.alter_column("persona", "default_persona", nullable=False) + op.drop_column("persona", "builtin_persona") + op.create_index( + "_default_persona_name_idx", + "persona", + ["name"], + unique=True, + postgresql_where=sa.text("default_persona = true"), + ) diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index 0690f08b759..8d0fd34d8da 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -122,7 +122,7 @@ def load_personas_from_yaml( prompt_ids=prompt_ids, document_set_ids=doc_set_ids, tool_ids=tool_ids, - default_persona=True, + builtin_persona=True, is_public=True, display_priority=existing_persona.display_priority if existing_persona is not None diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index fd2d1344afd..70e583ad9c5 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -125,6 +125,12 @@ class User(SQLAlchemyBaseUserTableUUID, Base): chosen_assistants: Mapped[list[int]] = mapped_column( postgresql.JSONB(), nullable=False, default=[-2, -1, 0] ) + visible_assistants: Mapped[list[int]] = mapped_column( + postgresql.JSONB(), nullable=False, default=[] + ) + hidden_assistants: Mapped[list[int]] = mapped_column( + postgresql.JSONB(), nullable=False, default=[] + ) oidc_expiry: Mapped[datetime.datetime] = mapped_column( TIMESTAMPAware(timezone=True), nullable=True @@ -1308,9 +1314,15 @@ class Persona(Base): starter_messages: Mapped[list[StarterMessage] | None] = mapped_column( postgresql.JSONB(), nullable=True ) - # Default personas are configured via backend during deployment + # Built-in personas are configured via backend during deployment # Treated specially (cannot be user edited etc.) - default_persona: Mapped[bool] = mapped_column(Boolean, default=False) + builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False) + + # Default personas are personas created by admins and are automatically added + # to all users' assistants list. + is_default_persona: Mapped[bool] = mapped_column( + Boolean, default=False, nullable=False + ) # controls whether the persona is available to be selected by users is_visible: Mapped[bool] = mapped_column(Boolean, default=True) # controls the ordering of personas in the UI @@ -1361,10 +1373,10 @@ class Persona(Base): # Default personas loaded via yaml cannot have the same name __table_args__ = ( Index( - "_default_persona_name_idx", + "_builtin_persona_name_idx", "name", unique=True, - postgresql_where=(default_persona == True), # noqa: E712 + postgresql_where=(builtin_persona == True), # noqa: E712 ), ) diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index 6cb93ad9fc0..316919e58b4 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -178,6 +178,7 @@ def create_update_persona( except ValueError as e: logger.exception("Failed to create persona") raise HTTPException(status_code=400, detail=str(e)) + return PersonaSnapshot.from_model(persona) @@ -258,7 +259,7 @@ def get_personas( stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable) if not include_default: - stmt = stmt.where(Persona.default_persona.is_(False)) + stmt = stmt.where(Persona.builtin_persona.is_(False)) if not include_slack_bot_personas: stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX))) if not include_deleted: @@ -306,7 +307,7 @@ def mark_delete_persona_by_name( ) -> None: stmt = ( update(Persona) - .where(Persona.name == persona_name, Persona.default_persona == is_default) + .where(Persona.name == persona_name, Persona.builtin_persona == is_default) .values(deleted=True) ) @@ -406,7 +407,6 @@ def upsert_persona( document_set_ids: list[int] | None = None, tool_ids: list[int] | None = None, persona_id: int | None = None, - default_persona: bool = False, commit: bool = True, icon_color: str | None = None, icon_shape: int | None = None, @@ -414,6 +414,8 @@ def upsert_persona( display_priority: int | None = None, is_visible: bool = True, remove_image: bool | None = None, + builtin_persona: bool = False, + is_default_persona: bool = False, chunks_above: int = CONTEXT_CHUNKS_ABOVE, chunks_below: int = CONTEXT_CHUNKS_BELOW, ) -> Persona: @@ -454,8 +456,8 @@ def upsert_persona( validate_persona_tools(tools) if persona: - if not default_persona and persona.default_persona: - raise ValueError("Cannot update default persona with non-default.") + if not builtin_persona and persona.builtin_persona: + raise ValueError("Cannot update builtin persona with non-builtin.") # this checks if the user has permission to edit the persona persona = fetch_persona_by_id( @@ -470,7 +472,7 @@ def upsert_persona( persona.llm_relevance_filter = llm_relevance_filter persona.llm_filter_extraction = llm_filter_extraction persona.recency_bias = recency_bias - persona.default_persona = default_persona + persona.builtin_persona = builtin_persona persona.llm_model_provider_override = llm_model_provider_override persona.llm_model_version_override = llm_model_version_override persona.starter_messages = starter_messages @@ -482,6 +484,7 @@ def upsert_persona( persona.uploaded_image_id = uploaded_image_id persona.display_priority = display_priority persona.is_visible = is_visible + persona.is_default_persona = is_default_persona # Do not delete any associations manually added unless # a new updated list is provided @@ -509,7 +512,7 @@ def upsert_persona( llm_relevance_filter=llm_relevance_filter, llm_filter_extraction=llm_filter_extraction, recency_bias=recency_bias, - default_persona=default_persona, + builtin_persona=builtin_persona, prompts=prompts or [], document_sets=document_sets or [], llm_model_provider_override=llm_model_provider_override, @@ -521,6 +524,7 @@ def upsert_persona( uploaded_image_id=uploaded_image_id, display_priority=display_priority, is_visible=is_visible, + is_default_persona=is_default_persona, ) db_session.add(persona) @@ -550,7 +554,7 @@ def delete_old_default_personas( Need a more graceful fix later or those need to never have IDs""" stmt = ( update(Persona) - .where(Persona.default_persona, Persona.id > 0) + .where(Persona.builtin_persona, Persona.id > 0) .values(deleted=True, name=func.concat(Persona.name, "_old")) ) @@ -732,7 +736,7 @@ def delete_persona_by_name( persona_name: str, db_session: Session, is_default: bool = True ) -> None: stmt = delete(Persona).where( - Persona.name == persona_name, Persona.default_persona == is_default + Persona.name == persona_name, Persona.builtin_persona == is_default ) db_session.execute(stmt) diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index a37bd18c0ec..1398057cfc8 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -66,7 +66,7 @@ def create_slack_bot_persona( llm_model_version_override=None, starter_messages=None, is_public=True, - default_persona=False, + is_default_persona=False, db_session=db_session, commit=False, ) diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 777ef2037ee..0112cc110e2 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -38,6 +38,8 @@ class CreatePersonaRequest(BaseModel): icon_shape: int | None = None uploaded_image_id: str | None = None # New field for uploaded image remove_image: bool | None = None + is_default_persona: bool = False + display_priority: int | None = None class PersonaSnapshot(BaseModel): @@ -54,7 +56,7 @@ class PersonaSnapshot(BaseModel): llm_model_provider_override: str | None llm_model_version_override: str | None starter_messages: list[StarterMessage] | None - default_persona: bool + builtin_persona: bool prompts: list[PromptSnapshot] tools: list[ToolSnapshot] document_sets: list[DocumentSet] @@ -63,6 +65,7 @@ class PersonaSnapshot(BaseModel): icon_color: str | None icon_shape: int | None uploaded_image_id: str | None = None + is_default_persona: bool @classmethod def from_model( @@ -93,7 +96,8 @@ def from_model( llm_model_provider_override=persona.llm_model_provider_override, llm_model_version_override=persona.llm_model_version_override, starter_messages=persona.starter_messages, - default_persona=persona.default_persona, + builtin_persona=persona.builtin_persona, + is_default_persona=persona.is_default_persona, prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts], tools=[ToolSnapshot.from_model(tool) for tool in persona.tools], document_sets=[ diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index 7b0a3813a82..e3be4c4891d 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -40,6 +40,9 @@ class AuthTypeResponse(BaseModel): class UserPreferences(BaseModel): chosen_assistants: list[int] | None = None + hidden_assistants: list[int] = [] + visible_assistants: list[int] = [] + default_model: str | None = None @@ -73,6 +76,8 @@ def from_model( UserPreferences( chosen_assistants=user.chosen_assistants, default_model=user.default_model, + hidden_assistants=user.hidden_assistants, + visible_assistants=user.visible_assistants, ) ), # set to None if TRACK_EXTERNAL_IDP_EXPIRY is False so that we avoid cases diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 04d2c124493..30b71ef0a6a 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -41,6 +41,7 @@ from danswer.server.manage.models import AllUsersResponse from danswer.server.manage.models import UserByEmail from danswer.server.manage.models import UserInfo +from danswer.server.manage.models import UserPreferences from danswer.server.manage.models import UserRoleResponse from danswer.server.manage.models import UserRoleUpdateRequest from danswer.server.models import FullUserSnapshot @@ -444,3 +445,64 @@ def update_user_assistant_list( .values(chosen_assistants=request.chosen_assistants) ) db_session.commit() + + +def update_assistant_list( + preferences: UserPreferences, assistant_id: int, show: bool +) -> UserPreferences: + visible_assistants = preferences.visible_assistants or [] + hidden_assistants = preferences.hidden_assistants or [] + chosen_assistants = preferences.chosen_assistants or [] + + if show: + if assistant_id not in visible_assistants: + visible_assistants.append(assistant_id) + if assistant_id in hidden_assistants: + hidden_assistants.remove(assistant_id) + if assistant_id not in chosen_assistants: + chosen_assistants.append(assistant_id) + else: + if assistant_id in visible_assistants: + visible_assistants.remove(assistant_id) + if assistant_id not in hidden_assistants: + hidden_assistants.append(assistant_id) + if assistant_id in chosen_assistants: + chosen_assistants.remove(assistant_id) + + preferences.visible_assistants = visible_assistants + preferences.hidden_assistants = hidden_assistants + preferences.chosen_assistants = chosen_assistants + return preferences + + +@router.patch("/user/assistant-list/update/{assistant_id}") +def update_user_assistant_visibility( + assistant_id: int, + show: bool, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + if user is None: + if AUTH_TYPE == AuthType.DISABLED: + store = get_dynamic_config_store() + no_auth_user = fetch_no_auth_user(store) + preferences = no_auth_user.preferences + updated_preferences = update_assistant_list(preferences, assistant_id, show) + set_no_auth_user_preferences(store, updated_preferences) + return + else: + raise RuntimeError("This should never happen") + + user_preferences = UserInfo.from_model(user).preferences + updated_preferences = update_assistant_list(user_preferences, assistant_id, show) + + db_session.execute( + update(User) + .where(User.id == user.id) # type: ignore + .values( + hidden_assistants=updated_preferences.hidden_assistants, + visible_assistants=updated_preferences.visible_assistants, + chosen_assistants=updated_preferences.chosen_assistants, + ) + ) + db_session.commit() diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index ab6c4b017f9..b161f057030 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -85,6 +85,7 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> is_public=persona.is_public, db_session=db_session, tool_ids=persona.tool_ids, + display_priority=persona.display_priority, ) diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 821949a5b79..4951f459b82 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -383,6 +383,7 @@ export function AssistantEditor({ } else { [promptResponse, personaResponse] = await createPersona({ ...values, + is_default_persona: admin!, num_chunks: numChunks, users: user && !checkUserIsNoAuthUser(user.id) ? [user.id] : undefined, @@ -414,10 +415,7 @@ export function AssistantEditor({ shouldAddAssistantToUserPreferences && user?.preferences?.chosen_assistants ) { - const success = await addAssistantToList( - assistantId, - user.preferences.chosen_assistants - ); + const success = await addAssistantToList(assistantId); if (success) { setPopup({ message: `"${assistant.name}" has been added to your list.`, diff --git a/web/src/app/admin/assistants/PersonaTable.tsx b/web/src/app/admin/assistants/PersonaTable.tsx index be81d6aa402..e30e858f13c 100644 --- a/web/src/app/admin/assistants/PersonaTable.tsx +++ b/web/src/app/admin/assistants/PersonaTable.tsx @@ -20,7 +20,7 @@ import { UserRole, User } from "@/lib/types"; import { useUser } from "@/components/user/UserProvider"; function PersonaTypeDisplay({ persona }: { persona: Persona }) { - if (persona.default_persona) { + if (persona.is_default_persona) { return Built-In; } @@ -119,7 +119,7 @@ export function PersonasTable({ id: persona.id.toString(), cells: [
- {!persona.default_persona && ( + {!persona.is_default_persona && ( @@ -173,7 +173,7 @@ export function PersonasTable({
,
- {!persona.default_persona && isEditable ? ( + {!persona.is_default_persona && isEditable ? (
{ diff --git a/web/src/app/admin/assistants/interfaces.ts b/web/src/app/admin/assistants/interfaces.ts index 0696b5ae885..a4528e26b2f 100644 --- a/web/src/app/admin/assistants/interfaces.ts +++ b/web/src/app/admin/assistants/interfaces.ts @@ -35,7 +35,8 @@ export interface Persona { llm_model_provider_override?: string; llm_model_version_override?: string; starter_messages: StarterMessage[] | null; - default_persona: boolean; + builtin_persona: boolean; + is_default_persona: boolean; users: MinimalUserSnapshot[]; groups: number[]; icon_shape?: number; diff --git a/web/src/app/admin/assistants/lib.ts b/web/src/app/admin/assistants/lib.ts index ca2c36108b3..61ede3f1643 100644 --- a/web/src/app/admin/assistants/lib.ts +++ b/web/src/app/admin/assistants/lib.ts @@ -21,6 +21,7 @@ interface PersonaCreationRequest { icon_shape: number | null; remove_image?: boolean; uploaded_image: File | null; + is_default_persona: boolean; } interface PersonaUpdateRequest { @@ -125,6 +126,11 @@ function buildPersonaAPIBody( remove_image, } = creationRequest; + const is_default_persona = + "is_default_persona" in creationRequest + ? creationRequest.is_default_persona + : false; + return { name, description, @@ -145,6 +151,7 @@ function buildPersonaAPIBody( icon_shape, uploaded_image_id, remove_image, + is_default_persona, }; } diff --git a/web/src/app/assistants/AssistantSharedStatus.tsx b/web/src/app/assistants/AssistantSharedStatus.tsx index c5127c87e91..f9c68e7919e 100644 --- a/web/src/app/assistants/AssistantSharedStatus.tsx +++ b/web/src/app/assistants/AssistantSharedStatus.tsx @@ -6,9 +6,11 @@ import { FiLock, FiUnlock } from "react-icons/fi"; export function AssistantSharedStatusDisplay({ assistant, user, + size = "sm", }: { assistant: Persona; user: User | null; + size?: "sm" | "md" | "lg"; }) { const isOwnedByUser = checkUserOwnsAssistant(user, assistant); @@ -18,7 +20,9 @@ export function AssistantSharedStatusDisplay({ if (assistant.is_public) { return ( -
+
Public
@@ -27,7 +31,9 @@ export function AssistantSharedStatusDisplay({ if (assistantSharedUsersWithoutOwner.length > 0) { return ( -
+
{isOwnedByUser ? ( `Shared with: ${ @@ -54,7 +60,9 @@ export function AssistantSharedStatusDisplay({ } return ( -
+
Private
diff --git a/web/src/app/assistants/AssistantsPageTitle.tsx b/web/src/app/assistants/AssistantsPageTitle.tsx index 5bcdea7ea66..f0acc0c2c63 100644 --- a/web/src/app/assistants/AssistantsPageTitle.tsx +++ b/web/src/app/assistants/AssistantsPageTitle.tsx @@ -6,10 +6,11 @@ export function AssistantsPageTitle({ return (

{children} diff --git a/web/src/app/assistants/ToolsDisplay.tsx b/web/src/app/assistants/ToolsDisplay.tsx index 2be7670c0ee..8efa02d5516 100644 --- a/web/src/app/assistants/ToolsDisplay.tsx +++ b/web/src/app/assistants/ToolsDisplay.tsx @@ -1,41 +1,5 @@ -import { Bubble } from "@/components/Bubble"; -import { ToolSnapshot } from "@/lib/tools/interfaces"; -import { FiImage, FiSearch, FiGlobe, FiMoreHorizontal } from "react-icons/fi"; +import { FiImage, FiSearch } from "react-icons/fi"; import { Persona } from "../admin/assistants/interfaces"; -import { CustomTooltip } from "@/components/tooltip/CustomTooltip"; -import { useState } from "react"; - -export function ToolsDisplay({ tools }: { tools: ToolSnapshot[] }) { - return ( -
-

Tools:

- {tools.map((tool) => { - let toolName = tool.name; - let toolIcon = null; - - if (tool.name === "SearchTool") { - toolName = "Search"; - toolIcon = ; - } else if (tool.name === "ImageGenerationTool") { - toolName = "Image Generation"; - toolIcon = ; - } else if (tool.name === "InternetSearchTool") { - toolName = "Internet Search"; - toolIcon = ; - } - - return ( - -
- {toolIcon} - {toolName} -
-
- ); - })} -
- ); -} export function AssistantTools({ assistant, diff --git a/web/src/app/assistants/gallery/AssistantsGallery.tsx b/web/src/app/assistants/gallery/AssistantsGallery.tsx index d4882aa6081..8926238b454 100644 --- a/web/src/app/assistants/gallery/AssistantsGallery.tsx +++ b/web/src/app/assistants/gallery/AssistantsGallery.tsx @@ -6,17 +6,137 @@ import { User } from "@/lib/types"; import { Button } from "@tremor/react"; import Link from "next/link"; import { useState } from "react"; -import { FiMinus, FiPlus, FiX } from "react-icons/fi"; -import { NavigationButton } from "../NavigationButton"; +import { FiList, FiMinus, FiPlus } from "react-icons/fi"; import { AssistantsPageTitle } from "../AssistantsPageTitle"; import { addAssistantToList, removeAssistantFromList, } from "@/lib/assistants/updateAssistantPreferences"; -import { usePopup } from "@/components/admin/connectors/Popup"; +import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; import { useRouter } from "next/navigation"; -import { AssistantTools, ToolsDisplay } from "../ToolsDisplay"; +import { AssistantTools } from "../ToolsDisplay"; +import { classifyAssistants } from "@/lib/assistants/utils"; +export function AssistantGalleryCard({ + assistant, + user, + setPopup, + selectedAssistant, +}: { + assistant: Persona; + user: User | null; + setPopup: (popup: PopupSpec) => void; + selectedAssistant: boolean; +}) { + const router = useRouter(); + return ( +
+
+ +

+ {assistant.name} +

+ {user && ( +
+ {selectedAssistant ? ( + + ) : ( + + )} +
+ )} +
+ +

{assistant.description}

+

+ Author: {assistant.owner?.email || "Danswer"} +

+ {assistant.tools.length > 0 && ( + + )} +
+ ); +} export function AssistantsGallery({ assistants, user, @@ -25,182 +145,171 @@ export function AssistantsGallery({ user: User | null; }) { - function filterAssistants(assistants: Persona[], query: string): Persona[] { - return assistants.filter( - (assistant) => - assistant.name.toLowerCase().includes(query.toLowerCase()) || - assistant.description.toLowerCase().includes(query.toLowerCase()) - ); - } - const router = useRouter(); const [searchQuery, setSearchQuery] = useState(""); const { popup, setPopup } = usePopup(); - const allAssistantIds = assistants.map((assistant) => assistant.id); - const filteredAssistants = filterAssistants(assistants, searchQuery); + const { visibleAssistants, hiddenAssistants: _ } = classifyAssistants( + user, + assistants + ); + + const defaultAssistants = assistants + .filter((assistant) => assistant.is_default_persona) + .filter( + (assistant) => + assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) || + assistant.description.toLowerCase().includes(searchQuery.toLowerCase()) + ); + + const nonDefaultAssistants = assistants + .filter((assistant) => !assistant.is_default_persona) + .filter( + (assistant) => + assistant.name.toLowerCase().includes(searchQuery.toLowerCase()) || + assistant.description.toLowerCase().includes(searchQuery.toLowerCase()) + ); return ( <> {popup}
Assistant Gallery -
- - View Your Assistants - + +
+ + +
-

- Discover and create custom assistants that combine instructions, extra - knowledge, and any combination of tools. -

- -
- setSearchQuery(e.target.value)} - className=" - w-full - p-2 - border - border-gray-300 - rounded - focus:outline-none - focus:ring-2 - focus:ring-blue-500 - " - /> +
+
+ setSearchQuery(e.target.value)} + className=" + w-full + py-3 + px-4 + pl-10 + text-lg + border-2 + border-background-strong + rounded-full + bg-background-50 + text-text-700 + placeholder-text-400 + focus:outline-none + focus:ring-2 + focus:ring-primary-500 + focus:border-transparent + transition duration-300 ease-in-out + " + /> +
+ + + +
+
-
- {filteredAssistants.map((assistant) => ( + + {defaultAssistants.length == 0 && + nonDefaultAssistants.length == 0 && + assistants.length != 0 && ( +
+ No assistants found for this search +
+ )} + + {defaultAssistants.length > 0 && ( + <> +
+

+ Default Assistants +

+ +

+ These are assistant created by your admins are and preferred. +

+
-
- -

+
+

+ Other Assistants +

+

+ These are community-contributed assistants. +

+
+ +
- {assistant.name} -

- {user && ( -
- {!user.preferences?.chosen_assistants || - user.preferences?.chosen_assistants?.includes( - assistant.id - ) ? ( - - ) : ( - - )} -
- )} -
- -

{assistant.description}

-

- Author: {assistant.owner?.email || "Danswer"} -

- {assistant.tools.length > 0 && ( - - )} + > + {nonDefaultAssistants.map((assistant) => ( + + ))}
- ))} -
+ + )}
); diff --git a/web/src/app/assistants/mine/AssistantSharingModal.tsx b/web/src/app/assistants/mine/AssistantSharingModal.tsx index 0fa55fea424..b0e96b00070 100644 --- a/web/src/app/assistants/mine/AssistantSharingModal.tsx +++ b/web/src/app/assistants/mine/AssistantSharingModal.tsx @@ -73,7 +73,11 @@ export function AssistantSharingModal({ let sharedStatus = null; if (assistant.is_public || !sharedUsersWithoutOwner.length) { sharedStatus = ( - + ); } else { sharedStatus = ( @@ -122,27 +126,30 @@ export function AssistantSharingModal({ <> {popup} - {" "} -
{assistantName}
+
+ +

+ {assistantName} +

} onOutsideClick={onClose} > -
+
{isUpdating && } - - Control which other users should have access to this assistant. - +

+ Manage access to this assistant by sharing it with other users. +

-
-

Current status:

- {sharedStatus} +
+

Current Status

+
{sharedStatus}
-

Share Assistant:

-
+
+

Share Assistant

{ - return { - name: user.email, - value: user.id, - }; - })} + .map((user) => ({ + name: user.email, + value: user.id, + }))} onSelect={(option) => { setSelectedUsers([ ...Array.from( @@ -170,18 +175,22 @@ export function AssistantSharingModal({ ]); }} itemComponent={({ option }) => ( -
- - {option.name} -
- -
+
+ + {option.name} +
)} /> -
- {selectedUsers.length > 0 && - selectedUsers.map((selectedUser) => ( +
+ + {selectedUsers.length > 0 && ( +
+

+ Selected Users: +

+
+ {selectedUsers.map((selectedUser) => (
{ @@ -191,35 +200,29 @@ export function AssistantSharingModal({ ) ); }} - className={` - flex - rounded-lg - px-2 - py-1 - border - border-border - hover:bg-hover-light - cursor-pointer`} + className="flex items-center bg-blue-50 text-blue-700 rounded-full px-3 py-1 text-sm hover:bg-blue-100 transition-colors duration-200 cursor-pointer" > - {selectedUser.email} + {selectedUser.email} +
))} +
+ )} - {selectedUsers.length > 0 && ( - - )} -
+ {selectedUsers.length > 0 && ( + + )}
diff --git a/web/src/app/assistants/mine/AssistantsList.tsx b/web/src/app/assistants/mine/AssistantsList.tsx index 880dcad49bc..211ac39463d 100644 --- a/web/src/app/assistants/mine/AssistantsList.tsx +++ b/web/src/app/assistants/mine/AssistantsList.tsx @@ -1,25 +1,26 @@ "use client"; -import { Dispatch, SetStateAction, useEffect, useState } from "react"; +import React, { + Dispatch, + ReactNode, + SetStateAction, + useEffect, + useState, +} from "react"; import { MinimalUserSnapshot, User } from "@/lib/types"; import { Persona } from "@/app/admin/assistants/interfaces"; -import { Divider, Text } from "@tremor/react"; +import { Button, Divider, Text } from "@tremor/react"; import { FiEdit2, - FiFigma, - FiMenu, + FiList, FiMinus, FiMoreHorizontal, FiPlus, - FiSearch, - FiShare, FiShare2, - FiToggleLeft, FiTrash, FiX, } from "react-icons/fi"; import Link from "next/link"; -import { orderAssistantsForUser } from "@/lib/assistants/orderAssistants"; import { addAssistantToList, removeAssistantFromList, @@ -29,14 +30,12 @@ import { AssistantIcon } from "@/components/assistants/AssistantIcon"; import { DefaultPopover } from "@/components/popover/DefaultPopover"; import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; import { useRouter } from "next/navigation"; -import { NavigationButton } from "../NavigationButton"; import { AssistantsPageTitle } from "../AssistantsPageTitle"; import { checkUserOwnsAssistant } from "@/lib/assistants/checkOwnership"; import { AssistantSharingModal } from "./AssistantSharingModal"; import { AssistantSharedStatusDisplay } from "../AssistantSharedStatus"; import useSWR from "swr"; import { errorHandlingFetcher } from "@/lib/fetcher"; -import { AssistantTools } from "../ToolsDisplay"; import { DndContext, @@ -62,6 +61,12 @@ import { } from "@/app/admin/assistants/lib"; import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal"; import { MakePublicAssistantModal } from "@/app/chat/modal/MakePublicAssistantModal"; +import { + classifyAssistants, + getUserCreatedAssistants, + orderAssistantsForUser, +} from "@/lib/assistants/utils"; +import { CustomTooltip } from "@/components/tooltip/CustomTooltip"; function DraggableAssistantListItem(props: any) { const { @@ -83,12 +88,12 @@ function DraggableAssistantListItem(props: any) { }; return ( -
+
- +
); @@ -97,21 +102,21 @@ function DraggableAssistantListItem(props: any) { function AssistantListItem({ assistant, user, - allAssistantIds, allUsers, isVisible, setPopup, deleteAssistant, shareAssistant, + isDragging, }: { assistant: Persona; user: User | null; allUsers: MinimalUserSnapshot[]; - allAssistantIds: string[]; isVisible: boolean; deleteAssistant: Dispatch>; shareAssistant: Dispatch>; setPopup: (popupSpec: PopupSpec | null) => void; + isDragging?: boolean; }) { const router = useRouter(); const [showSharingModal, setShowSharingModal] = useState(false); @@ -133,155 +138,154 @@ function AssistantListItem({ show={showSharingModal} />
-
-
-
- -

- {assistant.name} -

-
- -
{assistant.description}
-
- - {assistant.tools.length != 0 && ( - +
+ + +

+ {assistant.name} +

+ + {/* {isOwnedByUser && ( */} +
+
+ {assistant.tools.length > 0 && ( +

+ {assistant.tools.length} tool + {assistant.tools.length > 1 && "s"} +

)} +
-
- {isOwnedByUser && ( -
- {!assistant.is_public && ( -
{ - e.stopPropagation(); - setShowSharingModal(true); - }} - > - -
- )} + {isOwnedByUser ? ( - + - - - -
- } - side="bottom" - align="start" - sideOffset={5} + ) : ( + - {[ - isVisible ? ( -
{ - if ( - currentChosenAssistants && - currentChosenAssistants.length === 1 - ) { - setPopup({ - message: `Cannot remove "${assistant.name}" - you must have at least one assistant.`, - type: "error", - }); - return; - } - const success = await removeAssistantFromList( - assistant.id, - currentChosenAssistants || allAssistantIds - ); - if (success) { - setPopup({ - message: `"${assistant.name}" has been removed from your list.`, - type: "success", - }); - router.refresh(); - } else { - setPopup({ - message: `"${assistant.name}" could not be removed from your list.`, - type: "error", - }); - } - }} - > - {isOwnedByUser ? "Hide" : "Remove"} -
- ) : ( -
{ - const success = await addAssistantToList( - assistant.id, - currentChosenAssistants || allAssistantIds - ); - if (success) { - setPopup({ - message: `"${assistant.name}" has been added to your list.`, - type: "success", - }); - router.refresh(); - } else { - setPopup({ - message: `"${assistant.name}" could not be added to your list.`, - type: "error", - }); - } - }} - > - Add -
- ), - isOwnedByUser ? ( -
deleteAssistant(assistant)} - > - Delete -
- ) : ( - <> - ), - isOwnedByUser ? ( -
shareAssistant(assistant)} - > - {assistant.is_public ? : } Make{" "} - {assistant.is_public ? "Private" : "Public"} -
- ) : ( - <> - ), - ]} - -
- )} +
+ +
+ + )} + + + +
+ } + side="bottom" + align="end" + sideOffset={5} + > + {[ + isVisible ? ( + + ) : ( + + ), + isOwnedByUser ? ( + + ) : null, + isOwnedByUser ? ( + + ) : null, + !assistant.is_public ? ( + + ) : null, + ]} + +
+ {/* )} */}
@@ -294,17 +298,19 @@ export function AssistantsList({ user: User | null; assistants: Persona[]; }) { - const [filteredAssistants, setFilteredAssistants] = useState([]); + // Define the distinct groups of assistants + const { visibleAssistants, hiddenAssistants } = classifyAssistants( + user, + assistants + ); - useEffect(() => { - setFilteredAssistants(orderAssistantsForUser(assistants, user)); - }, [user, assistants, orderAssistantsForUser]); + const [currentlyVisibleAssistants, setCurrentlyVisibleAssistants] = useState< + Persona[] + >(orderAssistantsForUser(visibleAssistants, user)); - const ownedButHiddenAssistants = assistants.filter( - (assistant) => - checkUserOwnsAssistant(user, assistant) && - user?.preferences?.chosen_assistants && - !user?.preferences?.chosen_assistants?.includes(assistant.id) + const ownedButHiddenAssistants = getUserCreatedAssistants( + user, + hiddenAssistants ); const allAssistantIds = assistants.map((assistant) => @@ -332,9 +338,9 @@ export function AssistantsList({ async function handleDragEnd(event: DragEndEvent) { const { active, over } = event; - filteredAssistants; + if (over && active.id !== over.id) { - setFilteredAssistants((assistants) => { + setCurrentlyVisibleAssistants((assistants) => { const oldIndex = assistants.findIndex( (a) => a.id.toString() === active.id ); @@ -390,44 +396,36 @@ export function AssistantsList({ /> )} -
- My Assistants - -
- - -
- - Create New Assistant -
-
- - - - -
- - View Available Assistants -
-
- -
+
+ Your Assistants -

- Assistants allow you to customize your experience for a specific - purpose. Specifically, they combine instructions, extra knowledge, and - any combination of tools. -

+
+ - + +
-

Active Assistants

+

+ Active Assistants +

- +

The order the assistants appear below will be the order they appear in the Assistants dropdown. The first assistant listed will be your default assistant when you start a new chat. Drag and drop to reorder. - +

a.id.toString())} + items={currentlyVisibleAssistants.map((a) => a.id.toString())} strategy={verticalListSortingStrategy} > -
- {filteredAssistants.map((assistant, index) => ( +
+ {currentlyVisibleAssistants.map((assistant, index) => ( Your Hidden Assistants

- +

Assistants you've created that aren't currently visible in the Assistants selector. - +

{ownedButHiddenAssistants.map((assistant, index) => ( @@ -475,7 +473,6 @@ export function AssistantsList({ key={assistant.id} assistant={assistant} user={user} - allAssistantIds={allAssistantIds} allUsers={users || []} isVisible={false} setPopup={setPopup} diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 8e4bea443a3..403c40b3dc4 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -86,7 +86,6 @@ import { import { ChatInputBar } from "./input/ChatInputBar"; import { useChatContext } from "@/components/context/ChatContext"; import { v4 as uuidv4 } from "uuid"; -import { orderAssistantsForUser } from "@/lib/assistants/orderAssistants"; import { ChatPopup } from "./ChatPopup"; import FunctionalHeader from "@/components/chat_search/Header"; @@ -101,6 +100,10 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal"; import { SEARCH_TOOL_NAME } from "./tools/constants"; import { useUser } from "@/components/user/UserProvider"; import { ApiKeyModal } from "@/components/llm/ApiKeyModal"; +import { + classifyAssistants, + orderAssistantsForUser, +} from "@/lib/assistants/utils"; const TEMP_USER_MESSAGE_ID = -1; const TEMP_ASSISTANT_MESSAGE_ID = -2; @@ -136,7 +139,6 @@ export function ChatPage({ const { user, refreshUser, isLoadingUser } = useUser(); - // chat session const existingChatIdRaw = searchParams.get("chatId"); const currentPersonaId = searchParams.get(SEARCH_PARAM_NAMES.PERSONA_ID); @@ -155,7 +157,10 @@ export function ChatPage({ const loadedIdSessionRef = useRef(existingChatSessionId); // Assistants - const filteredAssistants = orderAssistantsForUser(availableAssistants, user); + const { visibleAssistants, hiddenAssistants: _ } = classifyAssistants( + user, + availableAssistants + ); const existingChatSessionAssistantId = selectedChatSession?.persona_id; const [selectedAssistant, setSelectedAssistant] = useState< @@ -210,7 +215,7 @@ export function ChatPage({ const liveAssistant = alternativeAssistant || selectedAssistant || - filteredAssistants[0] || + visibleAssistants[0] || availableAssistants[0]; useEffect(() => { @@ -680,7 +685,7 @@ export function ChatPage({ useEffect(() => { if (messageHistory.length === 0 && chatSessionIdRef.current === null) { setSelectedAssistant( - filteredAssistants.find((persona) => persona.id === defaultAssistantId) + visibleAssistants.find((persona) => persona.id === defaultAssistantId) ); } }, [defaultAssistantId]); @@ -2379,7 +2384,10 @@ export function ChatPage({ showDocs={() => setDocumentSelection(true)} selectedDocuments={selectedDocuments} // assistant stuff - assistantOptions={filteredAssistants} + assistantOptions={orderAssistantsForUser( + visibleAssistants, + user + )} selectedAssistant={liveAssistant} setSelectedAssistant={onAssistantChange} setAlternativeAssistant={setAlternativeAssistant} diff --git a/web/src/app/chat/modal/configuration/AssistantsTab.tsx b/web/src/app/chat/modal/configuration/AssistantsTab.tsx index d4933771f9b..dcf31138ccf 100644 --- a/web/src/app/chat/modal/configuration/AssistantsTab.tsx +++ b/web/src/app/chat/modal/configuration/AssistantsTab.tsx @@ -20,6 +20,7 @@ import { getFinalLLM } from "@/lib/llm/utils"; import React, { useState } from "react"; import { updateUserAssistantList } from "@/lib/assistants/updateAssistantPreferences"; import { DraggableAssistantCard } from "@/components/assistants/AssistantCards"; +import { orderAssistantsForUser } from "@/lib/assistants/utils"; export function AssistantsTab({ selectedAssistant, diff --git a/web/src/app/chat/page.tsx b/web/src/app/chat/page.tsx index e043f544657..72a1cf1aa57 100644 --- a/web/src/app/chat/page.tsx +++ b/web/src/app/chat/page.tsx @@ -6,6 +6,7 @@ import { ChatProvider } from "@/components/context/ChatContext"; import { fetchChatData } from "@/lib/chat/fetchChatData"; import WrappedChat from "./WrappedChat"; import { ProviderContextProvider } from "@/components/chat_search/ProviderContext"; +import { orderAssistantsForUser } from "@/lib/assistants/utils"; export default async function Page({ searchParams, diff --git a/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx b/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx index 35256ada98a..aeea9b97779 100644 --- a/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx +++ b/web/src/app/chat/sessionSidebar/ChatSessionDisplay.tsx @@ -207,7 +207,6 @@ export function ChatSessionDisplay({
{ e.preventDefault(); - // e.stopPropagation(); setIsMoreOptionsDropdownOpen( !isMoreOptionsDropdownOpen ); diff --git a/web/src/components/Dropdown.tsx b/web/src/components/Dropdown.tsx index 79fe0308311..5e36dca9843 100644 --- a/web/src/components/Dropdown.tsx +++ b/web/src/components/Dropdown.tsx @@ -192,7 +192,7 @@ export function SearchMultiSelectDropdown({ export const CustomDropdown = ({ children, dropdown, - direction = "down", // Default to 'down' if not specified + direction = "down", }: { children: JSX.Element | string; dropdown: JSX.Element | string; diff --git a/web/src/components/assistants/AssistantIcon.tsx b/web/src/components/assistants/AssistantIcon.tsx index 4855e95768a..07ab05aa145 100644 --- a/web/src/components/assistants/AssistantIcon.tsx +++ b/web/src/components/assistants/AssistantIcon.tsx @@ -3,6 +3,7 @@ import React from "react"; import { Tooltip } from "../tooltip/Tooltip"; import { createSVG } from "@/lib/assistantIconUtils"; import { buildImgUrl } from "@/app/chat/files/images/utils"; +import { CustomTooltip } from "../tooltip/CustomTooltip"; export function darkerGenerateColorFromId(id: string): string { const hash = Array.from(id).reduce( @@ -30,7 +31,7 @@ export function AssistantIcon({ const color = darkerGenerateColorFromId(assistant.id.toString()); return ( - + { // Prioritization order: image, graph, defaults assistant.uploaded_image_id ? ( @@ -44,7 +45,7 @@ export function AssistantIcon({
{createSVG( { encodedGrid: assistant.icon_shape, filledSquares: 0 }, @@ -62,6 +63,6 @@ export function AssistantIcon({ /> ) } - + ); } diff --git a/web/src/components/popover/DefaultPopover.tsx b/web/src/components/popover/DefaultPopover.tsx index 6032d790921..ff017d972c4 100644 --- a/web/src/components/popover/DefaultPopover.tsx +++ b/web/src/components/popover/DefaultPopover.tsx @@ -5,7 +5,7 @@ import { Popover } from "./Popover"; export function DefaultPopover(props: { content: JSX.Element; - children: JSX.Element[]; + children: (JSX.Element | null)[]; side?: "top" | "right" | "bottom" | "left"; align?: "start" | "center" | "end"; sideOffset?: number; @@ -39,15 +39,17 @@ export function DefaultPopover(props: { overscroll-contain `} > - {props.children.map((child, index) => ( -
setPopoverOpen(false)} - > - {child} -
- ))} + {props.children + .filter((child) => child !== null) + .map((child, index) => ( +
setPopoverOpen(false)} + > + {child} +
+ ))}
} {...props} diff --git a/web/src/components/tooltip/CustomTooltip.tsx b/web/src/components/tooltip/CustomTooltip.tsx index 6695b99eb08..2f4ca2d1254 100644 --- a/web/src/components/tooltip/CustomTooltip.tsx +++ b/web/src/components/tooltip/CustomTooltip.tsx @@ -121,7 +121,7 @@ export const CustomTooltip = ({ {isVisible && createPortal(
[ - id, - index, - ]) - ); - - let filteredAssistants = assistants.filter((assistant) => - chosenAssistantsSet.has(assistant.id) - ); - - if (filteredAssistants.length == 0) { - return assistants; - } - - filteredAssistants.sort((a, b) => { - const orderA = assistantOrderMap.get(a.id) ?? Number.MAX_SAFE_INTEGER; - const orderB = assistantOrderMap.get(b.id) ?? Number.MAX_SAFE_INTEGER; - return orderA - orderB; - }); - return filteredAssistants; - } - - return assistants; -} diff --git a/web/src/lib/assistants/updateAssistantPreferences.ts b/web/src/lib/assistants/updateAssistantPreferences.ts index 61a0128b221..e996867a35d 100644 --- a/web/src/lib/assistants/updateAssistantPreferences.ts +++ b/web/src/lib/assistants/updateAssistantPreferences.ts @@ -11,24 +11,33 @@ export async function updateUserAssistantList( return response.ok; } +export async function updateAssistantVisibility( + assistantId: number, + show: boolean +): Promise { + const response = await fetch( + `/api/user/assistant-list/update/${assistantId}?show=${show}`, + { + method: "PATCH", + headers: { + "Content-Type": "application/json", + }, + } + ); + + return response.ok; +} export async function removeAssistantFromList( - assistantId: number, - chosenAssistants: number[] + assistantId: number ): Promise { - const updatedAssistants = chosenAssistants.filter((id) => id !== assistantId); - return updateUserAssistantList(updatedAssistants); + return updateAssistantVisibility(assistantId, false); } export async function addAssistantToList( - assistantId: number, - chosenAssistants: number[] + assistantId: number ): Promise { - if (!chosenAssistants.includes(assistantId)) { - const updatedAssistants = [...chosenAssistants, assistantId]; - return updateUserAssistantList(updatedAssistants); - } - return false; + return updateAssistantVisibility(assistantId, true); } export async function moveAssistantUp( diff --git a/web/src/lib/assistants/utils.ts b/web/src/lib/assistants/utils.ts new file mode 100644 index 00000000000..208c2ae9827 --- /dev/null +++ b/web/src/lib/assistants/utils.ts @@ -0,0 +1,120 @@ +import { Persona } from "@/app/admin/assistants/interfaces"; +import { User } from "../types"; +import { checkUserIsNoAuthUser } from "../user"; + +export function checkUserOwnsAssistant(user: User | null, assistant: Persona) { + return checkUserIdOwnsAssistant(user?.id, assistant); +} + +export function checkUserIdOwnsAssistant( + userId: string | undefined, + assistant: Persona +) { + return ( + (!userId || + checkUserIsNoAuthUser(userId) || + assistant.owner?.id === userId) && + !assistant.is_default_persona + ); +} + +export function classifyAssistants(user: User | null, assistants: Persona[]) { + if (!user) { + return { + visibleAssistants: assistants.filter( + (assistant) => assistant.is_default_persona + ), + hiddenAssistants: [], + }; + } + + const visibleAssistants = assistants.filter((assistant) => { + const isVisible = user.preferences?.visible_assistants?.includes( + assistant.id + ); + const isNotHidden = !user.preferences?.hidden_assistants?.includes( + assistant.id + ); + const isSelected = user.preferences?.chosen_assistants?.includes( + assistant.id + ); + const isBuiltIn = assistant.builtin_persona; + const isDefault = assistant.is_default_persona; + + const isOwnedByUser = checkUserOwnsAssistant(user, assistant); + + const isShown = + (isVisible && isNotHidden && isSelected) || + (isNotHidden && (isBuiltIn || isDefault || isOwnedByUser)); + return isShown; + }); + + const hiddenAssistants = assistants.filter((assistant) => { + return !visibleAssistants.includes(assistant); + }); + + return { + visibleAssistants, + hiddenAssistants, + }; +} + +export function orderAssistantsForUser( + assistants: Persona[], + user: User | null +) { + let orderedAssistants = [...assistants]; + + if (user?.preferences?.chosen_assistants) { + const chosenAssistantsSet = new Set(user.preferences.chosen_assistants); + const assistantOrderMap = new Map( + user.preferences.chosen_assistants.map((id: number, index: number) => [ + id, + index, + ]) + ); + + // Sort chosen assistants based on user preferences + orderedAssistants.sort((a, b) => { + const orderA = assistantOrderMap.get(a.id); + const orderB = assistantOrderMap.get(b.id); + + if (orderA !== undefined && orderB !== undefined) { + return orderA - orderB; + } else if (orderA !== undefined) { + return -1; + } else if (orderB !== undefined) { + return 1; + } else { + return 0; + } + }); + + // Filter out assistants not in the user's chosen list + orderedAssistants = orderedAssistants.filter((assistant) => + chosenAssistantsSet.has(assistant.id) + ); + } + + // Sort remaining assistants based on display_priority + const remainingAssistants = assistants.filter( + (assistant) => !orderedAssistants.includes(assistant) + ); + remainingAssistants.sort((a, b) => { + const priorityA = a.display_priority ?? Number.MAX_SAFE_INTEGER; + const priorityB = b.display_priority ?? Number.MAX_SAFE_INTEGER; + return priorityA - priorityB; + }); + + // Combine ordered chosen assistants with remaining assistants + return [...orderedAssistants, ...remainingAssistants]; +} + +export function getUserCreatedAssistants( + user: User | null, + assistants: Persona[] +) { + return assistants.filter((assistant) => + checkUserOwnsAssistant(user, assistant) + ); +} diff --git a/web/src/lib/credential.ts b/web/src/lib/credential.ts index c6596804707..f39f3ddedc0 100644 --- a/web/src/lib/credential.ts +++ b/web/src/lib/credential.ts @@ -32,10 +32,7 @@ export async function deleteCredential( }); } -export async function forceDeleteCredential( - credentialId: number, - force?: boolean -) { +export async function forceDeleteCredential(credentialId: number) { return await fetch(`/api/manage/credential/force/${credentialId}`, { method: "DELETE", headers: { diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index bf342ca4029..1ea67d14064 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -5,6 +5,8 @@ import { ConnectorCredentialPairStatus } from "@/app/admin/connector/[ccPairId]/ export interface UserPreferences { chosen_assistants: number[] | null; + visible_assistants: number[]; + hidden_assistants: number[]; default_model: string | null; } From 9f179940f8f1e8072bcda950662cd9d7913f7bee Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 19 Sep 2024 16:54:18 -0700 Subject: [PATCH 008/505] Asana connector (community originated) (#2485) * initial Asana connector * hint on how to get Asana workspace ID * re-format with black * re-order imports * update asana connector for clarity * minor robustification * minor update to naming * update for best practice * update connector --------- Co-authored-by: Daniel Naber --- backend/danswer/configs/constants.py | 1 + backend/danswer/connectors/asana/__init__.py | 0 backend/danswer/connectors/asana/asana_api.py | 233 ++++++++++++++++++ backend/danswer/connectors/asana/connector.py | 120 +++++++++ backend/danswer/connectors/factory.py | 2 + backend/requirements/default.txt | 1 + web/public/Asana.png | Bin 0 -> 6716 bytes web/src/components/icons/icons.tsx | 13 +- web/src/lib/connectors/connectors.ts | 38 +++ web/src/lib/connectors/credentials.ts | 10 + web/src/lib/sources.ts | 7 + web/src/lib/types.ts | 1 + 12 files changed, 425 insertions(+), 1 deletion(-) create mode 100644 backend/danswer/connectors/asana/__init__.py create mode 100755 backend/danswer/connectors/asana/asana_api.py create mode 100755 backend/danswer/connectors/asana/connector.py create mode 100644 web/public/Asana.png diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 5ff2ccc7fd1..670ad3771f9 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -99,6 +99,7 @@ class DocumentSource(str, Enum): CLICKUP = "clickup" MEDIAWIKI = "mediawiki" WIKIPEDIA = "wikipedia" + ASANA = "asana" S3 = "s3" R2 = "r2" GOOGLE_CLOUD_STORAGE = "google_cloud_storage" diff --git a/backend/danswer/connectors/asana/__init__.py b/backend/danswer/connectors/asana/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/connectors/asana/asana_api.py b/backend/danswer/connectors/asana/asana_api.py new file mode 100755 index 00000000000..57c470c4531 --- /dev/null +++ b/backend/danswer/connectors/asana/asana_api.py @@ -0,0 +1,233 @@ +import time +from collections.abc import Iterator +from datetime import datetime +from typing import Dict + +import asana # type: ignore + +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +# https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints +class AsanaTask: + def __init__( + self, + id: str, + title: str, + text: str, + link: str, + last_modified: datetime, + project_gid: str, + project_name: str, + ) -> None: + self.id = id + self.title = title + self.text = text + self.link = link + self.last_modified = last_modified + self.project_gid = project_gid + self.project_name = project_name + + def __str__(self) -> str: + return f"ID: {self.id}\nTitle: {self.title}\nLast modified: {self.last_modified}\nText: {self.text}" + + +class AsanaAPI: + def __init__( + self, api_token: str, workspace_gid: str, team_gid: str | None + ) -> None: + self._user = None # type: ignore + self.workspace_gid = workspace_gid + self.team_gid = team_gid + + self.configuration = asana.Configuration() + self.api_client = asana.ApiClient(self.configuration) + self.tasks_api = asana.TasksApi(self.api_client) + self.stories_api = asana.StoriesApi(self.api_client) + self.users_api = asana.UsersApi(self.api_client) + self.project_api = asana.ProjectsApi(self.api_client) + self.workspaces_api = asana.WorkspacesApi(self.api_client) + + self.api_error_count = 0 + self.configuration.access_token = api_token + self.task_count = 0 + + def get_tasks( + self, project_gids: list[str] | None, start_date: str + ) -> Iterator[AsanaTask]: + """Get all tasks from the projects with the given gids that were modified since the given date. + If project_gids is None, get all tasks from all projects in the workspace.""" + logger.info("Starting to fetch Asana projects") + projects = self.project_api.get_projects( + opts={ + "workspace": self.workspace_gid, + "opt_fields": "gid,name,archived,modified_at", + } + ) + start_seconds = int(time.mktime(datetime.now().timetuple())) + projects_list = [] + project_count = 0 + for project_info in projects: + project_gid = project_info["gid"] + if project_gids is None or project_gid in project_gids: + projects_list.append(project_gid) + else: + logger.debug( + f"Skipping project: {project_gid} - not in accepted project_gids" + ) + project_count += 1 + if project_count % 100 == 0: + logger.info(f"Processed {project_count} projects") + + logger.info(f"Found {len(projects_list)} projects to process") + for project_gid in projects_list: + for task in self._get_tasks_for_project( + project_gid, start_date, start_seconds + ): + yield task + logger.info(f"Completed fetching {self.task_count} tasks from Asana") + if self.api_error_count > 0: + logger.warning( + f"Encountered {self.api_error_count} API errors during task fetching" + ) + + def _get_tasks_for_project( + self, project_gid: str, start_date: str, start_seconds: int + ) -> Iterator[AsanaTask]: + project = self.project_api.get_project(project_gid, opts={}) + if project["archived"]: + logger.info(f"Skipping archived project: {project['name']} ({project_gid})") + return [] + if not project["team"] or not project["team"]["gid"]: + logger.info( + f"Skipping project without a team: {project['name']} ({project_gid})" + ) + return [] + if project["privacy_setting"] == "private": + if self.team_gid and project["team"]["gid"] != self.team_gid: + logger.info( + f"Skipping private project not in configured team: {project['name']} ({project_gid})" + ) + return [] + else: + logger.info( + f"Processing private project in configured team: {project['name']} ({project_gid})" + ) + + simple_start_date = start_date.split(".")[0].split("+")[0] + logger.info( + f"Fetching tasks modified since {simple_start_date} for project: {project['name']} ({project_gid})" + ) + + opts = { + "opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at," + "created_by,custom_fields,dependencies,due_at,due_on,external,html_notes,liked,likes," + "modified_at,notes,num_hearts,parent,projects,resource_subtype,resource_type,start_on," + "workspace,permalink_url", + "modified_since": start_date, + } + tasks_from_api = self.tasks_api.get_tasks_for_project(project_gid, opts) + for data in tasks_from_api: + self.task_count += 1 + if self.task_count % 10 == 0: + end_seconds = time.mktime(datetime.now().timetuple()) + runtime_seconds = end_seconds - start_seconds + if runtime_seconds > 0: + logger.info( + f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds " + f"({self.task_count / runtime_seconds:.2f} tasks/second)" + ) + + logger.debug(f"Processing Asana task: {data['name']}") + + text = self._construct_task_text(data) + + try: + text += self._fetch_and_add_comments(data["gid"]) + + last_modified_date = self.format_date(data["modified_at"]) + text += f"Last modified: {last_modified_date}\n" + + task = AsanaTask( + id=data["gid"], + title=data["name"], + text=text, + link=data["permalink_url"], + last_modified=datetime.fromisoformat(data["modified_at"]), + project_gid=project_gid, + project_name=project["name"], + ) + yield task + except Exception: + logger.error( + f"Error processing task {data['gid']} in project {project_gid}", + exc_info=True, + ) + self.api_error_count += 1 + + def _construct_task_text(self, data: Dict) -> str: + text = f"{data['name']}\n\n" + + if data["notes"]: + text += f"{data['notes']}\n\n" + + if data["created_by"] and data["created_by"]["gid"]: + creator = self.get_user(data["created_by"]["gid"])["name"] + created_date = self.format_date(data["created_at"]) + text += f"Created by: {creator} on {created_date}\n" + + if data["due_on"]: + due_date = self.format_date(data["due_on"]) + text += f"Due date: {due_date}\n" + + if data["completed_at"]: + completed_date = self.format_date(data["completed_at"]) + text += f"Completed on: {completed_date}\n" + + text += "\n" + return text + + def _fetch_and_add_comments(self, task_gid: str) -> str: + text = "" + stories_opts: Dict[str, str] = {} + story_start = time.time() + stories = self.stories_api.get_stories_for_task(task_gid, stories_opts) + + story_count = 0 + comment_count = 0 + + for story in stories: + story_count += 1 + if story["resource_subtype"] == "comment_added": + comment = self.stories_api.get_story( + story["gid"], opts={"opt_fields": "text,created_by,created_at"} + ) + commenter = self.get_user(comment["created_by"]["gid"])["name"] + text += f"Comment by {commenter}: {comment['text']}\n\n" + comment_count += 1 + + story_duration = time.time() - story_start + logger.debug( + f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds" + ) + + return text + + def get_user(self, user_gid: str) -> Dict: + if self._user is not None: + return self._user + self._user = self.users_api.get_user(user_gid, {"opt_fields": "name,email"}) + + if not self._user: + logger.warning(f"Unable to fetch user information for user_gid: {user_gid}") + return {"name": "Unknown"} + return self._user + + def format_date(self, date_str: str) -> str: + date = datetime.fromisoformat(date_str) + return time.strftime("%Y-%m-%d", date.timetuple()) + + def get_time(self) -> str: + return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) diff --git a/backend/danswer/connectors/asana/connector.py b/backend/danswer/connectors/asana/connector.py new file mode 100755 index 00000000000..3e2c9a8aaf6 --- /dev/null +++ b/backend/danswer/connectors/asana/connector.py @@ -0,0 +1,120 @@ +import datetime +from typing import Any + +from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE +from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.constants import DocumentSource +from danswer.connectors.asana import asana_api +from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import LoadConnector +from danswer.connectors.interfaces import PollConnector +from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.connectors.models import Document +from danswer.connectors.models import Section +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +class AsanaConnector(LoadConnector, PollConnector): + def __init__( + self, + asana_workspace_id: str, + asana_project_ids: str | None = None, + asana_team_id: str | None = None, + batch_size: int = INDEX_BATCH_SIZE, + continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, + ) -> None: + self.workspace_id = asana_workspace_id + self.project_ids_to_index: list[str] | None = ( + asana_project_ids.split(",") if asana_project_ids is not None else None + ) + self.asana_team_id = asana_team_id + self.batch_size = batch_size + self.continue_on_failure = continue_on_failure + logger.info( + f"AsanaConnector initialized with workspace_id: {asana_workspace_id}" + ) + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self.api_token = credentials["asana_api_token_secret"] + self.asana_client = asana_api.AsanaAPI( + api_token=self.api_token, + workspace_gid=self.workspace_id, + team_gid=self.asana_team_id, + ) + logger.info("Asana credentials loaded and API client initialized") + return None + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None + ) -> GenerateDocumentsOutput: + start_time = datetime.datetime.fromtimestamp(start).isoformat() + logger.info(f"Starting Asana poll from {start_time}") + asana = asana_api.AsanaAPI( + api_token=self.api_token, + workspace_gid=self.workspace_id, + team_gid=self.asana_team_id, + ) + docs_batch: list[Document] = [] + tasks = asana.get_tasks(self.project_ids_to_index, start_time) + + for task in tasks: + doc = self._message_to_doc(task) + docs_batch.append(doc) + + if len(docs_batch) >= self.batch_size: + logger.info(f"Yielding batch of {len(docs_batch)} documents") + yield docs_batch + docs_batch = [] + + if docs_batch: + logger.info(f"Yielding final batch of {len(docs_batch)} documents") + yield docs_batch + + logger.info("Asana poll completed") + + def load_from_state(self) -> GenerateDocumentsOutput: + logger.notice("Starting full index of all Asana tasks") + return self.poll_source(start=0, end=None) + + def _message_to_doc(self, task: asana_api.AsanaTask) -> Document: + logger.debug(f"Converting Asana task {task.id} to Document") + return Document( + id=task.id, + sections=[Section(link=task.link, text=task.text)], + doc_updated_at=task.last_modified, + source=DocumentSource.ASANA, + semantic_identifier=task.title, + metadata={ + "group": task.project_gid, + "project": task.project_name, + }, + ) + + +if __name__ == "__main__": + import time + import os + + logger.notice("Starting Asana connector test") + connector = AsanaConnector( + os.environ["WORKSPACE_ID"], + os.environ["PROJECT_IDS"], + os.environ["TEAM_ID"], + ) + connector.load_credentials( + { + "asana_api_token_secret": os.environ["API_TOKEN"], + } + ) + logger.info("Loading all documents from Asana") + all_docs = connector.load_from_state() + current = time.time() + one_day_ago = current - 24 * 60 * 60 # 1 day + logger.info("Polling for documents updated in the last 24 hours") + latest_docs = connector.poll_source(one_day_ago, current) + for docs in latest_docs: + for doc in docs: + print(doc.id) + logger.notice("Asana connector test completed") diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index 42d3b0bd937..6df16bea641 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -4,6 +4,7 @@ from sqlalchemy.orm import Session from danswer.configs.constants import DocumentSource +from danswer.connectors.asana.connector import AsanaConnector from danswer.connectors.axero.connector import AxeroConnector from danswer.connectors.blob.connector import BlobStorageConnector from danswer.connectors.bookstack.connector import BookstackConnector @@ -91,6 +92,7 @@ def identify_connector_class( DocumentSource.CLICKUP: ClickupConnector, DocumentSource.MEDIAWIKI: MediaWikiConnector, DocumentSource.WIKIPEDIA: WikipediaConnector, + DocumentSource.ASANA: AsanaConnector, DocumentSource.S3: BlobStorageConnector, DocumentSource.R2: BlobStorageConnector, DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector, diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 5b9d57b9d35..1558ce87641 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -70,6 +70,7 @@ transformers==4.39.2 uvicorn==0.21.1 zulip==0.8.2 hubspot-api-client==8.1.0 +asana==5.0.8 zenpy==2.0.41 dropbox==11.36.2 boto3-stubs[s3]==1.34.133 diff --git a/web/public/Asana.png b/web/public/Asana.png new file mode 100644 index 0000000000000000000000000000000000000000..3d4c0c81d49cbb7785ff6aac299c8f9e94022c59 GIT binary patch literal 6716 zcmeHLdpK0<_a78eIWFO#GL4Z=%*D(wms~T(P$QSARH8AnXH1M4Gou+wh;mPgk|?3l zJydjY=_I-!iA1Os{;fuJu`Ky?gJ8adxs_ zsHCF=fj|~IP|2>~@gTTIE6Rhv3(WFG;GrUz?k;v^NT32CpUw3HpyD6_00pF6HUuJV z<~iG_-GwUN{upa16aM?I{oLNQp|xF-^>q%-eep)JmyGH?A1kzy;Mz_3!AK%@+r3!& zoEr5&fw!)?Gkkk@nxPePo@?yZt|y{lz0bp{sax!>rtj@@d46fH_sM>*)q@9i*tfPW zU2PCaXD{1po$<2WW0yrJ!*|!-$95UlotNw%8}&0fU5Jv}+odfBuAEyZ*L);t!yIkn zSYnogN;daOr22w4FVmH4bv@cP^>!OQSPX&4?B&|pIy>0fevb%@BRwR^f_m3_@jWlM zYpyGmcc{21y06-#q7hKPa?#1vxf+`T(Km)|ed-QeEg~pgnyX-?d0rV=bCXo0wC9yV zbV0lcv)g{{B20z<0sjj#`0t2c5Qa5MtflU+3MIkK_>`RvH6a>oXoe^&5hB! zajED_#f?{2%CL9YBvbpv%QffoE?hM+O4AvH2ap@Y54AAJ^XY{I-%~4n+YXhYlb;?r zQsg93nP&v2q!x>t9V|=j;flR9`Pm7t>aS6|~x{+K`9Dtl*rVPU7GlU^7uFy*$4vU0djcz&aKTWf3Ei&9S7 z*RPOc2Esy}x7Ooe=}Ea@iMi9(l9+rRlELD814t=P0M;G^Vs0fBFqpo880rmhxc(Nf z{#&ImD3@gc+h|Hd(*(AF50@G&1l)q1=*(bWCXofRvQ#pcl0X0+AZ9?NJU@RCNooO` z!6kv`(_$11I-?@?wSc+PoT0XSAppf8aY!`6PRiYZfmtd+&4nyB$(2m`1_54KzAFH z1(N+YO);1K7g>MvZF*!ToSz*5&41wjP5XQ7Gs>VAjYcB#nSs;pIgl-2)AN&9d?uGg znmIKgnt2l#I0Ay<4G<7GGdvbS05EI>#*AQs1#qSWZvgufl!L!W%2g)ihWz;tqX3=V)2_;Y3krU@t6 zI6GLtFi7-Y5@$b#m<<|Oz}9m81Eqfr(78OoP0W~P(*zHaVTw03!I}}xF!))YXYe)v zLJ?Sr)0ie`BzAUWIxQqH7?4=Tbe)0#Gq7MRBwHcC5c7p}KHtv*HVpxtw)`rXVqxvmuBWJAhe+L9E#!rVqoP1Awc?w{rb1=l(+}5C~Wn zi@@?mV6j*R0_W{*iXam3L#%vUv&MV>t8YOuatjP*Dt#M6$AfD`B!!QztN@i z*XtDE4?YD+z?UUy(;M)m7$Wbz#-0qBn*L@MpE(3d<_V~qz&`+#%cpml6eDF_P^chw zpxG(BlV75=eA&@jFBwp+-GOXFmo|UM;L&PYky(XODjiHV`zJ3EiAD!qYyTy=X?b`MnXA2#NJGYVm&Z-MqVgL6F2LoRKU zZM@tt*>*u*R#s-px~%cHf%%X5*4*KZB8Xgw#U_ecf@7GCQ7g6dM8rV01?h}^Y@hRh zdXtLVDc`Z}WZI#RL$Zj+7?P28QWD{-ZJ>c{SelMt$VW4Ye%Qy$bGaFC*6K;DaGlt?N_4!Z6 z7gXT!pVD`YY=1y~c!d{H;$ju5yep^5o-8Ekrlgc-eu-^MdU@lbBUzaCXYuD|`x3!r zk7l|%#dlq`1a`2)!<^MTlqg-iM=o#Yc#vM`>KP_6pl%-QHS&mCH_ z&k)Z9P;;Za9J;yxQFDR51KIVp^=(4Z)sRdrb>9|QdjH!gH><{v@UaG*S`wWE zJ#<=~TClbD%RxUjMK;Xy>!+i>wkT(zQVew^`&6DoWEnen?Z}Nhw1&QCQmsOF&KHTJ zOGf;q7Iw|8>s^Mqs|xM6>^c27dG(gftW#qpkql=2F?Z?;pBk^6eK{L21`G5TY^0vn zWMqC`W1h7(KWWv2V+Goeo~CD%GgUJSUyi13e$lDvXXV-%5ma!_`O~$px{3V?!ONPm zm9FcS7Tn%urXN~pp*!DC`8N9X!k)4ETLWnG&I_-#-Lxx}Z*T-&R+*#L&VLFLKOA@y zu&PGy_M_(oce8-6OzYA+na`z(eED~IGOv0~<2$0l-Mn6f&rdTOmfgHhPBk%tQfiz5 zx69cXy*}S+F30Ii`@{>0U36HcO2K3?>_^MpA9`6~AnBu2#AIKyBbJuP&^OAwY`MS- z>FfG!Ord^K9w)c?S`%gK%~8xJYk`w|$BkE;AI4f3g%)(jj+ybq5R`YP=2QySK;Pj04c&^1)!TxPss)qCc* z?7ggafZ+Oh&jOFs6Z@}@QZ@x+TJ2Y)4Q;_t;7&n%2OG2}MsfhcyNtb#56k1aOm@6& zF0}Eo+SXpLxV>>Q=Q?b#M@;E-{TsFC=Ds+`bA)lxlH%}eI+aqopoZvat>%x2L>%-0z_l(F@xgBasDz$7#F*PcSa*Rk(jp>?<^Kv4S(CQImrXykDjsgwuwC_ZDK%-y zlKWJM%7|OBoK|bY&AbKgQfUFtu9oKBq82OkEV+^T{K?@Zk$3J6*VG0K9$vVfla$NR zTmAHe&&|!gGCe9^H?FCyCm6ZZv4m^kkg_nch5iwA)p$fwdKPCxaQ?AVcN^BgbyBGt z%okou8BlcYMYV`(O{*M-$NBdIk|2-{!|taBNlRs`gua@MBxmV^{;hc3-1D_WoJRGX#{p=lg=df%J! zy!i7Q>krP}EB)i5fD=j1!J3-}!;|~eWFQE4ilUN>JxFTH%A`{;P*4q>V`Q&A{zedv|nlJlTHoM7J8>{8!)NN~;j=YKr*YZlpG9)G2*bw-QN8C$_Qg{Y}_I5gs0m;nnrwe(>06k7Mm!a+N*_&WA^fN-(Ew|um6ID8F^k)D z-qp_?e$+KQRn%vdx56mK#>K|c>3HszQ%5>&tRE?6ZcSFBO)XaKJ}!g!W9{m|hOAW= zsmI13lQNXTrNw;X>wJdCEqga>6h9Lv8p*)UwNaV#-y42_pVid9yXpM=67T^K;$Y`Q KzO;JVp8o)59k`tU literal 0 HcmV?d00001 diff --git a/web/src/components/icons/icons.tsx b/web/src/components/icons/icons.tsx index b5e735b0e65..b14e532f417 100644 --- a/web/src/components/icons/icons.tsx +++ b/web/src/components/icons/icons.tsx @@ -52,7 +52,7 @@ import litellmIcon from "../../../public/LiteLLM.jpg"; import awsWEBP from "../../../public/Amazon.webp"; import azureIcon from "../../../public/Azure.png"; - +import asanaIcon from "../../../public/Asana.png"; import anthropicSVG from "../../../public/Anthropic.svg"; import nomicSVG from "../../../public/nomic.svg"; import microsoftIcon from "../../../public/microsoft.png"; @@ -2811,3 +2811,14 @@ export const WindowsIcon = ({ ); }; +export const AsanaIcon = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => ( +
+ Logo +
+); diff --git a/web/src/lib/connectors/connectors.ts b/web/src/lib/connectors/connectors.ts index b526818b083..2a85990a5a6 100644 --- a/web/src/lib/connectors/connectors.ts +++ b/web/src/lib/connectors/connectors.ts @@ -763,6 +763,38 @@ For example, specifying .*-support.* as a "channel" will cause the connector to }, ], }, + asana: { + description: "Configure Asana connector", + values: [ + { + type: "text", + query: "Enter your Asana workspace ID:", + label: "Workspace ID", + name: "asana_workspace_id", + optional: false, + description: + "The ID of the Asana workspace to index. You can find this at https://app.asana.com/api/1.0/workspaces. It's a number that looks like 1234567890123456.", + }, + { + type: "text", + query: "Enter project IDs to index (optional):", + label: "Project IDs", + name: "asana_project_ids", + description: + "IDs of specific Asana projects to index, separated by commas. Leave empty to index all projects in the workspace. Example: 1234567890123456,2345678901234567", + optional: true, + }, + { + type: "text", + query: "Enter the Team ID (optional):", + label: "Team ID", + name: "asana_team_id", + optional: true, + description: + "ID of a team to use for accessing team-visible tasks. This allows indexing of team-visible tasks in addition to public tasks. Leave empty if you don't want to use this feature.", + }, + ], + }, mediawiki: { description: "Configure MediaWiki connector", values: [ @@ -1056,6 +1088,12 @@ export interface MediaWikiBaseConfig { recurse_depth?: number; } +export interface AsanaConfig { + asana_workspace_id: string; + asana_project_ids?: string; + asana_team_id?: string; +} + export interface MediaWikiConfig extends MediaWikiBaseConfig { hostname: string; } diff --git a/web/src/lib/connectors/credentials.ts b/web/src/lib/connectors/credentials.ts index 424a07c82fe..481a2b3380a 100644 --- a/web/src/lib/connectors/credentials.ts +++ b/web/src/lib/connectors/credentials.ts @@ -166,6 +166,10 @@ export interface SharepointCredentialJson { sp_directory_id: string; } +export interface AsanaCredentialJson { + asana_api_token_secret: string; +} + export interface TeamsCredentialJson { teams_client_id: string; teams_client_secret: string; @@ -241,6 +245,9 @@ export const credentialTemplates: Record = { sp_client_secret: "", sp_directory_id: "", } as SharepointCredentialJson, + asana: { + asana_api_token_secret: "", + } as AsanaCredentialJson, teams: { teams_client_id: "", teams_client_secret: "", @@ -412,6 +419,9 @@ export const credentialDisplayNames: Record = { sp_client_secret: "SharePoint Client Secret", sp_directory_id: "SharePoint Directory ID", + // Asana + asana_api_token_secret: "Asana API Token", + // Teams teams_client_id: "Microsoft Teams Client ID", teams_client_secret: "Microsoft Teams Client Secret", diff --git a/web/src/lib/sources.ts b/web/src/lib/sources.ts index bbc63847adb..6ee40e3c5f9 100644 --- a/web/src/lib/sources.ts +++ b/web/src/lib/sources.ts @@ -32,6 +32,7 @@ import { ZulipIcon, MediaWikiIcon, WikipediaIcon, + AsanaIcon, S3Icon, OCIStorageIcon, GoogleStorageIcon, @@ -230,6 +231,12 @@ const SOURCE_METADATA_MAP: SourceMap = { category: SourceCategory.Wiki, docs: "https://docs.danswer.dev/connectors/wikipedia", }, + asana: { + icon: AsanaIcon, + displayName: "Asana", + category: SourceCategory.ProjectManagement, + docs: "https://docs.danswer.dev/connectors/asana", + }, mediawiki: { icon: MediaWikiIcon, displayName: "MediaWiki", diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 1ea67d14064..9d385073648 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -247,6 +247,7 @@ const validSources = [ "clickup", "wikipedia", "mediawiki", + "asana", "s3", "r2", "google_cloud_storage", From 16d1c19d9f856f10916a9b2ee2ebdb219dc9c7ef Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 19 Sep 2024 17:29:49 -0700 Subject: [PATCH 009/505] Added bool to disable chat_session_id check for search_docs for api --- backend/danswer/chat/process_message.py | 2 ++ backend/danswer/db/chat.py | 4 +++- backend/ee/danswer/server/query_and_chat/chat_backend.py | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 93ad3bdd34e..2e79e200676 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -271,6 +271,7 @@ def stream_chat_message_objects( use_existing_user_message: bool = False, litellm_additional_headers: dict[str, str] | None = None, is_connected: Callable[[], bool] | None = None, + enforce_chat_session_id_for_search_docs: bool = True, ) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run @@ -442,6 +443,7 @@ def stream_chat_message_objects( chat_session=chat_session, user_id=user_id, db_session=db_session, + enforce_chat_session_id_for_search_docs=enforce_chat_session_id_for_search_docs, ) # Generates full documents currently diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 8599714ce8b..feb2e2b4b51 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -598,6 +598,7 @@ def get_doc_query_identifiers_from_model( chat_session: ChatSession, user_id: UUID | None, db_session: Session, + enforce_chat_session_id_for_search_docs: bool, ) -> list[tuple[str, int]]: """Given a list of search_doc_ids""" search_docs = ( @@ -617,7 +618,8 @@ def get_doc_query_identifiers_from_model( for doc in search_docs ] ): - raise ValueError("Invalid reference doc, not from this chat session.") + if enforce_chat_session_id_for_search_docs: + raise ValueError("Invalid reference doc, not from this chat session.") except IndexError: # This happens when the doc has no chat_messages associated with it. # which happens as an edge case where the chat message failed to save diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index f5b56b7d43e..dd637dcf081 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -182,6 +182,7 @@ def handle_simplified_chat_message( new_msg_req=full_chat_msg_info, user=user, db_session=db_session, + enforce_chat_session_id_for_search_docs=False, ) return _convert_packet_stream_to_response(packets) @@ -301,6 +302,7 @@ def handle_send_message_simple_with_history( new_msg_req=full_chat_msg_info, user=user, db_session=db_session, + enforce_chat_session_id_for_search_docs=False, ) return _convert_packet_stream_to_response(packets) From c82a36ad6882c7c83334f0e6a30eeb62a2aba651 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 19 Sep 2024 17:20:50 -0700 Subject: [PATCH 010/505] Saml account fastapi deletion (#2512) * saml account fastapi deletion * update error detail --- backend/danswer/server/manage/users.py | 5 ++++- web/src/components/admin/users/SignedUpUserTable.tsx | 3 ++- web/src/lib/admin/users/userMutationFetcher.ts | 1 + 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 30b71ef0a6a..e72b85dedad 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -33,6 +33,7 @@ from danswer.db.models import AccessToken from danswer.db.models import DocumentSet__User from danswer.db.models import Persona__User +from danswer.db.models import SamlAccount from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.db.users import get_user_by_email @@ -249,7 +250,9 @@ async def delete_user( db_session=db_session, user_id=user_to_delete.id, ) - + db_session.query(SamlAccount).filter( + SamlAccount.user_id == user_to_delete.id + ).delete() db_session.query(DocumentSet__User).filter( DocumentSet__User.user_id == user_to_delete.id ).delete() diff --git a/web/src/components/admin/users/SignedUpUserTable.tsx b/web/src/components/admin/users/SignedUpUserTable.tsx index cb30e4d89bc..b718221d6c0 100644 --- a/web/src/components/admin/users/SignedUpUserTable.tsx +++ b/web/src/components/admin/users/SignedUpUserTable.tsx @@ -143,7 +143,8 @@ const DeactivaterButton = ({ type: "success", }); }, - onError: (errorMsg) => setPopup({ message: errorMsg, type: "error" }), + onError: (errorMsg) => + setPopup({ message: errorMsg.message, type: "error" }), } ); return ( diff --git a/web/src/lib/admin/users/userMutationFetcher.ts b/web/src/lib/admin/users/userMutationFetcher.ts index d0c090c89ce..3a3e134e967 100644 --- a/web/src/lib/admin/users/userMutationFetcher.ts +++ b/web/src/lib/admin/users/userMutationFetcher.ts @@ -11,6 +11,7 @@ const userMutationFetcher = async ( body: JSON.stringify(body), }).then(async (res) => { if (res.ok) return res.json(); + const errorDetail = (await res.json()).detail; throw Error(errorDetail); }); From 5f2644985cb721364a3a85c5b7c2bea737f023b9 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Fri, 20 Sep 2024 08:44:28 -0700 Subject: [PATCH 011/505] Route name (#2520) * clearer refresh logic * rename path --- web/src/app/admin/assistants/lib.ts | 2 +- web/src/app/assistants/mine/AssistantsList.tsx | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/web/src/app/admin/assistants/lib.ts b/web/src/app/admin/assistants/lib.ts index 61ede3f1643..db1659e9d61 100644 --- a/web/src/app/admin/assistants/lib.ts +++ b/web/src/app/admin/assistants/lib.ts @@ -331,7 +331,7 @@ export const togglePersonaVisibility = async ( personaId: number, isVisible: boolean ) => { - const response = await fetch(`/api/persona/${personaId}/visible`, { + const response = await fetch(`/api/admin/persona/${personaId}/visible`, { method: "PATCH", headers: { "Content-Type": "application/json", diff --git a/web/src/app/assistants/mine/AssistantsList.tsx b/web/src/app/assistants/mine/AssistantsList.tsx index 211ac39463d..8798de2fd4c 100644 --- a/web/src/app/assistants/mine/AssistantsList.tsx +++ b/web/src/app/assistants/mine/AssistantsList.tsx @@ -306,7 +306,12 @@ export function AssistantsList({ const [currentlyVisibleAssistants, setCurrentlyVisibleAssistants] = useState< Persona[] - >(orderAssistantsForUser(visibleAssistants, user)); + >([]); + + useEffect(() => { + const orderedAssistants = orderAssistantsForUser(visibleAssistants, user); + setCurrentlyVisibleAssistants(orderedAssistants); + }, [assistants, user]); const ownedButHiddenAssistants = getUserCreatedAssistants( user, From 00229d2abedf898f1cb2510bfd3577311beb9716 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Fri, 20 Sep 2024 09:39:34 -0700 Subject: [PATCH 012/505] Add start date to persona (#2407) * add start date to persona * remove logs * rename * update assistant editor * update alembic * update alembic * update alembic * udpate alembic * remove rebase artifacts --- .../797089dfb4d2_persona_start_date.py | 27 +++++++++++++++++++ backend/danswer/db/models.py | 3 +++ backend/danswer/db/persona.py | 4 +++ .../search/preprocessing/preprocessing.py | 5 +++- .../danswer/server/features/persona/models.py | 5 +++- .../app/admin/assistants/AssistantEditor.tsx | 25 +++++++++++++++-- web/src/app/admin/assistants/interfaces.ts | 1 + web/src/app/admin/assistants/lib.ts | 5 +++- web/src/components/admin/connectors/Field.tsx | 4 ++- 9 files changed, 73 insertions(+), 6 deletions(-) create mode 100644 backend/alembic/versions/797089dfb4d2_persona_start_date.py diff --git a/backend/alembic/versions/797089dfb4d2_persona_start_date.py b/backend/alembic/versions/797089dfb4d2_persona_start_date.py new file mode 100644 index 00000000000..52ade3dea4e --- /dev/null +++ b/backend/alembic/versions/797089dfb4d2_persona_start_date.py @@ -0,0 +1,27 @@ +"""persona_start_date + +Revision ID: 797089dfb4d2 +Revises: 55546a7967ee +Create Date: 2024-09-11 14:51:49.785835 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "797089dfb4d2" +down_revision = "55546a7967ee" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "persona", + sa.Column("search_start_date", sa.DateTime(timezone=True), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("persona", "search_start_date") diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 70e583ad9c5..aecb1d9f973 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -1314,6 +1314,9 @@ class Persona(Base): starter_messages: Mapped[list[StarterMessage] | None] = mapped_column( postgresql.JSONB(), nullable=True ) + search_start_date: Mapped[datetime.datetime | None] = mapped_column( + DateTime(timezone=True), default=None + ) # Built-in personas are configured via backend during deployment # Treated specially (cannot be user edited etc.) builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False) diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index 316919e58b4..36d2d25c402 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from datetime import datetime from functools import lru_cache from uuid import UUID @@ -414,6 +415,7 @@ def upsert_persona( display_priority: int | None = None, is_visible: bool = True, remove_image: bool | None = None, + search_start_date: datetime | None = None, builtin_persona: bool = False, is_default_persona: bool = False, chunks_above: int = CONTEXT_CHUNKS_ABOVE, @@ -484,6 +486,7 @@ def upsert_persona( persona.uploaded_image_id = uploaded_image_id persona.display_priority = display_priority persona.is_visible = is_visible + persona.search_start_date = search_start_date persona.is_default_persona = is_default_persona # Do not delete any associations manually added unless @@ -524,6 +527,7 @@ def upsert_persona( uploaded_image_id=uploaded_image_id, display_priority=display_priority, is_visible=is_visible, + search_start_date=search_start_date, is_default_persona=is_default_persona, ) db_session.add(persona) diff --git a/backend/danswer/search/preprocessing/preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py index 43a6a43ce88..37fb254884a 100644 --- a/backend/danswer/search/preprocessing/preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -67,6 +67,9 @@ def retrieval_preprocessing( ] time_filter = preset_filters.time_cutoff + if time_filter is None and persona: + time_filter = persona.search_start_date + source_filter = preset_filters.source_type auto_detect_time_filter = True @@ -154,7 +157,7 @@ def retrieval_preprocessing( final_filters = IndexFilters( source_type=preset_filters.source_type or predicted_source_filters, document_set=preset_filters.document_set, - time_cutoff=preset_filters.time_cutoff or predicted_time_cutoff, + time_cutoff=time_filter or predicted_time_cutoff, tags=preset_filters.tags, # Tags are never auto-extracted access_control_list=user_acl_filters, ) diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 0112cc110e2..016defda369 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -1,3 +1,4 @@ +from datetime import datetime from uuid import UUID from pydantic import BaseModel @@ -12,7 +13,6 @@ from danswer.server.models import MinimalUserSnapshot from danswer.utils.logger import setup_logger - logger = setup_logger() @@ -40,6 +40,7 @@ class CreatePersonaRequest(BaseModel): remove_image: bool | None = None is_default_persona: bool = False display_priority: int | None = None + search_start_date: datetime | None = None class PersonaSnapshot(BaseModel): @@ -66,6 +67,7 @@ class PersonaSnapshot(BaseModel): icon_shape: int | None uploaded_image_id: str | None = None is_default_persona: bool + search_start_date: datetime | None = None @classmethod def from_model( @@ -112,6 +114,7 @@ def from_model( icon_color=persona.icon_color, icon_shape=persona.icon_shape, uploaded_image_id=persona.uploaded_image_id, + search_start_date=persona.search_start_date, ) diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 4951f459b82..1f961614dbc 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -2,7 +2,7 @@ import { generateRandomIconShape, createSVG } from "@/lib/assistantIconUtils"; -import { CCPairBasicInfo, DocumentSet, User, UserRole } from "@/lib/types"; +import { CCPairBasicInfo, DocumentSet, User } from "@/lib/types"; import { Button, Divider, Italic } from "@tremor/react"; import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector"; import { @@ -230,6 +230,9 @@ export function AssistantEditor({ existingPersona?.document_sets?.map((documentSet) => documentSet.id) ?? ([] as number[]), num_chunks: existingPersona?.num_chunks ?? null, + search_start_date: existingPersona?.search_start_date + ? existingPersona?.search_start_date.toString().split("T")[0] + : null, include_citations: existingPersona?.prompts[0]?.include_citations ?? true, llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false, llm_model_provider_override: @@ -284,6 +287,7 @@ export function AssistantEditor({ ), }) ), + search_start_date: Yup.date().nullable(), icon_color: Yup.string(), icon_shape: Yup.number(), uploaded_image: Yup.mixed().nullable(), @@ -373,6 +377,9 @@ export function AssistantEditor({ id: existingPersona.id, existingPromptId: existingPrompt?.id, ...values, + search_start_date: values.search_start_date + ? new Date(values.search_start_date) + : null, num_chunks: numChunks, users: user && !checkUserIsNoAuthUser(user.id) ? [user.id] : undefined, @@ -385,6 +392,9 @@ export function AssistantEditor({ ...values, is_default_persona: admin!, num_chunks: numChunks, + search_start_date: values.search_start_date + ? new Date(values.search_start_date) + : null, users: user && !checkUserIsNoAuthUser(user.id) ? [user.id] : undefined, groups, @@ -912,7 +922,7 @@ export function AssistantEditor({ )} -
+
+ + ) => void; + width?: string; }) { let heightString = defaultHeight || ""; if (isTextArea && !heightString) { @@ -182,7 +184,7 @@ export function TextFormField({ }; return ( -
+
{!removeLabel && (