diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 4ddd51f749f..8b2433d31c8 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -3,8 +3,6 @@ from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore -from googleapiclient.discovery import build # type: ignore -from googleapiclient.discovery import Resource # type: ignore from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.connectors.google_drive.connector_auth import ( @@ -24,6 +22,9 @@ from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval from danswer.connectors.google_drive.models import GoogleDriveFileType +from danswer.connectors.google_drive.resources import get_admin_service +from danswer.connectors.google_drive.resources import get_drive_service +from danswer.connectors.google_drive.resources import get_google_docs_service from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.interfaces import LoadConnector @@ -103,42 +104,49 @@ def __init__( shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls) self.shared_folder_ids = _extract_ids_from_urls(shared_folder_url_list) - self.primary_admin_email: str | None = None + self._primary_admin_email: str | None = None self.google_domain: str | None = None - self.creds: OAuthCredentials | ServiceAccountCredentials | None = None + self._creds: OAuthCredentials | ServiceAccountCredentials | None = None self._TRAVERSED_PARENT_IDS: set[str] = set() + @property + def primary_admin_email(self) -> str: + if self._primary_admin_email is None: + raise RuntimeError( + "Primary admin email missing, " + "should not call this property " + "before calling load_credentials" + ) + return self._primary_admin_email + + @property + def creds(self) -> OAuthCredentials | ServiceAccountCredentials: + if self._creds is None: + raise RuntimeError( + "Creds missing, " + "should not call this property " + "before calling load_credentials" + ) + return self._creds + def _update_traversed_parent_ids(self, folder_id: str) -> None: self._TRAVERSED_PARENT_IDS.add(folder_id) def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] self.google_domain = primary_admin_email.split("@")[1] - self.primary_admin_email = primary_admin_email + self._primary_admin_email = primary_admin_email - self.creds, new_creds_dict = get_google_drive_creds(credentials) + self._creds, new_creds_dict = get_google_drive_creds(credentials) return new_creds_dict - def get_google_resource( - self, - service_name: str = "drive", - service_version: str = "v3", - user_email: str | None = None, - ) -> Resource: - if isinstance(self.creds, ServiceAccountCredentials): - creds = self.creds.with_subject(user_email or self.primary_admin_email) - service = build(service_name, service_version, credentials=creds) - elif isinstance(self.creds, OAuthCredentials): - service = build(service_name, service_version, credentials=self.creds) - else: - raise PermissionError("No credentials found") - - return service - def _get_all_user_emails(self) -> list[str]: - admin_service = self.get_google_resource("admin", "directory_v1") + admin_service = get_admin_service( + creds=self.creds, + user_email=self.primary_admin_email, + ) emails = [] for user in execute_paginated_retrieval( retrieval_function=admin_service.users().list, @@ -156,7 +164,10 @@ def _fetch_drive_items( start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: - primary_drive_service = self.get_google_resource() + primary_drive_service = get_drive_service( + creds=self.creds, + user_email=self.primary_admin_email, + ) if self.include_shared_drives: shared_drive_urls = self.shared_drive_ids @@ -212,7 +223,7 @@ def _fetch_drive_items( for email in all_user_emails: logger.info(f"Fetching personal files for user: {email}") - user_drive_service = self.get_google_resource(user_email=email) + user_drive_service = get_drive_service(self.creds, user_email=email) yield from get_files_in_my_drive( service=user_drive_service, @@ -233,11 +244,16 @@ def _extract_docs_from_google_drive( start=start, end=end, ): - user_email = file.get("owners", [{}])[0].get("emailAddress") - service = self.get_google_resource(user_email=user_email) + user_email = ( + file.get("owners", [{}])[0].get("emailAddress") + or self.primary_admin_email + ) + user_drive_service = get_drive_service(self.creds, user_email=user_email) + docs_service = get_google_docs_service(self.creds, user_email=user_email) if doc := convert_drive_item_to_document( file=file, - service=service, + drive_service=user_drive_service, + docs_service=docs_service, ): doc_batch.append(doc) if len(doc_batch) >= self.batch_size: diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index 80cbda6772a..6ea2784662b 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -28,6 +28,8 @@ logger = setup_logger() +# NOTE: do not need https://www.googleapis.com/auth/documents.readonly +# this is counted under `/auth/drive.readonly` GOOGLE_DRIVE_SCOPES = [ "https://www.googleapis.com/auth/drive.readonly", "https://www.googleapis.com/auth/drive.metadata.readonly", diff --git a/backend/danswer/connectors/google_drive/doc_conversion.py b/backend/danswer/connectors/google_drive/doc_conversion.py index 688190c2267..81e709f53f4 100644 --- a/backend/danswer/connectors/google_drive/doc_conversion.py +++ b/backend/danswer/connectors/google_drive/doc_conversion.py @@ -2,7 +2,6 @@ from datetime import datetime from datetime import timezone -from googleapiclient.discovery import Resource # type: ignore from googleapiclient.errors import HttpError # type: ignore from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE @@ -13,6 +12,9 @@ from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT from danswer.connectors.google_drive.models import GDriveMimeType from danswer.connectors.google_drive.models import GoogleDriveFileType +from danswer.connectors.google_drive.resources import GoogleDocsService +from danswer.connectors.google_drive.resources import GoogleDriveService +from danswer.connectors.google_drive.section_extraction import get_document_sections from danswer.connectors.models import Document from danswer.connectors.models import Section from danswer.file_processing.extract_file_text import docx_to_text @@ -25,86 +27,137 @@ logger = setup_logger() -def _extract_text(file: dict[str, str], service: Resource) -> str: +def _extract_sections_basic( + file: dict[str, str], service: GoogleDriveService +) -> list[Section]: mime_type = file["mimeType"] + link = file["webViewLink"] if mime_type not in set(item.value for item in GDriveMimeType): # Unsupported file types can still have a title, finding this way is still useful - return UNSUPPORTED_FILE_TYPE_CONTENT - - if mime_type in [ - GDriveMimeType.DOC.value, - GDriveMimeType.PPT.value, - GDriveMimeType.SPREADSHEET.value, - ]: - export_mime_type = ( - "text/plain" - if mime_type != GDriveMimeType.SPREADSHEET.value - else "text/csv" - ) - return ( - service.files() - .export(fileId=file["id"], mimeType=export_mime_type) - .execute() - .decode("utf-8") - ) - elif mime_type in [ - GDriveMimeType.PLAIN_TEXT.value, - GDriveMimeType.MARKDOWN.value, - ]: - return service.files().get_media(fileId=file["id"]).execute().decode("utf-8") - if mime_type in [ - GDriveMimeType.WORD_DOC.value, - GDriveMimeType.POWERPOINT.value, - GDriveMimeType.PDF.value, - ]: - response = service.files().get_media(fileId=file["id"]).execute() - if get_unstructured_api_key(): - return unstructured_to_text( - file=io.BytesIO(response), file_name=file.get("name", file["id"]) + return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)] + + try: + if mime_type in [ + GDriveMimeType.DOC.value, + GDriveMimeType.PPT.value, + GDriveMimeType.SPREADSHEET.value, + ]: + export_mime_type = ( + "text/plain" + if mime_type != GDriveMimeType.SPREADSHEET.value + else "text/csv" ) + text = ( + service.files() + .export(fileId=file["id"], mimeType=export_mime_type) + .execute() + .decode("utf-8") + ) + return [Section(link=link, text=text)] + elif mime_type in [ + GDriveMimeType.PLAIN_TEXT.value, + GDriveMimeType.MARKDOWN.value, + ]: + return [ + Section( + link=link, + text=service.files() + .get_media(fileId=file["id"]) + .execute() + .decode("utf-8"), + ) + ] + if mime_type in [ + GDriveMimeType.WORD_DOC.value, + GDriveMimeType.POWERPOINT.value, + GDriveMimeType.PDF.value, + ]: + response = service.files().get_media(fileId=file["id"]).execute() + if get_unstructured_api_key(): + return [ + Section( + link=link, + text=unstructured_to_text( + file=io.BytesIO(response), + file_name=file.get("name", file["id"]), + ), + ) + ] + + if mime_type == GDriveMimeType.WORD_DOC.value: + return [ + Section(link=link, text=docx_to_text(file=io.BytesIO(response))) + ] + elif mime_type == GDriveMimeType.PDF.value: + text, _ = read_pdf_file(file=io.BytesIO(response)) + return [Section(link=link, text=text)] + elif mime_type == GDriveMimeType.POWERPOINT.value: + return [ + Section(link=link, text=pptx_to_text(file=io.BytesIO(response))) + ] - if mime_type == GDriveMimeType.WORD_DOC.value: - return docx_to_text(file=io.BytesIO(response)) - elif mime_type == GDriveMimeType.PDF.value: - text, _ = read_pdf_file(file=io.BytesIO(response)) - return text - elif mime_type == GDriveMimeType.POWERPOINT.value: - return pptx_to_text(file=io.BytesIO(response)) + return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)] - return UNSUPPORTED_FILE_TYPE_CONTENT + except Exception: + return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)] def convert_drive_item_to_document( - file: GoogleDriveFileType, service: Resource + file: GoogleDriveFileType, + drive_service: GoogleDriveService, + docs_service: GoogleDocsService, ) -> Document | None: try: # Skip files that are shortcuts if file.get("mimeType") == DRIVE_SHORTCUT_TYPE: logger.info("Ignoring Drive Shortcut Filetype") return None - try: - text_contents = _extract_text(file, service) or "" - except HttpError as e: - reason = e.error_details[0]["reason"] if e.error_details else e.reason - message = e.error_details[0]["message"] if e.error_details else e.reason - if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON: + + sections: list[Section] = [] + + # Special handling for Google Docs to preserve structure, link + # to headers + if file.get("mimeType") == GDriveMimeType.DOC.value: + try: + sections = get_document_sections(docs_service, file["id"]) + except Exception as e: logger.warning( - f"Could not export file '{file['name']}' due to '{message}', skipping..." + f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'." + " Falling back to basic extraction." ) - return None - raise + # NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc + if not sections: + try: + # For all other file types just extract the text + sections = _extract_sections_basic(file, drive_service) + + except HttpError as e: + reason = e.error_details[0]["reason"] if e.error_details else e.reason + message = e.error_details[0]["message"] if e.error_details else e.reason + if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON: + logger.warning( + f"Could not export file '{file['name']}' due to '{message}', skipping..." + ) + return None + + raise + + if not sections: + return None return Document( id=file["webViewLink"], - sections=[Section(link=file["webViewLink"], text=text_contents)], + sections=sections, source=DocumentSource.GOOGLE_DRIVE, semantic_identifier=file["name"], doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone( timezone.utc ), - metadata={} if text_contents else {IGNORE_FOR_QA: "True"}, + metadata={} + if any(section.text for section in sections) + else {IGNORE_FOR_QA: "True"}, additional_info=file.get("id"), ) except Exception as e: diff --git a/backend/danswer/connectors/google_drive/google_utils.py b/backend/danswer/connectors/google_drive/google_utils.py index 5f772e5ad63..a2e029a41df 100644 --- a/backend/danswer/connectors/google_drive/google_utils.py +++ b/backend/danswer/connectors/google_drive/google_utils.py @@ -28,7 +28,7 @@ def execute_paginated_retrieval( if next_page_token: request_kwargs["pageToken"] = next_page_token - results = add_retries(lambda: retrieval_function(**request_kwargs).execute())() + results = (lambda: retrieval_function(**request_kwargs).execute())() next_page_token = results.get("nextPageToken") for item in results.get(list_key, []): diff --git a/backend/danswer/connectors/google_drive/resources.py b/backend/danswer/connectors/google_drive/resources.py new file mode 100644 index 00000000000..df5b36376b5 --- /dev/null +++ b/backend/danswer/connectors/google_drive/resources.py @@ -0,0 +1,52 @@ +from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore +from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore +from googleapiclient.discovery import build # type: ignore +from googleapiclient.discovery import Resource # type: ignore + + +class GoogleDriveService(Resource): + pass + + +class GoogleDocsService(Resource): + pass + + +class AdminService(Resource): + pass + + +def _get_google_service( + service_name: str, + service_version: str, + creds: ServiceAccountCredentials | OAuthCredentials, + user_email: str | None = None, +) -> GoogleDriveService: + if isinstance(creds, ServiceAccountCredentials): + creds = creds.with_subject(user_email) + service = build(service_name, service_version, credentials=creds) + elif isinstance(creds, OAuthCredentials): + service = build(service_name, service_version, credentials=creds) + + return service + + +def get_google_docs_service( + creds: ServiceAccountCredentials | OAuthCredentials, + user_email: str | None = None, +) -> GoogleDocsService: + return _get_google_service("docs", "v1", creds, user_email) + + +def get_drive_service( + creds: ServiceAccountCredentials | OAuthCredentials, + user_email: str | None = None, +) -> GoogleDriveService: + return _get_google_service("drive", "v3", creds, user_email) + + +def get_admin_service( + creds: ServiceAccountCredentials | OAuthCredentials, + user_email: str, +) -> AdminService: + return _get_google_service("admin", "directory_v1", creds, user_email) diff --git a/backend/danswer/connectors/google_drive/section_extraction.py b/backend/danswer/connectors/google_drive/section_extraction.py new file mode 100644 index 00000000000..bcd162b1c79 --- /dev/null +++ b/backend/danswer/connectors/google_drive/section_extraction.py @@ -0,0 +1,105 @@ +from typing import Any + +from pydantic import BaseModel + +from danswer.connectors.google_drive.resources import GoogleDocsService +from danswer.connectors.models import Section + + +class CurrentHeading(BaseModel): + id: str + text: str + + +def _build_gdoc_section_link(doc_id: str, heading_id: str) -> str: + """Builds a Google Doc link that jumps to a specific heading""" + # NOTE: doesn't support docs with multiple tabs atm, if we need that ask + # @Chris + return ( + f"https://docs.google.com/document/d/{doc_id}/edit?tab=t.0#heading={heading_id}" + ) + + +def _extract_id_from_heading(paragraph: dict[str, Any]) -> str: + """Extracts the id from a heading paragraph element""" + return paragraph["paragraphStyle"]["headingId"] + + +def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str: + """Extracts the text content from a paragraph element""" + text_elements = [] + for element in paragraph.get("elements", []): + if "textRun" in element: + text_elements.append(element["textRun"].get("content", "")) + return "".join(text_elements) + + +def get_document_sections( + docs_service: GoogleDocsService, + doc_id: str, +) -> list[Section]: + """Extracts sections from a Google Doc, including their headings and content""" + # Fetch the document structure + doc = docs_service.documents().get(documentId=doc_id).execute() + + # Get the content + content = doc.get("body", {}).get("content", []) + + sections: list[Section] = [] + current_section: list[str] = [] + current_heading: CurrentHeading | None = None + + for element in content: + if "paragraph" not in element: + continue + + paragraph = element["paragraph"] + + # Check if this is a heading + if ( + "paragraphStyle" in paragraph + and "namedStyleType" in paragraph["paragraphStyle"] + ): + style = paragraph["paragraphStyle"]["namedStyleType"] + is_heading = style.startswith("HEADING_") + is_title = style.startswith("TITLE") + + if is_heading or is_title: + # If we were building a previous section, add it to sections list + if current_heading is not None and current_section: + heading_text = current_heading.text + section_text = f"{heading_text}\n" + "\n".join(current_section) + sections.append( + Section( + text=section_text.strip(), + link=_build_gdoc_section_link(doc_id, current_heading.id), + ) + ) + current_section = [] + + # Start new heading + heading_id = _extract_id_from_heading(paragraph) + heading_text = _extract_text_from_paragraph(paragraph) + current_heading = CurrentHeading( + id=heading_id, + text=heading_text, + ) + continue + + # Add content to current section + if current_heading is not None: + text = _extract_text_from_paragraph(paragraph) + if text.strip(): + current_section.append(text) + + # Don't forget to add the last section + if current_heading is not None and current_section: + section_text = f"{current_heading.text}\n" + "\n".join(current_section) + sections.append( + Section( + text=section_text.strip(), + link=_build_gdoc_section_link(doc_id, current_heading.id), + ) + ) + + return sections diff --git a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py index d1df0cb0846..9f736b22a00 100644 --- a/backend/ee/danswer/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/doc_sync.py @@ -7,6 +7,7 @@ from danswer.access.models import ExternalAccess from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval +from danswer.connectors.google_drive.resources import get_drive_service from danswer.connectors.interfaces import GenerateSlimDocumentOutput from danswer.connectors.models import SlimDocument from danswer.db.models import ConnectorCredentialPair @@ -56,7 +57,10 @@ def _fetch_permissions_for_permission_ids( return permissions owner_email = permission_info.get("owner_email") - drive_service = google_drive_connector.get_google_resource(user_email=owner_email) + drive_service = get_drive_service( + creds=google_drive_connector.creds, + user_email=(owner_email or google_drive_connector.primary_admin_email), + ) # Otherwise, fetch all permissions and update cache fetched_permissions = execute_paginated_retrieval( diff --git a/backend/ee/danswer/external_permissions/google_drive/group_sync.py b/backend/ee/danswer/external_permissions/google_drive/group_sync.py index c3afa962392..919866749ff 100644 --- a/backend/ee/danswer/external_permissions/google_drive/group_sync.py +++ b/backend/ee/danswer/external_permissions/google_drive/group_sync.py @@ -2,6 +2,7 @@ from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval +from danswer.connectors.google_drive.resources import get_admin_service from danswer.db.models import ConnectorCredentialPair from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit from danswer.utils.logger import setup_logger @@ -19,8 +20,9 @@ def gdrive_group_sync( **cc_pair.connector.connector_specific_config ) google_drive_connector.load_credentials(cc_pair.credential.credential_json) - - admin_service = google_drive_connector.get_google_resource("admin", "directory_v1") + admin_service = get_admin_service( + google_drive_connector.creds, google_drive_connector.primary_admin_email + ) danswer_groups: list[ExternalUserGroup] = [] for group in execute_paginated_retrieval( diff --git a/backend/tests/daily/connectors/google_drive/conftest.py b/backend/tests/daily/connectors/google_drive/conftest.py index 0b516d0359c..c330dcf2d60 100644 --- a/backend/tests/daily/connectors/google_drive/conftest.py +++ b/backend/tests/daily/connectors/google_drive/conftest.py @@ -31,6 +31,30 @@ def load_env_vars(env_file: str = ".env") -> None: load_env_vars() +def parse_credentials(env_str: str) -> dict: + """ + Parse a double-escaped JSON string from environment variables into a Python dictionary. + + Args: + env_str (str): The double-escaped JSON string from environment variables + + Returns: + dict: Parsed OAuth credentials + """ + # first try normally + try: + return json.loads(env_str) + except Exception: + # First, try remove extra escaping backslashes + unescaped = env_str.replace('\\"', '"') + + # remove leading / trailing quotes + unescaped = unescaped.strip('"') + + # Now parse the JSON + return json.loads(unescaped) + + @pytest.fixture def google_drive_oauth_connector_factory() -> Callable[..., GoogleDriveConnector]: def _connector_factory( @@ -50,7 +74,7 @@ def _connector_factory( ) json_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"] - refried_json_string = json.loads(json_string) + refried_json_string = json.dumps(parse_credentials(json_string)) credentials_json = { DB_CREDENTIALS_DICT_TOKEN_KEY: refried_json_string, @@ -84,7 +108,7 @@ def _connector_factory( ) json_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"] - refried_json_string = json.loads(json_string) + refried_json_string = json.dumps(parse_credentials(json_string)) # Load Service Account Credentials connector.load_credentials( diff --git a/backend/tests/daily/connectors/google_drive/helpers.py b/backend/tests/daily/connectors/google_drive/helpers.py index a1bc8feec38..7a120412e9d 100644 --- a/backend/tests/daily/connectors/google_drive/helpers.py +++ b/backend/tests/daily/connectors/google_drive/helpers.py @@ -18,6 +18,7 @@ _FOLDER_2_FILE_IDS = list(range(45, 50)) _FOLDER_2_1_FILE_IDS = list(range(50, 55)) _FOLDER_2_2_FILE_IDS = list(range(55, 60)) +_SECTIONS_FILE_IDS = [61] _PUBLIC_FOLDER_RANGE = _FOLDER_1_2_FILE_IDS _PUBLIC_FILE_IDS = list(range(55, 57)) @@ -64,6 +65,7 @@ "FOLDER_2": _FOLDER_2_FILE_IDS, "FOLDER_2_1": _FOLDER_2_1_FILE_IDS, "FOLDER_2_2": _FOLDER_2_2_FILE_IDS, + "SECTIONS": _SECTIONS_FILE_IDS, } # Dictionary for emails @@ -100,6 +102,7 @@ + _FOLDER_2_FILE_IDS + _FOLDER_2_1_FILE_IDS + _FOLDER_2_2_FILE_IDS + + _SECTIONS_FILE_IDS ), # This user has access to drive 1 # This user has redundant access to folder 1 because of group access @@ -127,6 +130,21 @@ "TEST_USER_3": _TEST_USER_3_FILE_IDS, } +SPECIAL_FILE_ID_TO_CONTENT_MAP: dict[int, str] = { + 61: ( + "Title\n\n" + "This is a Google Doc with sections - " + "Section 1\n\n" + "Section 1 content - " + "Sub-Section 1-1\n\n" + "Sub-Section 1-1 content - " + "Sub-Section 1-2\n\n" + "Sub-Section 1-2 content - " + "Section 2\n\n" + "Section 2 content" + ), +} + file_name_template = "file_{}.txt" file_text_template = "This is file {}" @@ -142,18 +160,28 @@ def print_discrepencies(expected: set[str], retrieved: set[str]) -> None: print(expected - retrieved) +def get_file_content(file_id: int) -> str: + if file_id in SPECIAL_FILE_ID_TO_CONTENT_MAP: + return SPECIAL_FILE_ID_TO_CONTENT_MAP[file_id] + + return file_text_template.format(file_id) + + def assert_retrieved_docs_match_expected( retrieved_docs: list[Document], expected_file_ids: Sequence[int] ) -> None: expected_file_names = { file_name_template.format(file_id) for file_id in expected_file_ids } - expected_file_texts = { - file_text_template.format(file_id) for file_id in expected_file_ids - } + expected_file_texts = {get_file_content(file_id) for file_id in expected_file_ids} retrieved_file_names = set([doc.semantic_identifier for doc in retrieved_docs]) - retrieved_texts = set([doc.sections[0].text for doc in retrieved_docs]) + retrieved_texts = set( + [ + " - ".join([section.text for section in doc.sections]) + for doc in retrieved_docs + ] + ) # Check file names print_discrepencies(expected_file_names, retrieved_file_names) diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py b/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py index f39b15600b4..a4adaa4673a 100644 --- a/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_oauth.py @@ -41,6 +41,7 @@ def test_include_all( + DRIVE_ID_MAPPING["FOLDER_2"] + DRIVE_ID_MAPPING["FOLDER_2_1"] + DRIVE_ID_MAPPING["FOLDER_2_2"] + + DRIVE_ID_MAPPING["SECTIONS"] ) assert_retrieved_docs_match_expected( retrieved_docs=retrieved_docs, @@ -75,6 +76,7 @@ def test_include_shared_drives_only( + DRIVE_ID_MAPPING["FOLDER_2"] + DRIVE_ID_MAPPING["FOLDER_2_1"] + DRIVE_ID_MAPPING["FOLDER_2_2"] + + DRIVE_ID_MAPPING["SECTIONS"] ) assert_retrieved_docs_match_expected( retrieved_docs=retrieved_docs, diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_sections.py b/backend/tests/daily/connectors/google_drive/test_google_drive_sections.py new file mode 100644 index 00000000000..4f4556a06d6 --- /dev/null +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_sections.py @@ -0,0 +1,71 @@ +import time +from collections.abc import Callable +from unittest.mock import MagicMock +from unittest.mock import patch + +from danswer.connectors.google_drive.connector import GoogleDriveConnector +from danswer.connectors.models import Document + + +SECTIONS_FOLDER_URL = ( + "https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33" +) + + +@patch( + "danswer.file_processing.extract_file_text.get_unstructured_api_key", + return_value=None, +) +def test_google_drive_sections( + mock_get_api_key: MagicMock, + google_drive_oauth_connector_factory: Callable[..., GoogleDriveConnector], + google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector], +) -> None: + oauth_connector = google_drive_oauth_connector_factory( + include_shared_drives=False, + include_my_drives=False, + shared_folder_urls=SECTIONS_FOLDER_URL, + ) + service_acct_connector = google_drive_service_acct_connector_factory( + include_shared_drives=False, + include_my_drives=False, + shared_folder_urls=SECTIONS_FOLDER_URL, + ) + for connector in [oauth_connector, service_acct_connector]: + retrieved_docs: list[Document] = [] + for doc_batch in connector.poll_source(0, time.time()): + retrieved_docs.extend(doc_batch) + + # Verify we got the 1 doc with sections + assert len(retrieved_docs) == 1 + + # Verify each section has the expected structure + doc = retrieved_docs[0] + assert len(doc.sections) == 5 + + header_section = doc.sections[0] + assert header_section.text == "Title\n\nThis is a Google Doc with sections" + assert header_section.link is not None + assert header_section.link.endswith( + "?tab=t.0#heading=h.hfjc17k6qwzt" + ) or header_section.link.endswith("?tab=t.0#heading=h.hfjc17k6qwzt") + + section_1 = doc.sections[1] + assert section_1.text == "Section 1\n\nSection 1 content" + assert section_1.link is not None + assert section_1.link.endswith("?tab=t.0#heading=h.8slfx752a3g5") + + section_2 = doc.sections[2] + assert section_2.text == "Sub-Section 1-1\n\nSub-Section 1-1 content" + assert section_2.link is not None + assert section_2.link.endswith("?tab=t.0#heading=h.4kj3ayade1bp") + + section_3 = doc.sections[3] + assert section_3.text == "Sub-Section 1-2\n\nSub-Section 1-2 content" + assert section_3.link is not None + assert section_3.link.endswith("?tab=t.0#heading=h.pm6wrpzgk69l") + + section_4 = doc.sections[4] + assert section_4.text == "Section 2\n\nSection 2 content" + assert section_4.link is not None + assert section_4.link.endswith("?tab=t.0#heading=h.2m0s9youe2k9") diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py b/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py index b36a53b30f6..a7f081fd323 100644 --- a/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_service_acct.py @@ -44,6 +44,7 @@ def test_include_all( + DRIVE_ID_MAPPING["FOLDER_2"] + DRIVE_ID_MAPPING["FOLDER_2_1"] + DRIVE_ID_MAPPING["FOLDER_2_2"] + + DRIVE_ID_MAPPING["SECTIONS"] ) assert_retrieved_docs_match_expected( retrieved_docs=retrieved_docs, @@ -78,6 +79,7 @@ def test_include_shared_drives_only( + DRIVE_ID_MAPPING["FOLDER_2"] + DRIVE_ID_MAPPING["FOLDER_2_1"] + DRIVE_ID_MAPPING["FOLDER_2_2"] + + DRIVE_ID_MAPPING["SECTIONS"] ) assert_retrieved_docs_match_expected( retrieved_docs=retrieved_docs, diff --git a/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py b/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py index e731c8b27ce..d7e79da09d5 100644 --- a/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py +++ b/backend/tests/daily/connectors/google_drive/test_google_drive_slim_docs.py @@ -6,6 +6,7 @@ from danswer.access.models import ExternalAccess from danswer.connectors.google_drive.connector import GoogleDriveConnector from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval +from danswer.connectors.google_drive.resources import get_admin_service from ee.danswer.external_permissions.google_drive.doc_sync import ( _get_permissions_from_slim_doc, ) @@ -72,7 +73,10 @@ def assert_correct_access_for_user( # This function is supposed to map to the group_sync.py file for the google drive connector # TODO: Call it directly def get_group_map(google_drive_connector: GoogleDriveConnector) -> dict[str, list[str]]: - admin_service = google_drive_connector.get_google_resource("admin", "directory_v1") + admin_service = get_admin_service( + creds=google_drive_connector.creds, + user_email=google_drive_connector.primary_admin_email, + ) group_map: dict[str, list[str]] = {} for group in execute_paginated_retrieval( @@ -138,6 +142,7 @@ def test_all_permissions( + DRIVE_ID_MAPPING["FOLDER_2"] + DRIVE_ID_MAPPING["FOLDER_2_1"] + DRIVE_ID_MAPPING["FOLDER_2_2"] + + DRIVE_ID_MAPPING["SECTIONS"] ) # Should get everything