From 54dc1ac91776a6ff21f6e8b80e85d32cb957f701 Mon Sep 17 00:00:00 2001 From: "Richard Kuo (Danswer)" Date: Thu, 14 Nov 2024 11:14:12 -0800 Subject: [PATCH 01/11] unnecessary python setup --- .github/workflows/pr-helm-chart-testing.yml | 31 +++++++++++---------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/.github/workflows/pr-helm-chart-testing.yml b/.github/workflows/pr-helm-chart-testing.yml index 5b00303010d..f26ab43e780 100644 --- a/.github/workflows/pr-helm-chart-testing.yml +++ b/.github/workflows/pr-helm-chart-testing.yml @@ -23,21 +23,6 @@ jobs: with: version: v3.14.4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.11' - cache: 'pip' - cache-dependency-path: | - backend/requirements/default.txt - backend/requirements/dev.txt - backend/requirements/model_server.txt - - run: | - python -m pip install --upgrade pip - pip install --retries 5 --timeout 30 -r backend/requirements/default.txt - pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt - pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt - - name: Set up chart-testing uses: helm/chart-testing-action@v2.6.1 @@ -52,6 +37,22 @@ jobs: echo "changed=true" >> "$GITHUB_OUTPUT" fi +# rkuo: I don't think we need python? +# - name: Set up Python +# uses: actions/setup-python@v5 +# with: +# python-version: '3.11' +# cache: 'pip' +# cache-dependency-path: | +# backend/requirements/default.txt +# backend/requirements/dev.txt +# backend/requirements/model_server.txt +# - run: | +# python -m pip install --upgrade pip +# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt +# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt +# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt + # lint all charts if any changes were detected - name: Run chart-testing (lint) if: steps.list-changed.outputs.changed == 'true' From 97932dc44b06c3499d185e3f1cda162081bf2adf Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 14 Nov 2024 17:28:03 -0800 Subject: [PATCH 02/11] Fix Quotes Prompting (#3137) --- backend/danswer/llm/answering/answer.py | 1 + backend/danswer/llm/answering/prompts/build.py | 7 +++++++ .../tools/tool_implementations/search_like_tool_utils.py | 5 ++++- examples/assistants-api/topics_analyzer.py | 4 ++-- 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index c447db452ff..f9c9dbcb100 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -263,6 +263,7 @@ def processed_streamed_output(self) -> AnswerStream: message_history=self.message_history, llm_config=self.llm.config, single_message_history=self.single_message_history, + raw_user_text=self.question, ) prompt_builder.update_system_prompt( default_build_system_message(self.prompt_config) diff --git a/backend/danswer/llm/answering/prompts/build.py b/backend/danswer/llm/answering/prompts/build.py index b5b774f522d..ac9ce6f1abd 100644 --- a/backend/danswer/llm/answering/prompts/build.py +++ b/backend/danswer/llm/answering/prompts/build.py @@ -59,6 +59,7 @@ def __init__( message_history: list[PreviousMessage], llm_config: LLMConfig, single_message_history: str | None = None, + raw_user_text: str | None = None, ) -> None: self.max_tokens = compute_max_llm_input_tokens(llm_config) @@ -88,6 +89,12 @@ def __init__( self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = [] + self.raw_user_message = ( + HumanMessage(content=raw_user_text) + if raw_user_text is not None + else user_message + ) + def update_system_prompt(self, system_message: SystemMessage | None) -> None: if not system_message: self.system_message_and_token_cnt = None diff --git a/backend/danswer/tools/tool_implementations/search_like_tool_utils.py b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py index 6701f1602ea..121841d0ba3 100644 --- a/backend/danswer/tools/tool_implementations/search_like_tool_utils.py +++ b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py @@ -55,9 +55,12 @@ def build_next_prompt_for_search_like_tool( ) ) elif answer_style_config.quotes_config: + # For Quotes, the system prompt is included in the user prompt + prompt_builder.update_system_prompt(None) + prompt_builder.update_user_prompt( build_quotes_user_message( - message=prompt_builder.user_message_and_token_cnt[0], + message=prompt_builder.raw_user_message, context_docs=final_context_documents, history_str=prompt_builder.single_message_history or "", prompt=prompt_config, diff --git a/examples/assistants-api/topics_analyzer.py b/examples/assistants-api/topics_analyzer.py index a8ef6c27a23..d99c384363a 100644 --- a/examples/assistants-api/topics_analyzer.py +++ b/examples/assistants-api/topics_analyzer.py @@ -28,7 +28,7 @@ """ -def wait_on_run(client: OpenAI, run, thread): +def wait_on_run(client: OpenAI, run, thread): # type: ignore while run.status == "queued" or run.status == "in_progress": run = client.beta.threads.runs.retrieve( thread_id=thread.id, @@ -38,7 +38,7 @@ def wait_on_run(client: OpenAI, run, thread): return run -def show_response(messages) -> None: +def show_response(messages) -> None: # type: ignore # Get only the assistant's response text for message in messages.data[::-1]: if message.role == "assistant": From ddff7ecc3f4ea0838f3b5ae72658e90b9eedce7e Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 14 Nov 2024 18:09:30 -0800 Subject: [PATCH 03/11] minor configuration updates (#3134) --- .../vespa/app_config/schemas/danswer_chunk.sd | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd b/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd index b98c9343f3f..e712266fa08 100644 --- a/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd +++ b/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd @@ -15,7 +15,7 @@ schema DANSWER_CHUNK_NAME { # Must have an additional field for whether to skip title embeddings # This information cannot be extracted from either the title field nor title embedding field skip_title type bool { - indexing: attribute + indexing: attribute } # May not always match the `semantic_identifier` e.g. for Slack docs the # `semantic_identifier` will be the channel name, but the `title` will be empty @@ -36,7 +36,7 @@ schema DANSWER_CHUNK_NAME { } # Title embedding (x1) field title_embedding type tensor(x[VARIABLE_DIM]) { - indexing: attribute + indexing: attribute | index attribute { distance-metric: angular } @@ -44,7 +44,7 @@ schema DANSWER_CHUNK_NAME { # Content embeddings (chunk + optional mini chunks embeddings) # "t" and "x" are arbitrary names, not special keywords field embeddings type tensor(t{},x[VARIABLE_DIM]) { - indexing: attribute + indexing: attribute | index attribute { distance-metric: angular } From 24be13c015ec51fe525e36e299fb9705d0767c3f Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 14 Nov 2024 20:13:29 -0800 Subject: [PATCH 04/11] Improved tokenizer fallback (#3132) * silence warning * improved fallback logic * k * minor cosmetic update * minor logic update * nit --- .../natural_language_processing/utils.py | 89 ++++++++++--------- 1 file changed, 46 insertions(+), 43 deletions(-) diff --git a/backend/danswer/natural_language_processing/utils.py b/backend/danswer/natural_language_processing/utils.py index a8250570e84..56631495a4e 100644 --- a/backend/danswer/natural_language_processing/utils.py +++ b/backend/danswer/natural_language_processing/utils.py @@ -89,67 +89,70 @@ def _check_tokenizer_cache( model_provider: EmbeddingProvider | None, model_name: str | None ) -> BaseTokenizer: global _TOKENIZER_CACHE - id_tuple = (model_provider, model_name) if id_tuple not in _TOKENIZER_CACHE: - if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]: - if model_name is None: - raise ValueError( - "model_name is required for OPENAI and AZURE embeddings" - ) + tokenizer = None - _TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name) - return _TOKENIZER_CACHE[id_tuple] + if model_name: + tokenizer = _try_initialize_tokenizer(model_name, model_provider) - try: - if model_name is None: - model_name = DOCUMENT_ENCODER_MODEL - - logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}") - _TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name) - except Exception as primary_error: - logger.error( - f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}" - ) - logger.warning( + if not tokenizer: + logger.info( f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}" ) + tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) - try: - # Cache this tokenizer name to the default so we don't have to try to load it again - # and fail again - _TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer( - DOCUMENT_ENCODER_MODEL - ) - except Exception as fallback_error: - logger.error( - f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}" - ) - raise ValueError( - f"Failed to initialize tokenizer for {model_name} and fallback model" - ) from fallback_error + _TOKENIZER_CACHE[id_tuple] = tokenizer return _TOKENIZER_CACHE[id_tuple] +def _try_initialize_tokenizer( + model_name: str, model_provider: EmbeddingProvider | None +) -> BaseTokenizer | None: + tokenizer: BaseTokenizer | None = None + + if model_provider is not None: + # Try using TiktokenTokenizer first if model_provider exists + try: + tokenizer = TiktokenTokenizer(model_name) + logger.info(f"Initialized TiktokenTokenizer for: {model_name}") + return tokenizer + except Exception as tiktoken_error: + logger.debug( + f"TiktokenTokenizer not available for model {model_name}: {tiktoken_error}" + ) + else: + # If no provider specified, try HuggingFaceTokenizer + try: + tokenizer = HuggingFaceTokenizer(model_name) + logger.info(f"Initialized HuggingFaceTokenizer for: {model_name}") + return tokenizer + except Exception as hf_error: + logger.warning( + f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}" + ) + + # If both initializations fail, return None + return None + + _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) def get_tokenizer( model_name: str | None, provider_type: EmbeddingProvider | str | None ) -> BaseTokenizer: - if provider_type is not None: - if isinstance(provider_type, str): - try: - provider_type = EmbeddingProvider(provider_type) - except ValueError: - logger.debug( - f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer." - ) - return _DEFAULT_TOKENIZER - return _check_tokenizer_cache(provider_type, model_name) - return _DEFAULT_TOKENIZER + if isinstance(provider_type, str): + try: + provider_type = EmbeddingProvider(provider_type) + except ValueError: + logger.debug( + f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer." + ) + return _DEFAULT_TOKENIZER + return _check_tokenizer_cache(provider_type, model_name) def tokenizer_trim_content( From 7015e6f2abd59c3b87b8eb315e547cf8d5961605 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Fri, 15 Nov 2024 16:47:52 -0800 Subject: [PATCH 05/11] Bugfix/overlapping connectors (#3138) * fix tenant logging * upsert only new/updated docs, but always upsert document to cc pair relationship * better logging and rough cut at testing --- .../background/indexing/run_indexing.py | 16 +++--- backend/danswer/db/document.py | 29 +++------- backend/danswer/indexing/indexing_pipeline.py | 57 ++++++++++++------- .../connector/test_connector_deletion.py | 53 +++++++++++++++++ 4 files changed, 106 insertions(+), 49 deletions(-) diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 2cd1fc07e2a..8f3ae65fa3d 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -433,11 +433,13 @@ def run_indexing_entrypoint( with get_session_with_tenant(tenant_id) as db_session: attempt = transition_attempt_to_in_progress(index_attempt_id, db_session) + tenant_str = "" + if tenant_id is not None: + tenant_str = f" for tenant {tenant_id}" + logger.info( - f"Indexing starting for tenant {tenant_id}: " - if tenant_id is not None - else "" - + f"connector='{attempt.connector_credential_pair.connector.name}' " + f"Indexing starting{tenant_str}: " + f"connector='{attempt.connector_credential_pair.connector.name}' " f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' " f"credentials='{attempt.connector_credential_pair.connector_id}'" ) @@ -445,10 +447,8 @@ def run_indexing_entrypoint( _run_indexing(db_session, attempt, tenant_id, callback) logger.info( - f"Indexing finished for tenant {tenant_id}: " - if tenant_id is not None - else "" - + f"connector='{attempt.connector_credential_pair.connector.name}' " + f"Indexing finished{tenant_str}: " + f"connector='{attempt.connector_credential_pair.connector.name}' " f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' " f"credentials='{attempt.connector_credential_pair.connector_id}'" ) diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index 32efbf511be..75d6c8d0875 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -308,7 +308,7 @@ def get_access_info_for_documents( return db_session.execute(stmt).all() # type: ignore -def _upsert_documents( +def upsert_documents( db_session: Session, document_metadata_batch: list[DocumentMetadata], initial_boost: int = DEFAULT_BOOST, @@ -364,24 +364,24 @@ def _upsert_documents( db_session.commit() -def _upsert_document_by_connector_credential_pair( - db_session: Session, document_metadata_batch: list[DocumentMetadata] +def upsert_document_by_connector_credential_pair( + db_session: Session, connector_id: int, credential_id: int, document_ids: list[str] ) -> None: """NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause.""" - if not document_metadata_batch: - logger.info("`document_metadata_batch` is empty. Skipping.") + if not document_ids: + logger.info("`document_ids` is empty. Skipping.") return insert_stmt = insert(DocumentByConnectorCredentialPair).values( [ model_to_dict( DocumentByConnectorCredentialPair( - id=document_metadata.document_id, - connector_id=document_metadata.connector_id, - credential_id=document_metadata.credential_id, + id=doc_id, + connector_id=connector_id, + credential_id=credential_id, ) ) - for document_metadata in document_metadata_batch + for doc_id in document_ids ] ) # for now, there are no columns to update. If more metadata is added, then this @@ -442,17 +442,6 @@ def mark_document_as_synced(document_id: str, db_session: Session) -> None: db_session.commit() -def upsert_documents_complete( - db_session: Session, - document_metadata_batch: list[DocumentMetadata], -) -> None: - _upsert_documents(db_session, document_metadata_batch) - _upsert_document_by_connector_credential_pair(db_session, document_metadata_batch) - logger.info( - f"Upserted {len(document_metadata_batch)} document store entries into DB" - ) - - def delete_document_by_connector_credential_pair__no_commit( db_session: Session, document_id: str, diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index 2324edf4d03..f23bf91d2f1 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -20,7 +20,8 @@ from danswer.db.document import prepare_to_modify_documents from danswer.db.document import update_docs_last_modified__no_commit from danswer.db.document import update_docs_updated_at__no_commit -from danswer.db.document import upsert_documents_complete +from danswer.db.document import upsert_document_by_connector_credential_pair +from danswer.db.document import upsert_documents from danswer.db.document_set import fetch_document_sets_for_documents from danswer.db.index_attempt import create_index_attempt_error from danswer.db.models import Document as DBDocument @@ -62,7 +63,7 @@ def _upsert_documents_in_db( db_session: Session, ) -> None: # Metadata here refers to basic document info, not metadata about the actual content - doc_m_batch: list[DocumentMetadata] = [] + document_metadata_list: list[DocumentMetadata] = [] for doc in documents: first_link = next( (section.link for section in doc.sections if section.link), "" @@ -77,12 +78,9 @@ def _upsert_documents_in_db( secondary_owners=get_experts_stores_representations(doc.secondary_owners), from_ingestion_api=doc.from_ingestion_api, ) - doc_m_batch.append(db_doc_metadata) + document_metadata_list.append(db_doc_metadata) - upsert_documents_complete( - db_session=db_session, - document_metadata_batch=doc_m_batch, - ) + upsert_documents(db_session, document_metadata_list) # Insert document content metadata for doc in documents: @@ -109,7 +107,10 @@ def get_doc_ids_to_update( documents: list[Document], db_docs: list[DBDocument] ) -> list[Document]: """Figures out which documents actually need to be updated. If a document is already present - and the `updated_at` hasn't changed, we shouldn't need to do anything with it.""" + and the `updated_at` hasn't changed, we shouldn't need to do anything with it. + + NB: Still need to associate the document in the DB if multiple connectors are + indexing the same doc.""" id_update_time_map = { doc.id: doc.doc_updated_at for doc in db_docs if doc.doc_updated_at } @@ -197,7 +198,7 @@ def index_doc_batch_prepare( ) -> DocumentBatchPrepareContext | None: """This sets up the documents in the relational DB (source of truth) for permissions, metadata, etc. This preceeds indexing it into the actual document index.""" - documents = [] + documents: list[Document] = [] for document in document_batch: empty_contents = not any(section.text.strip() for section in document.sections) if ( @@ -223,32 +224,46 @@ def index_doc_batch_prepare( else: documents.append(document) - document_ids = [document.id for document in documents] + # Create a trimmed list of docs that don't have a newer updated at + # Shortcuts the time-consuming flow on connector index retries + document_ids: list[str] = [document.id for document in documents] db_docs: list[DBDocument] = get_documents_by_ids( db_session=db_session, document_ids=document_ids, ) - # Skip indexing docs that don't have a newer updated at - # Shortcuts the time-consuming flow on connector index retries updatable_docs = ( get_doc_ids_to_update(documents=documents, db_docs=db_docs) if not ignore_time_skip else documents ) - # No docs to update either because the batch is empty or every doc was already indexed - if not updatable_docs: - return None + # for all updatable docs, upsert into the DB + # Does not include doc_updated_at which is also used to indicate a successful update + if updatable_docs: + _upsert_documents_in_db( + documents=updatable_docs, + index_attempt_metadata=index_attempt_metadata, + db_session=db_session, + ) - # Create records in the source of truth about these documents, - # does not include doc_updated_at which is also used to indicate a successful update - _upsert_documents_in_db( - documents=documents, - index_attempt_metadata=index_attempt_metadata, - db_session=db_session, + logger.info( + f"Upserted {len(updatable_docs)} changed docs out of " + f"{len(documents)} total docs into the DB" + ) + + # for all docs, upsert the document to cc pair relationship + upsert_document_by_connector_credential_pair( + db_session, + index_attempt_metadata.connector_id, + index_attempt_metadata.credential_id, + document_ids, ) + # No docs to process because the batch is empty or every doc was already indexed + if not updatable_docs: + return None + id_to_db_doc_map = {doc.id: doc for doc in db_docs} return DocumentBatchPrepareContext( updatable_docs=updatable_docs, id_to_db_doc_map=id_to_db_doc_map diff --git a/backend/tests/integration/tests/connector/test_connector_deletion.py b/backend/tests/integration/tests/connector/test_connector_deletion.py index 663aedfc335..b14a75e0045 100644 --- a/backend/tests/integration/tests/connector/test_connector_deletion.py +++ b/backend/tests/integration/tests/connector/test_connector_deletion.py @@ -48,6 +48,59 @@ def test_connector_creation(reset: None) -> None: assert cc_pair_info.creator_email == admin_user.email +# TODO(rkuo): will enable this once i have credentials on github +# def test_overlapping_connector_creation(reset: None) -> None: +# # Creating an admin user (first user created is automatically an admin) +# admin_user: DATestUser = UserManager.create(name="admin_user") + +# config = { +# "wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"], +# "space": os.environ["CONFLUENCE_TEST_SPACE"], +# "is_cloud": True, +# "page_id": "", +# } + +# credential = { +# "confluence_username": os.environ["CONFLUENCE_USER_NAME"], +# "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"], +# } + +# # store the time before we create the connector so that we know after +# # when the indexing should have started +# now = datetime.now(timezone.utc) + +# # create connector +# cc_pair_1 = CCPairManager.create_from_scratch( +# source=DocumentSource.CONFLUENCE, +# connector_specific_config=config, +# credential_json=credential, +# user_performing_action=admin_user, +# ) + +# CCPairManager.wait_for_indexing( +# cc_pair_1, now, timeout=60, user_performing_action=admin_user +# ) + +# cc_pair_2 = CCPairManager.create_from_scratch( +# source=DocumentSource.CONFLUENCE, +# connector_specific_config=config, +# credential_json=credential, +# user_performing_action=admin_user, +# ) + +# CCPairManager.wait_for_indexing( +# cc_pair_2, now, timeout=60, user_performing_action=admin_user +# ) + +# info_1 = CCPairManager.get_single(cc_pair_1.id) +# assert info_1 + +# info_2 = CCPairManager.get_single(cc_pair_2.id) +# assert info_2 + +# assert info_1.num_docs_indexed == info_2.num_docs_indexed + + def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None: # Creating an admin user (first user created is automatically an admin) admin_user: DATestUser = UserManager.create(name="admin_user") From 259fc049b7bf993458fb9a6ef02b9bf348872862 Mon Sep 17 00:00:00 2001 From: Weves Date: Fri, 15 Nov 2024 19:43:20 -0800 Subject: [PATCH 06/11] Add error message on JSON decode error in CustomTool --- .../tool_implementations/custom/custom_tool.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/backend/danswer/tools/tool_implementations/custom/custom_tool.py b/backend/danswer/tools/tool_implementations/custom/custom_tool.py index 9e9aa121681..7386961ee75 100644 --- a/backend/danswer/tools/tool_implementations/custom/custom_tool.py +++ b/backend/danswer/tools/tool_implementations/custom/custom_tool.py @@ -13,6 +13,7 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import SystemMessage from pydantic import BaseModel +from requests import JSONDecodeError from danswer.configs.constants import FileOrigin from danswer.db.engine import get_session_with_default_tenant @@ -241,6 +242,8 @@ def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: ) content_type = response.headers.get("Content-Type", "") + tool_result: Any + response_type: str if "text/csv" in content_type: file_ids = self._save_and_get_file_references( response.content, content_type @@ -256,8 +259,15 @@ def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: response_type = "image" else: - tool_result = response.json() - response_type = "json" + try: + tool_result = response.json() + response_type = "json" + except JSONDecodeError: + logger.exception( + f"Failed to parse response as JSON for tool '{self._name}'" + ) + tool_result = response.text + response_type = "text" logger.info( f"Returning tool response for {self._name} with type {response_type}" From 6e83fe3a3910bbcbe4201faafdcf49187c4939b0 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Fri, 15 Nov 2024 19:38:30 -0800 Subject: [PATCH 07/11] reworked drive+confluence frontend and implied backend changes (#3143) * reworked drive+confluence frontend and implied backend changes * fixed oauth admin tests * fixed service account tests * frontend cleanup * copy change * details! * added key * so good * whoops! * fixed mnore treljsertjoslijt * has issue with boolean form * should be done --- .../workflows/pr-python-connector-tests.yml | 1 + .../connectors/confluence/connector.py | 8 +- .../connectors/google_drive/connector.py | 194 +++++++++--- .../connectors/google_drive/file_retrieval.py | 68 +++- .../daily/connectors/google_drive/conftest.py | 44 ++- .../google_drive/consts_and_utils.py | 19 +- ...gle_drive_oauth.py => test_admin_oauth.py} | 77 +++-- ...gle_drive_sections.py => test_sections.py} | 9 + ...e_service_acct.py => test_service_acct.py} | 80 ++--- ...e_drive_slim_docs.py => test_slim_docs.py} | 5 + .../google_drive/test_user_1_oauth.py | 218 +++++++++++++ .../pages/DynamicConnectorCreationForm.tsx | 139 +-------- .../[connector]/pages/FieldRendering.tsx | 224 ++++++++++++++ .../credentials/CredentialFields.tsx | 37 ++- web/src/components/ui/fully_wrapped_tabs.tsx | 94 ++++++ web/src/lib/connectors/connectors.tsx | 290 ++++++++++++------ 16 files changed, 1102 insertions(+), 405 deletions(-) rename backend/tests/daily/connectors/google_drive/{test_google_drive_oauth.py => test_admin_oauth.py} (86%) rename backend/tests/daily/connectors/google_drive/{test_google_drive_sections.py => test_sections.py} (88%) rename backend/tests/daily/connectors/google_drive/{test_google_drive_service_acct.py => test_service_acct.py} (88%) rename backend/tests/daily/connectors/google_drive/{test_google_drive_slim_docs.py => test_slim_docs.py} (97%) create mode 100644 backend/tests/daily/connectors/google_drive/test_user_1_oauth.py create mode 100644 web/src/app/admin/connectors/[connector]/pages/FieldRendering.tsx create mode 100644 web/src/components/ui/fully_wrapped_tabs.tsx diff --git a/.github/workflows/pr-python-connector-tests.yml b/.github/workflows/pr-python-connector-tests.yml index 88053427726..6e122860ee9 100644 --- a/.github/workflows/pr-python-connector-tests.yml +++ b/.github/workflows/pr-python-connector-tests.yml @@ -20,6 +20,7 @@ env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} # Google GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }} + GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }} GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }} GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }} GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }} diff --git a/backend/danswer/connectors/confluence/connector.py b/backend/danswer/connectors/confluence/connector.py index 6a376e3d55d..9f1fe2e05f5 100644 --- a/backend/danswer/connectors/confluence/connector.py +++ b/backend/danswer/connectors/confluence/connector.py @@ -81,15 +81,15 @@ def __init__( if cql_query: # if a cql_query is provided, we will use it to fetch the pages cql_page_query = cql_query - elif space: - # if no cql_query is provided, we will use the space to fetch the pages - cql_page_query += f" and space='{quote(space)}'" elif page_id: + # if a cql_query is not provided, we will use the page_id to fetch the page if index_recursively: cql_page_query += f" and ancestor='{page_id}'" else: - # if neither a space nor a cql_query is provided, we will use the page_id to fetch the page cql_page_query += f" and id='{page_id}'" + elif space: + # if no cql_query or page_id is provided, we will use the space to fetch the pages + cql_page_query += f" and space='{quote(space)}'" self.cql_page_query = cql_page_query self.cql_time_filter = "" diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 614b465e74d..ad929eb0905 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -15,6 +15,7 @@ convert_drive_item_to_document, ) from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files +from danswer.connectors.google_drive.file_retrieval import get_all_files_for_oauth from danswer.connectors.google_drive.file_retrieval import get_all_files_in_my_drive from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive from danswer.connectors.google_drive.models import GoogleDriveFileType @@ -82,12 +83,31 @@ def _process_files_batch( yield doc_batch +def _clean_requested_drive_ids( + requested_drive_ids: set[str], + requested_folder_ids: set[str], + all_drive_ids_available: set[str], +) -> tuple[set[str], set[str]]: + invalid_requested_drive_ids = requested_drive_ids - all_drive_ids_available + filtered_folder_ids = requested_folder_ids - all_drive_ids_available + if invalid_requested_drive_ids: + logger.warning( + f"Some shared drive IDs were not found. IDs: {invalid_requested_drive_ids}" + ) + logger.warning("Checking for folder access instead...") + filtered_folder_ids.update(invalid_requested_drive_ids) + + valid_requested_drive_ids = requested_drive_ids - invalid_requested_drive_ids + return valid_requested_drive_ids, filtered_folder_ids + + class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): def __init__( self, - include_shared_drives: bool = True, + include_shared_drives: bool = False, + include_my_drives: bool = False, + include_files_shared_with_me: bool = False, shared_drive_urls: str | None = None, - include_my_drives: bool = True, my_drive_emails: str | None = None, shared_folder_urls: str | None = None, batch_size: int = INDEX_BATCH_SIZE, @@ -120,22 +140,36 @@ def __init__( if ( not include_shared_drives and not include_my_drives + and not include_files_shared_with_me and not shared_folder_urls + and not my_drive_emails + and not shared_drive_urls ): raise ValueError( - "At least one of include_shared_drives, include_my_drives," - " or shared_folder_urls must be true" + "Nothing to index. Please specify at least one of the following: " + "include_shared_drives, include_my_drives, include_files_shared_with_me, " + "shared_folder_urls, or my_drive_emails" ) self.batch_size = batch_size - self.include_shared_drives = include_shared_drives + specific_requests_made = False + if bool(shared_drive_urls) or bool(my_drive_emails) or bool(shared_folder_urls): + specific_requests_made = True + + self.include_files_shared_with_me = ( + False if specific_requests_made else include_files_shared_with_me + ) + self.include_my_drives = False if specific_requests_made else include_my_drives + self.include_shared_drives = ( + False if specific_requests_made else include_shared_drives + ) + shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls) self._requested_shared_drive_ids = set( _extract_ids_from_urls(shared_drive_url_list) ) - self.include_my_drives = include_my_drives self._requested_my_drive_emails = set( _extract_str_list_from_comma_str(my_drive_emails) ) @@ -225,26 +259,20 @@ def _get_all_drive_ids(self) -> set[str]: creds=self.creds, user_email=self.primary_admin_email, ) + is_service_account = isinstance(self.creds, ServiceAccountCredentials) all_drive_ids = set() - # We don't want to fail if we're using OAuth because you can - # access your my drive as a non admin user in an org still - ignore_fetch_failure = isinstance(self.creds, OAuthCredentials) for drive in execute_paginated_retrieval( retrieval_function=primary_drive_service.drives().list, list_key="drives", - continue_on_404_or_403=ignore_fetch_failure, - useDomainAdminAccess=True, + useDomainAdminAccess=is_service_account, fields="drives(id)", ): all_drive_ids.add(drive["id"]) if not all_drive_ids: logger.warning( - "No drives found. This is likely because oauth user " - "is not an admin and cannot view all drive IDs. " - "Continuing with only the shared drive IDs specified in the config." + "No drives found even though we are indexing shared drives was requested." ) - all_drive_ids = set(self._requested_shared_drive_ids) return all_drive_ids @@ -261,14 +289,9 @@ def _impersonate_user_for_retrieval( # if we are including my drives, try to get the current user's my # drive if any of the following are true: - # - no specific emails were requested + # - include_my_drives is true # - the current user's email is in the requested emails - # - we are using OAuth (in which case we assume that is the only email we will try) - if self.include_my_drives and ( - not self._requested_my_drive_emails - or user_email in self._requested_my_drive_emails - or isinstance(self.creds, OAuthCredentials) - ): + if self.include_my_drives or user_email in self._requested_my_drive_emails: yield from get_all_files_in_my_drive( service=drive_service, update_traversed_ids_func=self._update_traversed_parent_ids, @@ -299,7 +322,7 @@ def _impersonate_user_for_retrieval( end=end, ) - def _fetch_drive_items( + def _manage_service_account_retrieval( self, is_slim: bool, start: SecondsSinceUnixEpoch | None = None, @@ -309,29 +332,16 @@ def _fetch_drive_items( all_drive_ids: set[str] = self._get_all_drive_ids() - # remove drive ids from the folder ids because they are queried differently - filtered_folder_ids = self._requested_folder_ids - all_drive_ids - - # Remove drive_ids that are not in the all_drive_ids and check them as folders instead - invalid_drive_ids = self._requested_shared_drive_ids - all_drive_ids - if invalid_drive_ids: - logger.warning( - f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}" + drive_ids_to_retrieve: set[str] = set() + folder_ids_to_retrieve: set[str] = set() + if self._requested_shared_drive_ids or self._requested_folder_ids: + drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids( + requested_drive_ids=self._requested_shared_drive_ids, + requested_folder_ids=self._requested_folder_ids, + all_drive_ids_available=all_drive_ids, ) - logger.warning("Checking for folder access instead...") - filtered_folder_ids.update(invalid_drive_ids) - - # If including shared drives, use the requested IDs if provided, - # otherwise use all drive IDs - filtered_drive_ids = set() - if self.include_shared_drives: - if self._requested_shared_drive_ids: - # Remove invalid drive IDs from requested IDs - filtered_drive_ids = ( - self._requested_shared_drive_ids - invalid_drive_ids - ) - else: - filtered_drive_ids = all_drive_ids + elif self.include_shared_drives: + drive_ids_to_retrieve = all_drive_ids # Process users in parallel using ThreadPoolExecutor with ThreadPoolExecutor(max_workers=10) as executor: @@ -340,8 +350,8 @@ def _fetch_drive_items( self._impersonate_user_for_retrieval, email, is_slim, - filtered_drive_ids, - filtered_folder_ids, + drive_ids_to_retrieve, + folder_ids_to_retrieve, start, end, ): email @@ -353,13 +363,101 @@ def _fetch_drive_items( yield from future.result() remaining_folders = ( - filtered_drive_ids | filtered_folder_ids + drive_ids_to_retrieve | folder_ids_to_retrieve ) - self._retrieved_ids if remaining_folders: logger.warning( f"Some folders/drives were not retrieved. IDs: {remaining_folders}" ) + def _manage_oauth_retrieval( + self, + is_slim: bool, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> Iterator[GoogleDriveFileType]: + drive_service = get_drive_service(self.creds, self.primary_admin_email) + + if self.include_files_shared_with_me or self.include_my_drives: + yield from get_all_files_for_oauth( + service=drive_service, + include_files_shared_with_me=self.include_files_shared_with_me, + include_my_drives=self.include_my_drives, + include_shared_drives=self.include_shared_drives, + is_slim=is_slim, + start=start, + end=end, + ) + + all_requested = ( + self.include_files_shared_with_me + and self.include_my_drives + and self.include_shared_drives + ) + if all_requested: + # If all 3 are true, we already yielded from get_all_files_for_oauth + return + + all_drive_ids = self._get_all_drive_ids() + drive_ids_to_retrieve: set[str] = set() + folder_ids_to_retrieve: set[str] = set() + if self._requested_shared_drive_ids or self._requested_folder_ids: + drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids( + requested_drive_ids=self._requested_shared_drive_ids, + requested_folder_ids=self._requested_folder_ids, + all_drive_ids_available=all_drive_ids, + ) + elif self.include_shared_drives: + drive_ids_to_retrieve = all_drive_ids + + for drive_id in drive_ids_to_retrieve: + yield from get_files_in_shared_drive( + service=drive_service, + drive_id=drive_id, + is_slim=is_slim, + update_traversed_ids_func=self._update_traversed_parent_ids, + start=start, + end=end, + ) + + # Even if no folders were requested, we still check if any drives were requested + # that could be folders. + remaining_folders = folder_ids_to_retrieve - self._retrieved_ids + for folder_id in remaining_folders: + yield from crawl_folders_for_files( + service=drive_service, + parent_id=folder_id, + traversed_parent_ids=self._retrieved_ids, + update_traversed_ids_func=self._update_traversed_parent_ids, + start=start, + end=end, + ) + + remaining_folders = ( + drive_ids_to_retrieve | folder_ids_to_retrieve + ) - self._retrieved_ids + if remaining_folders: + logger.warning( + f"Some folders/drives were not retrieved. IDs: {remaining_folders}" + ) + + def _fetch_drive_items( + self, + is_slim: bool, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, + ) -> Iterator[GoogleDriveFileType]: + retrieval_method = ( + self._manage_service_account_retrieval + if isinstance(self.creds, ServiceAccountCredentials) + else self._manage_oauth_retrieval + ) + return retrieval_method( + is_slim=is_slim, + start=start, + end=end, + ) + def _extract_docs_from_google_drive( self, start: SecondsSinceUnixEpoch | None = None, diff --git a/backend/danswer/connectors/google_drive/file_retrieval.py b/backend/danswer/connectors/google_drive/file_retrieval.py index c0ed998561d..962d531b076 100644 --- a/backend/danswer/connectors/google_drive/file_retrieval.py +++ b/backend/danswer/connectors/google_drive/file_retrieval.py @@ -140,8 +140,8 @@ def get_files_in_shared_drive( ) -> Iterator[GoogleDriveFileType]: # If we know we are going to folder crawl later, we can cache the folders here # Get all folders being queried and add them to the traversed set - query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" - query += " and trashed = false" + folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" + folder_query += " and trashed = false" found_folders = False for file in execute_paginated_retrieval( retrieval_function=service.files().list, @@ -152,7 +152,7 @@ def get_files_in_shared_drive( supportsAllDrives=True, includeItemsFromAllDrives=True, fields="nextPageToken, files(id)", - q=query, + q=folder_query, ): update_traversed_ids_func(file["id"]) found_folders = True @@ -160,9 +160,9 @@ def get_files_in_shared_drive( update_traversed_ids_func(drive_id) # Get all files in the shared drive - query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" - query += " and trashed = false" - query += _generate_time_range_filter(start, end) + file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" + file_query += " and trashed = false" + file_query += _generate_time_range_filter(start, end) yield from execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", @@ -172,7 +172,7 @@ def get_files_in_shared_drive( supportsAllDrives=True, includeItemsFromAllDrives=True, fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, - q=query, + q=file_query, ) @@ -185,14 +185,16 @@ def get_all_files_in_my_drive( ) -> Iterator[GoogleDriveFileType]: # If we know we are going to folder crawl later, we can cache the folders here # Get all folders being queried and add them to the traversed set - query = "trashed = false and 'me' in owners" + folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" + folder_query += " and trashed = false" + folder_query += " and 'me' in owners" found_folders = False for file in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", corpora="user", fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, - q=query, + q=folder_query, ): update_traversed_ids_func(file["id"]) found_folders = True @@ -200,18 +202,52 @@ def get_all_files_in_my_drive( update_traversed_ids_func(get_root_folder_id(service)) # Then get the files - query = "trashed = false and 'me' in owners" - query += _generate_time_range_filter(start, end) - fields = "files(id, name, mimeType, webViewLink, modifiedTime, createdTime)" - if not is_slim: - fields += ", files(permissions, permissionIds, owners)" - + file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" + file_query += " and trashed = false" + file_query += " and 'me' in owners" + file_query += _generate_time_range_filter(start, end) yield from execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", corpora="user", fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, - q=query, + q=file_query, + ) + + +def get_all_files_for_oauth( + service: Any, + include_files_shared_with_me: bool, + include_my_drives: bool, + # One of the above 2 should be true + include_shared_drives: bool, + is_slim: bool = False, + start: SecondsSinceUnixEpoch | None = None, + end: SecondsSinceUnixEpoch | None = None, +) -> Iterator[GoogleDriveFileType]: + should_get_all = ( + include_shared_drives and include_my_drives and include_files_shared_with_me + ) + corpora = "allDrives" if should_get_all else "user" + + file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" + file_query += " and trashed = false" + file_query += _generate_time_range_filter(start, end) + + if not should_get_all: + if include_files_shared_with_me and not include_my_drives: + file_query += " and not 'me' in owners" + if not include_files_shared_with_me and include_my_drives: + file_query += " and 'me' in owners" + + yield from execute_paginated_retrieval( + retrieval_function=service.files().list, + list_key="files", + corpora=corpora, + includeItemsFromAllDrives=should_get_all, + supportsAllDrives=should_get_all, + fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS, + q=file_query, ) diff --git a/backend/tests/daily/connectors/google_drive/conftest.py b/backend/tests/daily/connectors/google_drive/conftest.py index 4b618b28e1d..4f525b24592 100644 --- a/backend/tests/daily/connectors/google_drive/conftest.py +++ b/backend/tests/daily/connectors/google_drive/conftest.py @@ -21,6 +21,16 @@ load_env_vars() +_USER_TO_OAUTH_CREDENTIALS_MAP = { + "admin@onyx-test.com": "GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR", + "test_user_1@onyx-test.com": "GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1", +} + +_USER_TO_SERVICE_ACCOUNT_CREDENTIALS_MAP = { + "admin@onyx-test.com": "GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR", +} + + def parse_credentials(env_str: str) -> dict: """ Parse a double-escaped JSON string from environment variables into a Python dictionary. @@ -48,23 +58,25 @@ def parse_credentials(env_str: str) -> dict: @pytest.fixture def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector]: def _connector_factory( - primary_admin_email: str = "admin@onyx-test.com", - include_shared_drives: bool = True, - shared_drive_urls: str | None = None, - include_my_drives: bool = True, - my_drive_emails: str | None = None, - shared_folder_urls: str | None = None, + primary_admin_email: str, + include_shared_drives: bool, + shared_drive_urls: str | None, + include_my_drives: bool, + my_drive_emails: str | None, + shared_folder_urls: str | None, + include_files_shared_with_me: bool, ) -> GoogleDriveConnector: print("Creating GoogleDriveConnector with OAuth credentials") connector = GoogleDriveConnector( include_shared_drives=include_shared_drives, shared_drive_urls=shared_drive_urls, include_my_drives=include_my_drives, + include_files_shared_with_me=include_files_shared_with_me, my_drive_emails=my_drive_emails, shared_folder_urls=shared_folder_urls, ) - json_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"] + json_string = os.environ[_USER_TO_OAUTH_CREDENTIALS_MAP[primary_admin_email]] refried_json_string = json.dumps(parse_credentials(json_string)) credentials_json = { @@ -82,12 +94,13 @@ def google_drive_service_acct_connector_factory() -> ( Callable[..., GoogleDriveConnector] ): def _connector_factory( - primary_admin_email: str = "admin@onyx-test.com", - include_shared_drives: bool = True, - shared_drive_urls: str | None = None, - include_my_drives: bool = True, - my_drive_emails: str | None = None, - shared_folder_urls: str | None = None, + primary_admin_email: str, + include_shared_drives: bool, + shared_drive_urls: str | None, + include_my_drives: bool, + my_drive_emails: str | None, + shared_folder_urls: str | None, + include_files_shared_with_me: bool, ) -> GoogleDriveConnector: print("Creating GoogleDriveConnector with service account credentials") connector = GoogleDriveConnector( @@ -96,9 +109,12 @@ def _connector_factory( include_my_drives=include_my_drives, my_drive_emails=my_drive_emails, shared_folder_urls=shared_folder_urls, + include_files_shared_with_me=include_files_shared_with_me, ) - json_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"] + json_string = os.environ[ + _USER_TO_SERVICE_ACCOUNT_CREDENTIALS_MAP[primary_admin_email] + ] refried_json_string = json.dumps(parse_credentials(json_string)) # Load Service Account Credentials diff --git a/backend/tests/daily/connectors/google_drive/consts_and_utils.py b/backend/tests/daily/connectors/google_drive/consts_and_utils.py index f808f6812e7..1df59a58101 100644 --- a/backend/tests/daily/connectors/google_drive/consts_and_utils.py +++ b/backend/tests/daily/connectors/google_drive/consts_and_utils.py @@ -7,14 +7,14 @@ ADMIN_FILE_IDS = list(range(0, 5)) -ADMIN_FOLDER_3_FILE_IDS = list(range(65, 70)) +ADMIN_FOLDER_3_FILE_IDS = list(range(65, 70)) # This folder is shared with test_user_1 TEST_USER_1_FILE_IDS = list(range(5, 10)) TEST_USER_2_FILE_IDS = list(range(10, 15)) TEST_USER_3_FILE_IDS = list(range(15, 20)) SHARED_DRIVE_1_FILE_IDS = list(range(20, 25)) FOLDER_1_FILE_IDS = list(range(25, 30)) FOLDER_1_1_FILE_IDS = list(range(30, 35)) -FOLDER_1_2_FILE_IDS = list(range(35, 40)) +FOLDER_1_2_FILE_IDS = list(range(35, 40)) # This folder is public SHARED_DRIVE_2_FILE_IDS = list(range(40, 45)) FOLDER_2_FILE_IDS = list(range(45, 50)) FOLDER_2_1_FILE_IDS = list(range(50, 55)) @@ -75,26 +75,29 @@ + FOLDER_2_2_FILE_IDS + SECTIONS_FILE_IDS ), - # This user has access to drive 1 - # This user has redundant access to folder 1 because of group access - # This user has been given individual access to files in Admin's My Drive TEST_USER_1_EMAIL: ( TEST_USER_1_FILE_IDS + # This user has access to drive 1 + SHARED_DRIVE_1_FILE_IDS + # This user has redundant access to folder 1 because of group access + FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS + # This user has been given shared access to folder 3 in Admin's My Drive + + ADMIN_FOLDER_3_FILE_IDS + # This user has been given shared access to files 0 and 1 in Admin's My Drive + list(range(0, 2)) ), - # Group 1 includes this user, giving access to folder 1 - # This user has also been given access to folder 2-1 - # This user has also been given individual access to files in folder 2 TEST_USER_2_EMAIL: ( TEST_USER_2_FILE_IDS + # Group 1 includes this user, giving access to folder 1 + FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + # This folder is public + FOLDER_1_2_FILE_IDS + # Folder 2-1 is shared with this user + FOLDER_2_1_FILE_IDS + # This user has been given shared access to files 45 and 46 in folder 2 + list(range(45, 47)) ), # This user can only see his own files and public files diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py b/backend/tests/daily/connectors/google_drive/test_admin_oauth.py similarity index 86% rename from backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py rename to backend/tests/daily/connectors/google_drive/test_admin_oauth.py index cf9776c1806..74625ed0b2d 100644 --- a/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py +++ b/backend/tests/daily/connectors/google_drive/test_admin_oauth.py @@ -5,6 +5,7 @@ from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.models import Document +from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import ( @@ -26,8 +27,6 @@ from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_FILE_IDS -from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL -from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_3_EMAIL @patch( @@ -40,8 +39,13 @@ def test_include_all( ) -> None: print("\n\nRunning test_include_all") connector = google_drive_oauth_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, include_my_drives=True, + include_files_shared_with_me=False, + shared_folder_urls=None, + my_drive_emails=None, + shared_drive_urls=None, ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): @@ -77,8 +81,13 @@ def test_include_shared_drives_only( ) -> None: print("\n\nRunning test_include_shared_drives_only") connector = google_drive_oauth_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, include_my_drives=False, + include_files_shared_with_me=False, + shared_folder_urls=None, + my_drive_emails=None, + shared_drive_urls=None, ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): @@ -112,8 +121,13 @@ def test_include_my_drives_only( ) -> None: print("\n\nRunning test_include_my_drives_only") connector = google_drive_oauth_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=False, include_my_drives=True, + include_files_shared_with_me=False, + shared_folder_urls=None, + my_drive_emails=None, + shared_drive_urls=None, ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): @@ -138,8 +152,12 @@ def test_drive_one_only( print("\n\nRunning test_drive_one_only") drive_urls = [SHARED_DRIVE_1_URL] connector = google_drive_oauth_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, include_my_drives=False, + include_files_shared_with_me=False, + shared_folder_urls=None, + my_drive_emails=None, shared_drive_urls=",".join([str(url) for url in drive_urls]), ) retrieved_docs: list[Document] = [] @@ -170,19 +188,20 @@ def test_folder_and_shared_drive( drive_urls = [SHARED_DRIVE_1_URL] folder_urls = [FOLDER_2_URL] connector = google_drive_oauth_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, - include_my_drives=True, - shared_drive_urls=",".join([str(url) for url in drive_urls]), + include_my_drives=False, + include_files_shared_with_me=False, shared_folder_urls=",".join([str(url) for url in folder_urls]), + my_drive_emails=None, + shared_drive_urls=",".join([str(url) for url in drive_urls]), ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): retrieved_docs.extend(doc_batch) expected_file_ids = ( - ADMIN_FILE_IDS - + ADMIN_FOLDER_3_FILE_IDS - + SHARED_DRIVE_1_FILE_IDS + SHARED_DRIVE_1_FILE_IDS + FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS @@ -216,10 +235,13 @@ def test_folders_only( FOLDER_1_1_URL, ] connector = google_drive_oauth_connector_factory( - include_shared_drives=False, + primary_admin_email=ADMIN_EMAIL, + include_shared_drives=True, include_my_drives=False, - shared_drive_urls=",".join([str(url) for url in shared_drive_urls]), + include_files_shared_with_me=False, shared_folder_urls=",".join([str(url) for url in folder_urls]), + my_drive_emails=None, + shared_drive_urls=",".join([str(url) for url in shared_drive_urls]), ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): @@ -238,37 +260,6 @@ def test_folders_only( ) -@patch( - "danswer.file_processing.extract_file_text.get_unstructured_api_key", - return_value=None, -) -def test_specific_emails( - mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], -) -> None: - print("\n\nRunning test_specific_emails") - my_drive_emails = [ - TEST_USER_1_EMAIL, - TEST_USER_3_EMAIL, - ] - connector = google_drive_oauth_connector_factory( - include_shared_drives=False, - include_my_drives=True, - my_drive_emails=",".join([str(email) for email in my_drive_emails]), - ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) - - # No matter who is specified, when using oauth, if include_my_drives is True, - # we will get all the files from the admin's My Drive - expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS - assert_retrieved_docs_match_expected( - retrieved_docs=retrieved_docs, - expected_file_ids=expected_file_ids, - ) - - @patch( "danswer.file_processing.extract_file_text.get_unstructured_api_key", return_value=None, @@ -282,9 +273,13 @@ def test_personal_folders_only( FOLDER_3_URL, ] connector = google_drive_oauth_connector_factory( - include_shared_drives=False, + primary_admin_email=ADMIN_EMAIL, + include_shared_drives=True, include_my_drives=False, + include_files_shared_with_me=False, shared_folder_urls=",".join([str(url) for url in folder_urls]), + my_drive_emails=None, + shared_drive_urls=None, ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_sections.py b/backend/tests/daily/connectors/google_drive/test_sections.py similarity index 88% rename from backend/tests/daily/connectors/google_drive/test_google_drive_sections.py rename to backend/tests/daily/connectors/google_drive/test_sections.py index b6d2423a092..989bf9e9e7b 100644 --- a/backend/tests/daily/connectors/google_drive/test_google_drive_sections.py +++ b/backend/tests/daily/connectors/google_drive/test_sections.py @@ -5,6 +5,7 @@ from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.models import Document +from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_URL @@ -18,14 +19,22 @@ def test_google_drive_sections( google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: oauth_connector = google_drive_oauth_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=False, include_my_drives=False, + include_files_shared_with_me=False, shared_folder_urls=SECTIONS_FOLDER_URL, + shared_drive_urls=None, + my_drive_emails=None, ) service_acct_connector = google_drive_service_acct_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=False, include_my_drives=False, + include_files_shared_with_me=False, shared_folder_urls=SECTIONS_FOLDER_URL, + shared_drive_urls=None, + my_drive_emails=None, ) for connector in [oauth_connector, service_acct_connector]: retrieved_docs: list[Document] = [] diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py b/backend/tests/daily/connectors/google_drive/test_service_acct.py similarity index 88% rename from backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py rename to backend/tests/daily/connectors/google_drive/test_service_acct.py index 5dcbb7f5b64..602849a41dc 100644 --- a/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py +++ b/backend/tests/daily/connectors/google_drive/test_service_acct.py @@ -5,6 +5,7 @@ from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.models import Document +from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS from tests.daily.connectors.google_drive.consts_and_utils import ( @@ -43,8 +44,13 @@ def test_include_all( ) -> None: print("\n\nRunning test_include_all") connector = google_drive_service_acct_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, include_my_drives=True, + include_files_shared_with_me=False, + shared_folder_urls=None, + shared_drive_urls=None, + my_drive_emails=None, ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): @@ -83,8 +89,13 @@ def test_include_shared_drives_only( ) -> None: print("\n\nRunning test_include_shared_drives_only") connector = google_drive_service_acct_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, include_my_drives=False, + include_files_shared_with_me=False, + shared_folder_urls=None, + shared_drive_urls=None, + my_drive_emails=None, ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): @@ -118,8 +129,13 @@ def test_include_my_drives_only( ) -> None: print("\n\nRunning test_include_my_drives_only") connector = google_drive_service_acct_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=False, include_my_drives=True, + include_files_shared_with_me=False, + shared_folder_urls=None, + shared_drive_urls=None, + my_drive_emails=None, ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): @@ -150,9 +166,13 @@ def test_drive_one_only( print("\n\nRunning test_drive_one_only") urls = [SHARED_DRIVE_1_URL] connector = google_drive_service_acct_connector_factory( - include_shared_drives=True, + primary_admin_email=ADMIN_EMAIL, + include_shared_drives=False, include_my_drives=False, + include_files_shared_with_me=False, + shared_folder_urls=None, shared_drive_urls=",".join([str(url) for url in urls]), + my_drive_emails=None, ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): @@ -183,10 +203,13 @@ def test_folder_and_shared_drive( drive_urls = [SHARED_DRIVE_1_URL] folder_urls = [FOLDER_2_URL] connector = google_drive_service_acct_connector_factory( - include_shared_drives=True, - include_my_drives=True, + primary_admin_email=ADMIN_EMAIL, + include_shared_drives=False, + include_my_drives=False, + include_files_shared_with_me=False, shared_drive_urls=",".join([str(url) for url in drive_urls]), shared_folder_urls=",".join([str(url) for url in folder_urls]), + my_drive_emails=None, ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): @@ -194,12 +217,7 @@ def test_folder_and_shared_drive( # Should get everything except for the top level files in drive 2 expected_file_ids = ( - ADMIN_FILE_IDS - + ADMIN_FOLDER_3_FILE_IDS - + TEST_USER_1_FILE_IDS - + TEST_USER_2_FILE_IDS - + TEST_USER_3_FILE_IDS - + SHARED_DRIVE_1_FILE_IDS + SHARED_DRIVE_1_FILE_IDS + FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS @@ -233,10 +251,13 @@ def test_folders_only( FOLDER_1_1_URL, ] connector = google_drive_service_acct_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=False, include_my_drives=False, + include_files_shared_with_me=False, shared_drive_urls=",".join([str(url) for url in shared_drive_urls]), shared_folder_urls=",".join([str(url) for url in folder_urls]), + my_drive_emails=None, ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): @@ -269,8 +290,12 @@ def test_specific_emails( TEST_USER_3_EMAIL, ] connector = google_drive_service_acct_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=False, - include_my_drives=True, + include_my_drives=False, + include_files_shared_with_me=False, + shared_folder_urls=None, + shared_drive_urls=None, my_drive_emails=",".join([str(email) for email in my_drive_emails]), ) retrieved_docs: list[Document] = [] @@ -293,42 +318,17 @@ def get_specific_folders_in_my_drive( google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: print("\n\nRunning get_specific_folders_in_my_drive") - my_drive_emails = [ - TEST_USER_1_EMAIL, - TEST_USER_3_EMAIL, - ] - connector = google_drive_service_acct_connector_factory( - include_shared_drives=False, - include_my_drives=True, - my_drive_emails=",".join([str(email) for email in my_drive_emails]), - ) - retrieved_docs: list[Document] = [] - for doc_batch in connector.poll_source(0, time.time()): - retrieved_docs.extend(doc_batch) - - expected_file_ids = TEST_USER_1_FILE_IDS + TEST_USER_3_FILE_IDS - assert_retrieved_docs_match_expected( - retrieved_docs=retrieved_docs, - expected_file_ids=expected_file_ids, - ) - - -@patch( - "danswer.file_processing.extract_file_text.get_unstructured_api_key", - return_value=None, -) -def test_personal_folders_only( - mock_get_api_key: MagicMock, - google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], -) -> None: - print("\n\nRunning test_personal_folders_only") folder_urls = [ FOLDER_3_URL, ] - connector = google_drive_oauth_connector_factory( + connector = google_drive_service_acct_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=False, include_my_drives=False, + include_files_shared_with_me=False, shared_folder_urls=",".join([str(url) for url in folder_urls]), + shared_drive_urls=None, + my_drive_emails=None, ) retrieved_docs: list[Document] = [] for doc_batch in connector.poll_source(0, time.time()): diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py b/backend/tests/daily/connectors/google_drive/test_slim_docs.py similarity index 97% rename from backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py rename to backend/tests/daily/connectors/google_drive/test_slim_docs.py index c9d70d814a1..7a836421317 100644 --- a/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py +++ b/backend/tests/daily/connectors/google_drive/test_slim_docs.py @@ -126,8 +126,13 @@ def test_all_permissions( google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], ) -> None: google_drive_connector = google_drive_service_acct_connector_factory( + primary_admin_email=ADMIN_EMAIL, include_shared_drives=True, include_my_drives=True, + include_files_shared_with_me=False, + shared_folder_urls=None, + shared_drive_urls=None, + my_drive_emails=None, ) access_map: dict[str, ExternalAccess] = {} diff --git a/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py b/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py new file mode 100644 index 00000000000..9f173536269 --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/test_user_1_oauth.py @@ -0,0 +1,218 @@ +import time +from collections.abc import Callable +from unittest.mock import MagicMock +from unittest.mock import patch + +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.models import Document +from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS +from tests.daily.connectors.google_drive.consts_and_utils import ( + assert_retrieved_docs_match_expected, +) +from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS +from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS +from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS +from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_URL +from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL +from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS +from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL +from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FILE_IDS + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_all( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_all") + connector = google_drive_oauth_connector_factory( + primary_admin_email=TEST_USER_1_EMAIL, + include_files_shared_with_me=True, + include_shared_drives=True, + include_my_drives=True, + shared_folder_urls=None, + shared_drive_urls=None, + my_drive_emails=None, + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + expected_file_ids = ( + # These are the files from my drive + TEST_USER_1_FILE_IDS + # These are the files from shared drives + + SHARED_DRIVE_1_FILE_IDS + + FOLDER_1_FILE_IDS + + FOLDER_1_1_FILE_IDS + + FOLDER_1_2_FILE_IDS + # These are the files shared with me from admin + + ADMIN_FOLDER_3_FILE_IDS + + list(range(0, 2)) + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_shared_drives_only( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_shared_drives_only") + connector = google_drive_oauth_connector_factory( + primary_admin_email=TEST_USER_1_EMAIL, + include_files_shared_with_me=False, + include_shared_drives=True, + include_my_drives=False, + shared_folder_urls=None, + shared_drive_urls=None, + my_drive_emails=None, + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + expected_file_ids = ( + # These are the files from shared drives + SHARED_DRIVE_1_FILE_IDS + + FOLDER_1_FILE_IDS + + FOLDER_1_1_FILE_IDS + + FOLDER_1_2_FILE_IDS + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_shared_with_me_only( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_shared_with_me_only") + connector = google_drive_oauth_connector_factory( + primary_admin_email=TEST_USER_1_EMAIL, + include_files_shared_with_me=True, + include_shared_drives=False, + include_my_drives=False, + shared_folder_urls=None, + shared_drive_urls=None, + my_drive_emails=None, + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + expected_file_ids = ( + # These are the files shared with me from admin + ADMIN_FOLDER_3_FILE_IDS + + list(range(0, 2)) + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_my_drive_only( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_my_drive_only") + connector = google_drive_oauth_connector_factory( + primary_admin_email=TEST_USER_1_EMAIL, + include_files_shared_with_me=False, + include_shared_drives=False, + include_my_drives=True, + shared_folder_urls=None, + shared_drive_urls=None, + my_drive_emails=None, + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # These are the files from my drive + expected_file_ids = TEST_USER_1_FILE_IDS + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_shared_my_drive_folder( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_shared_my_drive_folder") + connector = google_drive_oauth_connector_factory( + primary_admin_email=TEST_USER_1_EMAIL, + include_files_shared_with_me=False, + include_shared_drives=False, + include_my_drives=True, + shared_folder_urls=FOLDER_3_URL, + shared_drive_urls=None, + my_drive_emails=None, + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + expected_file_ids = ( + # this is a folder from admin's drive that is shared with me + ADMIN_FOLDER_3_FILE_IDS + ) + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_shared_drive_folder( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + print("\n\nRunning test_shared_drive_folder") + connector = google_drive_oauth_connector_factory( + primary_admin_email=TEST_USER_1_EMAIL, + include_files_shared_with_me=False, + include_shared_drives=False, + include_my_drives=True, + shared_folder_urls=FOLDER_1_URL, + shared_drive_urls=None, + my_drive_emails=None, + ) + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + expected_file_ids = FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS + assert_retrieved_docs_match_expected( + retrieved_docs=retrieved_docs, + expected_file_ids=expected_file_ids, + ) diff --git a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx index 857ef7ec4d0..0b7d961aeef 100644 --- a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx @@ -1,20 +1,13 @@ import React, { Dispatch, FC, SetStateAction, useState } from "react"; -import CredentialSubText, { - AdminBooleanFormField, -} from "@/components/credentials/CredentialFields"; -import { FileUpload } from "@/components/admin/connectors/FileUpload"; +import CredentialSubText from "@/components/credentials/CredentialFields"; import { ConnectionConfiguration } from "@/lib/connectors/connectors"; -import SelectInput from "./ConnectorInput/SelectInput"; -import NumberInput from "./ConnectorInput/NumberInput"; import { TextFormField } from "@/components/admin/connectors/Field"; -import ListInput from "./ConnectorInput/ListInput"; -import FileInput from "./ConnectorInput/FileInput"; import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle"; import { AccessTypeForm } from "@/components/admin/connectors/AccessTypeForm"; import { AccessTypeGroupSelector } from "@/components/admin/connectors/AccessTypeGroupSelector"; import { ConfigurableSources } from "@/lib/types"; import { Credential } from "@/lib/connectors/credentials"; -import CollapsibleSection from "@/app/admin/assistants/CollapsibleSection"; +import { RenderField } from "./FieldRendering"; export interface DynamicConnectionFormProps { config: ConnectionConfiguration; @@ -25,105 +18,6 @@ export interface DynamicConnectionFormProps { currentCredential: Credential | null; } -interface RenderFieldProps { - field: any; - values: any; - selectedFiles: File[]; - setSelectedFiles: Dispatch>; - connector: ConfigurableSources; - currentCredential: Credential | null; -} - -const RenderField: FC = ({ - field, - values, - selectedFiles, - setSelectedFiles, - connector, - currentCredential, -}) => { - if ( - field.visibleCondition && - !field.visibleCondition(values, currentCredential) - ) { - return null; - } - - const label = - typeof field.label === "function" - ? field.label(currentCredential) - : field.label; - const description = - typeof field.description === "function" - ? field.description(currentCredential) - : field.description; - - const fieldContent = ( - <> - {field.type === "file" ? ( - - ) : field.type === "zip" ? ( - - ) : field.type === "list" ? ( - - ) : field.type === "select" ? ( - - ) : field.type === "number" ? ( - - ) : field.type === "checkbox" ? ( - - ) : ( - - )} - - ); - - if ( - field.visibleCondition && - field.visibleCondition(values, currentCredential) - ) { - return ( - - {fieldContent} - - ); - } else { - return
{fieldContent}
; - } -}; - const DynamicConnectionForm: FC = ({ config, selectedFiles, @@ -136,14 +30,12 @@ const DynamicConnectionForm: FC = ({ return ( <> -

{config.description}

- {config.subtext && ( {config.subtext} )} = ({ setShowAdvancedOptions={setShowAdvancedOptions} /> {showAdvancedOptions && - config.advanced_values.map((field) => ( - - ))} + config.advanced_values.map( + (field) => + !field.hidden && ( + + ) + )} )} diff --git a/web/src/app/admin/connectors/[connector]/pages/FieldRendering.tsx b/web/src/app/admin/connectors/[connector]/pages/FieldRendering.tsx new file mode 100644 index 00000000000..c3c862a8119 --- /dev/null +++ b/web/src/app/admin/connectors/[connector]/pages/FieldRendering.tsx @@ -0,0 +1,224 @@ +import React, { Dispatch, FC, SetStateAction } from "react"; +import { AdminBooleanFormField } from "@/components/credentials/CredentialFields"; +import { FileUpload } from "@/components/admin/connectors/FileUpload"; +import { TabOption } from "@/lib/connectors/connectors"; +import SelectInput from "./ConnectorInput/SelectInput"; +import NumberInput from "./ConnectorInput/NumberInput"; +import { TextFormField } from "@/components/admin/connectors/Field"; +import ListInput from "./ConnectorInput/ListInput"; +import FileInput from "./ConnectorInput/FileInput"; +import { ConfigurableSources } from "@/lib/types"; +import { Credential } from "@/lib/connectors/credentials"; +import CollapsibleSection from "@/app/admin/assistants/CollapsibleSection"; +import { + Tabs, + TabsContent, + TabsList, + TabsTrigger, +} from "@/components/ui/fully_wrapped_tabs"; + +interface TabsFieldProps { + tabField: TabOption; + values: any; + selectedFiles: File[]; + setSelectedFiles: Dispatch>; + connector: ConfigurableSources; + currentCredential: Credential | null; +} + +const TabsField: FC = ({ + tabField, + values, + selectedFiles, + setSelectedFiles, + connector, + currentCredential, +}) => { + return ( +
+ {tabField.label && ( +
+

+ {typeof tabField.label === "function" + ? tabField.label(currentCredential) + : tabField.label} +

+ {tabField.description && ( +

+ {typeof tabField.description === "function" + ? tabField.description(currentCredential) + : tabField.description} +

+ )} +
+ )} + + { + // Clear values from other tabs but preserve defaults + tabField.tabs.forEach((tab) => { + if (tab.value !== newTab) { + tab.fields.forEach((field) => { + // Only clear if not default value + if (values[field.name] !== field.default) { + values[field.name] = field.default; + } + }); + } + }); + }} + > + + {tabField.tabs.map((tab) => ( + + {tab.label} + + ))} + + {tabField.tabs.map((tab) => ( + + {tab.fields.map((subField, index, array) => { + // Check visibility condition first + if ( + subField.visibleCondition && + !subField.visibleCondition(values, currentCredential) + ) { + return null; + } + + return ( +
+ +
+ ); + })} +
+ ))} +
+
+ ); +}; + +interface RenderFieldProps { + field: any; + values: any; + selectedFiles: File[]; + setSelectedFiles: Dispatch>; + connector: ConfigurableSources; + currentCredential: Credential | null; +} + +export const RenderField: FC = ({ + field, + values, + selectedFiles, + setSelectedFiles, + connector, + currentCredential, +}) => { + const label = + typeof field.label === "function" + ? field.label(currentCredential) + : field.label; + const description = + typeof field.description === "function" + ? field.description(currentCredential) + : field.description; + + if (field.type === "tab") { + return ( + + ); + } + + const fieldContent = ( + <> + {field.type === "file" ? ( + + ) : field.type === "zip" ? ( + + ) : field.type === "list" ? ( + + ) : field.type === "select" ? ( + + ) : field.type === "number" ? ( + + ) : field.type === "checkbox" ? ( + + ) : field.type === "text" ? ( + + ) : field.type === "string_tab" ? ( +
{description}
+ ) : ( + <>INVALID FIELD TYPE + )} + + ); + + if (field.wrapInCollapsible) { + return ( + + {fieldContent} + + ); + } + + return
{fieldContent}
; +}; diff --git a/web/src/components/credentials/CredentialFields.tsx b/web/src/components/credentials/CredentialFields.tsx index 3d25fcb75fe..d23b30834f3 100644 --- a/web/src/components/credentials/CredentialFields.tsx +++ b/web/src/components/credentials/CredentialFields.tsx @@ -1,4 +1,4 @@ -import { ErrorMessage, Field } from "formik"; +import { ErrorMessage, Field, useField } from "formik"; import { ExplanationText, @@ -96,18 +96,18 @@ export function AdminTextField({ name={name} id={name} className={` - ${small && "text-sm"} - border - border-border - rounded - w-full - bg-input - py-2 - px-3 - mt-1 - ${heightString} - ${fontSize} - ${isCode ? " font-mono" : ""} + ${small && "text-sm"} + border + border-border + rounded + w-full + bg-input + py-2 + px-3 + mt-1 + ${heightString} + ${fontSize} + ${isCode ? " font-mono" : ""} `} disabled={disabled} placeholder={placeholder} @@ -143,13 +143,18 @@ export const AdminBooleanFormField = ({ alignTop, onChange, }: BooleanFormFieldProps) => { + const [field, meta, helpers] = useField(name); + return (