Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drive sections #3040

Merged
merged 6 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 44 additions & 28 deletions backend/danswer/connectors/google_drive/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.discovery import Resource # type: ignore

from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.connectors.google_drive.connector_auth import (
Expand All @@ -24,6 +22,9 @@
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from danswer.connectors.google_drive.google_utils import execute_paginated_retrieval
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_drive.resources import get_admin_service
from danswer.connectors.google_drive.resources import get_drive_service
from danswer.connectors.google_drive.resources import get_google_docs_service
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
Expand Down Expand Up @@ -103,42 +104,49 @@ def __init__(
shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls)
self.shared_folder_ids = _extract_ids_from_urls(shared_folder_url_list)

self.primary_admin_email: str | None = None
self._primary_admin_email: str | None = None
self.google_domain: str | None = None

self.creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None

self._TRAVERSED_PARENT_IDS: set[str] = set()

@property
def primary_admin_email(self) -> str:
if self._primary_admin_email is None:
raise RuntimeError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
return self._primary_admin_email

@property
def creds(self) -> OAuthCredentials | ServiceAccountCredentials:
if self._creds is None:
raise RuntimeError(
"Creds missing, "
"should not call this property "
"before calling load_credentials"
)
return self._creds

def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._TRAVERSED_PARENT_IDS.add(folder_id)

def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self.google_domain = primary_admin_email.split("@")[1]
self.primary_admin_email = primary_admin_email
self._primary_admin_email = primary_admin_email

self.creds, new_creds_dict = get_google_drive_creds(credentials)
self._creds, new_creds_dict = get_google_drive_creds(credentials)
return new_creds_dict

def get_google_resource(
self,
service_name: str = "drive",
service_version: str = "v3",
user_email: str | None = None,
) -> Resource:
if isinstance(self.creds, ServiceAccountCredentials):
creds = self.creds.with_subject(user_email or self.primary_admin_email)
service = build(service_name, service_version, credentials=creds)
elif isinstance(self.creds, OAuthCredentials):
service = build(service_name, service_version, credentials=self.creds)
else:
raise PermissionError("No credentials found")

return service

def _get_all_user_emails(self) -> list[str]:
admin_service = self.get_google_resource("admin", "directory_v1")
admin_service = get_admin_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
Expand All @@ -156,7 +164,10 @@ def _fetch_drive_items(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
primary_drive_service = self.get_google_resource()
primary_drive_service = get_drive_service(
creds=self.creds,
user_email=self.primary_admin_email,
)

if self.include_shared_drives:
shared_drive_urls = self.shared_drive_ids
Expand Down Expand Up @@ -212,7 +223,7 @@ def _fetch_drive_items(

for email in all_user_emails:
logger.info(f"Fetching personal files for user: {email}")
user_drive_service = self.get_google_resource(user_email=email)
user_drive_service = get_drive_service(self.creds, user_email=email)

yield from get_files_in_my_drive(
service=user_drive_service,
Expand All @@ -233,11 +244,16 @@ def _extract_docs_from_google_drive(
start=start,
end=end,
):
user_email = file.get("owners", [{}])[0].get("emailAddress")
service = self.get_google_resource(user_email=user_email)
user_email = (
file.get("owners", [{}])[0].get("emailAddress")
or self.primary_admin_email
)
user_drive_service = get_drive_service(self.creds, user_email=user_email)
docs_service = get_google_docs_service(self.creds, user_email=user_email)
if doc := convert_drive_item_to_document(
file=file,
service=service,
drive_service=user_drive_service,
docs_service=docs_service,
):
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
Expand Down
2 changes: 2 additions & 0 deletions backend/danswer/connectors/google_drive/connector_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

logger = setup_logger()

# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
# this is counted under `/auth/drive.readonly`
GOOGLE_DRIVE_SCOPES = [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
Expand Down
161 changes: 107 additions & 54 deletions backend/danswer/connectors/google_drive/doc_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from datetime import datetime
from datetime import timezone

from googleapiclient.discovery import Resource # type: ignore
from googleapiclient.errors import HttpError # type: ignore

from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
Expand All @@ -13,6 +12,9 @@
from danswer.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
from danswer.connectors.google_drive.models import GDriveMimeType
from danswer.connectors.google_drive.models import GoogleDriveFileType
from danswer.connectors.google_drive.resources import GoogleDocsService
from danswer.connectors.google_drive.resources import GoogleDriveService
from danswer.connectors.google_drive.section_extraction import get_document_sections
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import docx_to_text
Expand All @@ -25,86 +27,137 @@
logger = setup_logger()


def _extract_text(file: dict[str, str], service: Resource) -> str:
def _extract_sections_basic(
file: dict[str, str], service: GoogleDriveService
) -> list[Section]:
mime_type = file["mimeType"]
link = file["webViewLink"]

if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT

if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
return (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return unstructured_to_text(
file=io.BytesIO(response), file_name=file.get("name", file["id"])
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]

try:
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
text = (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
return [Section(link=link, text=text)]
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
return [
Section(
link=link,
text=service.files()
.get_media(fileId=file["id"])
.execute()
.decode("utf-8"),
)
]
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
return [
Section(
link=link,
text=unstructured_to_text(
file=io.BytesIO(response),
file_name=file.get("name", file["id"]),
),
)
]

if mime_type == GDriveMimeType.WORD_DOC.value:
return [
Section(link=link, text=docx_to_text(file=io.BytesIO(response)))
]
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return [Section(link=link, text=text)]
elif mime_type == GDriveMimeType.POWERPOINT.value:
return [
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
]

if mime_type == GDriveMimeType.WORD_DOC.value:
return docx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PDF.value:
text, _ = read_pdf_file(file=io.BytesIO(response))
return text
elif mime_type == GDriveMimeType.POWERPOINT.value:
return pptx_to_text(file=io.BytesIO(response))
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]

return UNSUPPORTED_FILE_TYPE_CONTENT
except Exception:
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]


def convert_drive_item_to_document(
file: GoogleDriveFileType, service: Resource
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
) -> Document | None:
try:
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
return None
try:
text_contents = _extract_text(file, service) or ""
except HttpError as e:
reason = e.error_details[0]["reason"] if e.error_details else e.reason
message = e.error_details[0]["message"] if e.error_details else e.reason
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:

sections: list[Section] = []

# Special handling for Google Docs to preserve structure, link
# to headers
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
sections = get_document_sections(docs_service, file["id"])
except Exception as e:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'."
" Falling back to basic extraction."
)
return None

raise
# NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc
if not sections:
try:
# For all other file types just extract the text
sections = _extract_sections_basic(file, drive_service)

except HttpError as e:
reason = e.error_details[0]["reason"] if e.error_details else e.reason
message = e.error_details[0]["message"] if e.error_details else e.reason
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
return None

raise

if not sections:
return None

return Document(
id=file["webViewLink"],
sections=[Section(link=file["webViewLink"], text=text_contents)],
sections=sections,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
),
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
metadata={}
if any(section.text for section in sections)
else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/connectors/google_drive/google_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def execute_paginated_retrieval(
if next_page_token:
request_kwargs["pageToken"] = next_page_token

results = add_retries(lambda: retrieval_function(**request_kwargs).execute())()
results = (lambda: retrieval_function(**request_kwargs).execute())()

next_page_token = results.get("nextPageToken")
for item in results.get(list_key, []):
Expand Down
Loading
Loading