From 2e748c2c534bcea0af0b2ff0df9cb57697b010ef Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 3 Dec 2024 15:35:49 -0800 Subject: [PATCH 1/7] Egnyte connector --- backend/danswer/configs/constants.py | 1 + .../danswer/connectors/egnyte/connector.py | 143 ++++++++++++++++++ backend/danswer/connectors/factory.py | 2 + web/src/lib/connectors/connectors.tsx | 24 +++ web/src/lib/sources.ts | 6 + web/src/lib/types.ts | 1 + 6 files changed, 177 insertions(+) create mode 100644 backend/danswer/connectors/egnyte/connector.py diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index e84c229a696..b9b5f7deb26 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -132,6 +132,7 @@ class DocumentSource(str, Enum): NOT_APPLICABLE = "not_applicable" FRESHDESK = "freshdesk" FIREFLIES = "fireflies" + EGNYTE = "egnyte" DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE] diff --git a/backend/danswer/connectors/egnyte/connector.py b/backend/danswer/connectors/egnyte/connector.py new file mode 100644 index 00000000000..a5ed05c9b83 --- /dev/null +++ b/backend/danswer/connectors/egnyte/connector.py @@ -0,0 +1,143 @@ +import os +from datetime import datetime +from datetime import timezone +from typing import Any + +import requests + +from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.connectors.interfaces import GenerateDocumentsOutput +from danswer.connectors.interfaces import LoadConnector +from danswer.connectors.interfaces import PollConnector +from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.connectors.models import ConnectorMissingCredentialError +from danswer.connectors.models import Document +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + +_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1" +_TIMEOUT = 60 + + +class EgnyteConnector(LoadConnector, PollConnector): + def __init__( + self, + domain: str | None = None, + folder_path: str | None = None, + batch_size: int = INDEX_BATCH_SIZE, + ) -> None: + self.domain = domain + self.folder_path = folder_path or "/" # Root folder if not specified + self.batch_size = batch_size + self.access_token: str | None = None + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self.domain = credentials["domain"] + self.access_token = credentials["access_token"] + return None + + def _get_files_list( + self, + path: str, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> list[dict[str, Any]]: + if not self.access_token or not self.domain: + raise ConnectorMissingCredentialError("Egnyte") + + headers = { + "Authorization": f"Bearer {self.access_token}", + } + + params: dict[str, Any] = { + "list_content": True, + "folder_id": path, + } + + if start_time: + params["last_modified_after"] = start_time.isoformat() + if end_time: + params["last_modified_before"] = end_time.isoformat() + + url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs" + response = requests.get(url, headers=headers, params=params, timeout=_TIMEOUT) + + if not response.ok: + raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}") + + return response.json().get("files", []) + + def _process_files( + self, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> GenerateDocumentsOutput: + files = self._get_files_list(self.folder_path, start_time, end_time) + + current_batch: list[Document] = [] + for file in files: + if not file["is_folder"]: + try: + # Get file content + headers = { + "Authorization": f"Bearer {self.access_token}", + } + url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}" + response = requests.get(url, headers=headers, timeout=_TIMEOUT) + + if not response.ok: + logger.error(f"Failed to fetch file content: {file['path']}") + continue + + # doc = process_file( + # file_name=file["name"], + # file_data=response.content, + # source=DocumentSource.EGNYTE, + # url=file.get("url", ""), + # metadata={ + # "folder_path": os.path.dirname(file["path"]), + # "size": file.get("size", 0), + # "last_modified": file.get("last_modified", ""), + # }, + # ) + # TOOD: Implement this + doc = None + + if doc is not None: + current_batch.append(doc) + + if len(current_batch) >= self.batch_size: + yield current_batch + current_batch = [] + + except Exception as e: + logger.error(f"Failed to process file {file['path']}: {str(e)}") + continue + + if current_batch: + yield current_batch + + def load_from_state(self) -> GenerateDocumentsOutput: + yield from self._process_files() + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> GenerateDocumentsOutput: + start_time = datetime.fromtimestamp(start, tz=timezone.utc) + end_time = datetime.fromtimestamp(end, tz=timezone.utc) + + yield from self._process_files(start_time=start_time, end_time=end_time) + + +if __name__ == "__main__": + connector = EgnyteConnector() + connector.load_credentials( + { + "domain": os.environ["EGNYTE_DOMAIN"], + "access_token": os.environ["EGNYTE_ACCESS_TOKEN"], + } + ) + document_batches = connector.load_from_state() + print(next(document_batches)) diff --git a/backend/danswer/connectors/factory.py b/backend/danswer/connectors/factory.py index 40f926b31d1..87d1539d3d2 100644 --- a/backend/danswer/connectors/factory.py +++ b/backend/danswer/connectors/factory.py @@ -15,6 +15,7 @@ from danswer.connectors.discourse.connector import DiscourseConnector from danswer.connectors.document360.connector import Document360Connector from danswer.connectors.dropbox.connector import DropboxConnector +from danswer.connectors.egnyte.connector import EgnyteConnector from danswer.connectors.file.connector import LocalFileConnector from danswer.connectors.fireflies.connector import FirefliesConnector from danswer.connectors.freshdesk.connector import FreshdeskConnector @@ -103,6 +104,7 @@ def identify_connector_class( DocumentSource.XENFORO: XenforoConnector, DocumentSource.FRESHDESK: FreshdeskConnector, DocumentSource.FIREFLIES: FirefliesConnector, + DocumentSource.EGNYTE: EgnyteConnector, } connector_by_source = connector_map.get(source, {}) diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index ed4cd39e329..3af19b779a9 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -1050,6 +1050,30 @@ For example, specifying .*-support.* as a "channel" will cause the connector to values: [], advanced_values: [], }, + egnyte: { + description: "Configure Egnyte connector", + values: [ + { + type: "text", + query: "Enter your Egnyte domain:", + label: "Domain", + name: "domain", + optional: false, + description: + "Your Egnyte domain (e.g., if your Egnyte URL is 'company.egnyte.com', enter 'company')", + }, + { + type: "text", + query: "Enter folder path to index:", + label: "Folder Path", + name: "folder_path", + optional: true, + description: + "The folder path to index (e.g., '/Shared/Documents'). Leave empty to index everything.", + }, + ], + advanced_values: [], + }, }; export function createConnectorInitialValues( connector: ConfigurableSources diff --git a/web/src/lib/sources.ts b/web/src/lib/sources.ts index c77c5277d35..8c51a1bb2fe 100644 --- a/web/src/lib/sources.ts +++ b/web/src/lib/sources.ts @@ -304,6 +304,12 @@ export const SOURCE_METADATA_MAP: SourceMap = { displayName: "Not Applicable", category: SourceCategory.Other, }, + egnyte: { + icon: FileIcon, + displayName: "Egnyte", + category: SourceCategory.Storage, + docs: "https://docs.danswer.dev/connectors/egnyte", + }, } as SourceMap; function fillSourceMetadata( diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 29f0acf995e..d08fb0d7563 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -309,6 +309,7 @@ export enum ValidSources { IngestionApi = "ingestion_api", Freshdesk = "freshdesk", Fireflies = "fireflies", + Egnyte = "egnyte", } export const validAutoSyncSources = [ From faa4e0cd6e8548d69863f86029c5dc71f431afbc Mon Sep 17 00:00:00 2001 From: Weves Date: Mon, 9 Dec 2024 19:34:13 -0800 Subject: [PATCH 2/7] Egnyte connector more --- .../danswer/connectors/egnyte/connector.py | 37 ++++++- backend/danswer/connectors/interfaces.py | 18 ++++ backend/danswer/main.py | 2 + .../server/documents/standard_oauth.py | 101 ++++++++++++++++++ backend/danswer/utils/subclasses.py | 77 +++++++++++++ .../[connector]/AddConnectorPage.tsx | 15 ++- web/src/lib/connectors/credentials.ts | 9 ++ 7 files changed, 253 insertions(+), 6 deletions(-) create mode 100644 backend/danswer/server/documents/standard_oauth.py create mode 100644 backend/danswer/utils/subclasses.py diff --git a/backend/danswer/connectors/egnyte/connector.py b/backend/danswer/connectors/egnyte/connector.py index a5ed05c9b83..dd8922f8b6c 100644 --- a/backend/danswer/connectors/egnyte/connector.py +++ b/backend/danswer/connectors/egnyte/connector.py @@ -8,20 +8,27 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import LoadConnector +from danswer.connectors.interfaces import OAuthConnector from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import SecondsSinceUnixEpoch from danswer.connectors.models import ConnectorMissingCredentialError from danswer.connectors.models import Document from danswer.utils.logger import setup_logger +from danswer.utils.special_types import JSON_ro logger = setup_logger() +_EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE") +_EGNYTE_DOMAIN = os.getenv("EGNYTE_DOMAIN") +_EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID") +_EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET") + _EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1" _TIMEOUT = 60 -class EgnyteConnector(LoadConnector, PollConnector): +class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector): def __init__( self, domain: str | None = None, @@ -33,6 +40,34 @@ def __init__( self.batch_size = batch_size self.access_token: str | None = None + @classmethod + def oauth_id(cls) -> str: + return "egnyte" + + @classmethod + def redirect_uri(cls, base_domain: str) -> str: + if not _EGNYTE_CLIENT_ID: + raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") + if not _EGNYTE_DOMAIN: + raise ValueError("EGNYTE_DOMAIN environment variable must be set") + + if _EGNYTE_LOCALHOST_OVERRIDE: + base_domain = _EGNYTE_LOCALHOST_OVERRIDE + + callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte" + return ( + f"https://{_EGNYTE_DOMAIN}.egnyte.com/puboauth/token" + f"?client_id={_EGNYTE_CLIENT_ID}" + f"&redirect_uri={callback_uri}" + f"&scope=Egnyte.filesystem" + # f"&state=danswer" + f"&response_type=code" + ) + + @classmethod + def code_to_token(cls, code: str) -> JSON_ro: + pass + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.domain = credentials["domain"] self.access_token = credentials["access_token"] diff --git a/backend/danswer/connectors/interfaces.py b/backend/danswer/connectors/interfaces.py index c53b3de5f2f..e4ffe23568d 100644 --- a/backend/danswer/connectors/interfaces.py +++ b/backend/danswer/connectors/interfaces.py @@ -4,6 +4,7 @@ from danswer.connectors.models import Document from danswer.connectors.models import SlimDocument +from danswer.utils.special_types import JSON_ro SecondsSinceUnixEpoch = float @@ -64,6 +65,23 @@ def retrieve_all_slim_documents( raise NotImplementedError +class OAuthConnector(BaseConnector): + @classmethod + @abc.abstractmethod + def oauth_id(cls) -> str: + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def redirect_uri(cls, base_domain: str) -> str: + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def code_to_token(cls, code: str) -> JSON_ro: + raise NotImplementedError + + # Event driven class EventConnector(BaseConnector): @abc.abstractmethod diff --git a/backend/danswer/main.py b/backend/danswer/main.py index b78607b92f9..79bd145fbc7 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -52,6 +52,7 @@ from danswer.server.documents.credential import router as credential_router from danswer.server.documents.document import router as document_router from danswer.server.documents.indexing import router as indexing_router +from danswer.server.documents.standard_oauth import router as oauth_router from danswer.server.features.document_set.api import router as document_set_router from danswer.server.features.folder.api import router as folder_router from danswer.server.features.input_prompt.api import ( @@ -282,6 +283,7 @@ def get_application() -> FastAPI: ) include_router_with_global_prefix_prepended(application, long_term_logs_router) include_router_with_global_prefix_prepended(application, api_key_router) + include_router_with_global_prefix_prepended(application, oauth_router) if AUTH_TYPE == AuthType.DISABLED: # Server logs this during auth setup verification step diff --git a/backend/danswer/server/documents/standard_oauth.py b/backend/danswer/server/documents/standard_oauth.py new file mode 100644 index 00000000000..b30e56a1033 --- /dev/null +++ b/backend/danswer/server/documents/standard_oauth.py @@ -0,0 +1,101 @@ +from typing import Annotated + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Query +from fastapi import Request +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.connectors.interfaces import OAuthConnector +from danswer.db.credentials import create_credential +from danswer.db.engine import get_session +from danswer.db.models import User +from danswer.server.documents.models import CredentialBase +from danswer.utils.logger import setup_logger +from danswer.utils.subclasses import find_all_subclasses_in_dir + +logger = setup_logger() + +router = APIRouter(prefix="/connector/oauth") + +# Cache for OAuth connectors, populated at module load time +_OAUTH_CONNECTORS: dict[str, type[OAuthConnector]] = {} + + +def _discover_oauth_connectors() -> dict[str, type[OAuthConnector]]: + """Walk through the connectors package to find all OAuthConnector implementations""" + global _OAUTH_CONNECTORS + if _OAUTH_CONNECTORS: # Return cached connectors if already discovered + return _OAUTH_CONNECTORS + + oauth_connectors = find_all_subclasses_in_dir(OAuthConnector, "danswer.connectors") + + _OAUTH_CONNECTORS = {cls.oauth_id(): cls for cls in oauth_connectors} + return _OAUTH_CONNECTORS + + +# Discover OAuth connectors at module load time +_discover_oauth_connectors() + + +@router.get("/authorize/{source}") +def oauth_authorize( + request: Request, + source: str, + _: User = Depends(current_user), +) -> dict: + """Initiates the OAuth flow by redirecting to the provider's auth page""" + oauth_connectors = _discover_oauth_connectors() + + if source not in oauth_connectors: + raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}") + + connector_cls = oauth_connectors[source] + base_url = str(request.base_url) + if "127.0.0.1" in base_url: + base_url = base_url.replace("127.0.0.1", "localhost") + return {"redirect_url": connector_cls.redirect_uri(base_url)} + + +@router.get("/callback/{source}") +async def oauth_callback( + source: str, + code: Annotated[str, Query()], + state: Annotated[str | None, Query()] = None, + db_session: Session = Depends(get_session), + user: User = Depends(current_user), +) -> dict: + """Handles the OAuth callback and exchanges the code for tokens""" + oauth_connectors = _discover_oauth_connectors() + + if source not in oauth_connectors: + raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}") + + connector_cls = oauth_connectors[source] + + try: + token_info = connector_cls.code_to_token(code) + + # Create a new credential with the token info + credential_data = CredentialBase( + credential_json=token_info, + admin_public=True, # Or based on some logic/parameter + source=source, + name=f"{source.title()} OAuth Credential", + ) + + credential = create_credential( + credential_data=credential_data, + user=user, + db_session=db_session, + ) + + return { + "credential_id": credential.id, + "token_info": token_info, + "message": "Successfully authenticated and created credential", + } + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) diff --git a/backend/danswer/utils/subclasses.py b/backend/danswer/utils/subclasses.py new file mode 100644 index 00000000000..72408f98b08 --- /dev/null +++ b/backend/danswer/utils/subclasses.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import importlib +import os +import pkgutil +import sys +from types import ModuleType +from typing import List +from typing import Type +from typing import TypeVar + +T = TypeVar("T") + + +def import_all_modules_from_dir(dir_path: str) -> List[ModuleType]: + """ + Imports all modules found in the given directory and its subdirectories, + returning a list of imported module objects. + """ + dir_path = os.path.abspath(dir_path) + + if dir_path not in sys.path: + sys.path.insert(0, dir_path) + + imported_modules: List[ModuleType] = [] + + for _, package_name, _ in pkgutil.walk_packages([dir_path]): + try: + module = importlib.import_module(package_name) + imported_modules.append(module) + except Exception as e: + # Handle or log exceptions as needed + print(f"Could not import {package_name}: {e}") + + return imported_modules + + +def all_subclasses(cls: Type[T]) -> List[Type[T]]: + """ + Recursively find all subclasses of the given class. + """ + direct_subs = cls.__subclasses__() + result: List[Type[T]] = [] + for subclass in direct_subs: + result.append(subclass) + # Extend the result by recursively calling all_subclasses + result.extend(all_subclasses(subclass)) + return result + + +def find_all_subclasses_in_dir(parent_class: Type[T], directory: str) -> List[Type[T]]: + """ + Imports all modules from the given directory (and subdirectories), + then returns all classes that are subclasses of parent_class. + + :param parent_class: The class to find subclasses of. + :param directory: The directory to search for subclasses. + :return: A list of all subclasses of parent_class found in the directory. + """ + # First import all modules to ensure classes are loaded into memory + import_all_modules_from_dir(directory) + + # Gather all subclasses of the given parent class + subclasses = all_subclasses(parent_class) + return subclasses + + +# Example usage: +if __name__ == "__main__": + + class Animal: + pass + + # Suppose "mymodules" contains files that define classes inheriting from Animal + found_subclasses = find_all_subclasses_in_dir(Animal, "mymodules") + for sc in found_subclasses: + print("Found subclass:", sc.__name__) diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index 62556993a91..eda454510bc 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -49,6 +49,7 @@ import { useRouter } from "next/navigation"; import CardSection from "@/components/admin/CardSection"; import { prepareOAuthAuthorizationRequest } from "@/lib/oauth_utils"; import { EE_ENABLED, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants"; +import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth"; export interface AdvancedConfig { refreshFreq: number; pruneFreq: number; @@ -442,11 +443,15 @@ export default function AddConnector({ {/* Button to pop up a form to manually enter credentials */} diff --git a/web/src/lib/connectors/credentials.ts b/web/src/lib/connectors/credentials.ts index 8b52cb5c9a1..15ee7a47838 100644 --- a/web/src/lib/connectors/credentials.ts +++ b/web/src/lib/connectors/credentials.ts @@ -195,6 +195,11 @@ export interface FirefliesCredentialJson { export interface MediaWikiCredentialJson {} export interface WikipediaCredentialJson extends MediaWikiCredentialJson {} +export interface EgnyteCredentialJson { + domain: string; + access_token: string; +} + export const credentialTemplates: Record = { github: { github_access_token: "" } as GithubCredentialJson, gitlab: { @@ -298,6 +303,10 @@ export const credentialTemplates: Record = { fireflies: { fireflies_api_key: "", } as FirefliesCredentialJson, + egnyte: { + domain: "", + access_token: "", + } as EgnyteCredentialJson, xenforo: null, google_sites: null, file: null, From aed44ed0b85ef142709511fc11db0321c742c914 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 10 Dec 2024 13:37:49 -0800 Subject: [PATCH 3/7] More stuff --- .../danswer/connectors/egnyte/connector.py | 312 ++++++++++++++---- backend/danswer/connectors/interfaces.py | 3 +- .../server/documents/standard_oauth.py | 13 +- .../oauth/callback/[source]/route.tsx | 50 +++ web/src/lib/connectors/connectors.tsx | 9 - web/src/lib/connectors/oauth.ts | 23 ++ 6 files changed, 338 insertions(+), 72 deletions(-) create mode 100644 web/src/app/connector/oauth/callback/[source]/route.tsx create mode 100644 web/src/lib/connectors/oauth.ts diff --git a/backend/danswer/connectors/egnyte/connector.py b/backend/danswer/connectors/egnyte/connector.py index dd8922f8b6c..7cb2cfa2e6a 100644 --- a/backend/danswer/connectors/egnyte/connector.py +++ b/backend/danswer/connectors/egnyte/connector.py @@ -1,18 +1,31 @@ +import io import os +from collections.abc import Generator from datetime import datetime from datetime import timezone from typing import Any +from typing import IO import requests +from retry import retry from danswer.configs.app_configs import INDEX_BATCH_SIZE +from danswer.configs.constants import DocumentSource from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import OAuthConnector from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import SecondsSinceUnixEpoch +from danswer.connectors.models import BasicExpertInfo from danswer.connectors.models import ConnectorMissingCredentialError from danswer.connectors.models import Document +from danswer.connectors.models import Section +from danswer.file_processing.extract_file_text import check_file_ext_is_valid +from danswer.file_processing.extract_file_text import detect_encoding +from danswer.file_processing.extract_file_text import extract_file_text +from danswer.file_processing.extract_file_text import get_file_ext +from danswer.file_processing.extract_file_text import is_text_file_extension +from danswer.file_processing.extract_file_text import read_text_file from danswer.utils.logger import setup_logger from danswer.utils.special_types import JSON_ro @@ -20,35 +33,76 @@ logger = setup_logger() _EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE") -_EGNYTE_DOMAIN = os.getenv("EGNYTE_DOMAIN") +_EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN") _EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID") _EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET") _EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1" +_EGNYTE_APP_BASE = "https://{domain}.egnyte.com" _TIMEOUT = 60 +def _request_with_retries( + method: str, + url: str, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + timeout: int = _TIMEOUT, + stream: bool = False, +) -> requests.Response: + @retry(tries=8, delay=1, backoff=2, logger=logger) + def _make_request() -> requests.Response: + if method == "GET": + response = requests.get( + url, headers=headers, params=params, timeout=timeout, stream=stream + ) + elif method == "POST": + response = requests.post( + url, headers=headers, json=params, timeout=timeout, stream=stream + ) + elif method == "PUT": + response = requests.put( + url, headers=headers, json=params, timeout=timeout, stream=stream + ) + elif method == "DELETE": + response = requests.delete( + url, headers=headers, params=params, timeout=timeout, stream=stream + ) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() + return response + + return _make_request() + + +def _parse_last_modified(last_modified: str) -> datetime: + return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace( + tzinfo=timezone.utc + ) + + class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector): def __init__( self, - domain: str | None = None, folder_path: str | None = None, batch_size: int = INDEX_BATCH_SIZE, ) -> None: - self.domain = domain - self.folder_path = folder_path or "/" # Root folder if not specified + self.domain = "" # will always be set in `load_credentials` + self.folder_path = folder_path or "" # Root folder if not specified self.batch_size = batch_size self.access_token: str | None = None @classmethod - def oauth_id(cls) -> str: + def oauth_id(cls) -> DocumentSource: return "egnyte" @classmethod def redirect_uri(cls, base_domain: str) -> str: if not _EGNYTE_CLIENT_ID: raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") - if not _EGNYTE_DOMAIN: + if not _EGNYTE_BASE_DOMAIN: raise ValueError("EGNYTE_DOMAIN environment variable must be set") if _EGNYTE_LOCALHOST_OVERRIDE: @@ -56,17 +110,47 @@ def redirect_uri(cls, base_domain: str) -> str: callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte" return ( - f"https://{_EGNYTE_DOMAIN}.egnyte.com/puboauth/token" + f"https://{_EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token" f"?client_id={_EGNYTE_CLIENT_ID}" f"&redirect_uri={callback_uri}" f"&scope=Egnyte.filesystem" + # TODO: Add state support # f"&state=danswer" f"&response_type=code" ) @classmethod def code_to_token(cls, code: str) -> JSON_ro: - pass + if not _EGNYTE_CLIENT_ID: + raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") + if not _EGNYTE_CLIENT_SECRET: + raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set") + if not _EGNYTE_BASE_DOMAIN: + raise ValueError("EGNYTE_DOMAIN environment variable must be set") + + # Exchange code for token + url = f"https://{_EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token" + data = { + "client_id": _EGNYTE_CLIENT_ID, + "client_secret": _EGNYTE_CLIENT_SECRET, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": f"{_EGNYTE_LOCALHOST_OVERRIDE or ''}/connector/oauth/callback/egnyte", + "scope": "Egnyte.filesystem", + } + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + response = _request_with_retries( + method="POST", url=url, data=data, headers=headers + ) + if not response.ok: + raise RuntimeError(f"Failed to exchange code for token: {response.text}") + + token_data = response.json() + return { + "domain": _EGNYTE_BASE_DOMAIN, + "access_token": token_data["access_token"], + } def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.domain = credentials["domain"] @@ -76,8 +160,6 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None def _get_files_list( self, path: str, - start_time: datetime | None = None, - end_time: datetime | None = None, ) -> list[dict[str, Any]]: if not self.access_token or not self.domain: raise ConnectorMissingCredentialError("Egnyte") @@ -88,69 +170,106 @@ def _get_files_list( params: dict[str, Any] = { "list_content": True, - "folder_id": path, } - if start_time: - params["last_modified_after"] = start_time.isoformat() - if end_time: - params["last_modified_before"] = end_time.isoformat() - - url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs" - response = requests.get(url, headers=headers, params=params, timeout=_TIMEOUT) - + url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}" + response = _request_with_retries( + method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT + ) if not response.ok: raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}") - return response.json().get("files", []) + data = response.json() + all_files: list[dict[str, Any]] = [] + + # Add files from current directory + all_files.extend(data.get("files", [])) + + # Recursively traverse folders + for item in data.get("folders", []): + all_files.extend(self._get_files_list(item["path"])) + + return all_files + + def _filter_files( + self, + files: list[dict[str, Any]], + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> list[dict[str, Any]]: + filtered_files = [] + for file in files: + if file["is_folder"]: + continue + + file_modified = _parse_last_modified(file["last_modified"]) + if start_time and file_modified < start_time: + continue + if end_time and file_modified > end_time: + continue + + filtered_files.append(file) + + return filtered_files def _process_files( self, start_time: datetime | None = None, end_time: datetime | None = None, - ) -> GenerateDocumentsOutput: - files = self._get_files_list(self.folder_path, start_time, end_time) + ) -> Generator[list[Document], None, None]: + files = self._get_files_list(self.folder_path) + files = self._filter_files(files, start_time, end_time) current_batch: list[Document] = [] for file in files: - if not file["is_folder"]: - try: - # Get file content - headers = { - "Authorization": f"Bearer {self.access_token}", - } - url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}" - response = requests.get(url, headers=headers, timeout=_TIMEOUT) - - if not response.ok: - logger.error(f"Failed to fetch file content: {file['path']}") - continue - - # doc = process_file( - # file_name=file["name"], - # file_data=response.content, - # source=DocumentSource.EGNYTE, - # url=file.get("url", ""), - # metadata={ - # "folder_path": os.path.dirname(file["path"]), - # "size": file.get("size", 0), - # "last_modified": file.get("last_modified", ""), - # }, - # ) - # TOOD: Implement this - doc = None - - if doc is not None: - current_batch.append(doc) - - if len(current_batch) >= self.batch_size: - yield current_batch - current_batch = [] - - except Exception as e: - logger.error(f"Failed to process file {file['path']}: {str(e)}") + try: + # Set up request with streaming enabled + headers = { + "Authorization": f"Bearer {self.access_token}", + } + url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}" + response = _request_with_retries( + method="GET", + url=url, + headers=headers, + timeout=_TIMEOUT, + stream=True, + ) + + if not response.ok: + logger.error( + f"Failed to fetch file content: {file['path']} (status code: {response.status_code})" + ) continue + # Stream the response content into a BytesIO buffer + buffer = io.BytesIO() + for chunk in response.iter_content(chunk_size=8192): + if chunk: + buffer.write(chunk) + + # Reset buffer's position to the start + buffer.seek(0) + + # Process the streamed file content + doc = process_egnyte_file( + file_metadata=file, + file_content=buffer, + base_url=_EGNYTE_APP_BASE.format(domain=self.domain), + folder_path=self.folder_path, + ) + + if doc is not None: + current_batch.append(doc) + + if len(current_batch) >= self.batch_size: + yield current_batch + current_batch = [] + + except Exception as e: + logger.error(f"Failed to process file {file['path']}: {str(e)}") + continue + if current_batch: yield current_batch @@ -166,6 +285,83 @@ def poll_source( yield from self._process_files(start_time=start_time, end_time=end_time) +def process_egnyte_file( + file_metadata: dict[str, Any], + file_content: IO, + base_url: str, + folder_path: str | None = None, +) -> Document | None: + """Process an Egnyte file into a Document object + + Args: + file_data: The file data from Egnyte API + file_content: The raw content of the file in bytes + base_url: The base URL for the Egnyte instance + folder_path: Optional folder path to filter results + """ + # Skip if file path doesn't match folder path filter + if folder_path and not file_metadata["path"].startswith(folder_path): + raise ValueError( + f"File path {file_metadata['path']} does not match folder path {folder_path}" + ) + + file_name = file_metadata["name"] + extension = get_file_ext(file_name) + if not check_file_ext_is_valid(extension): + logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") + return None + + # Extract text content based on file type + if is_text_file_extension(file_name): + encoding = detect_encoding(file_content) + file_content_raw, file_metadata = read_text_file( + file_content, encoding=encoding, ignore_danswer_metadata=False + ) + else: + file_content_raw = extract_file_text( + file=file_content, + file_name=file_name, + break_on_unprocessable=True, + ) + + # Build the web URL for the file + web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}" + + # Create document metadata + metadata: dict[str, str | list[str]] = { + "file_path": file_metadata["path"], + "last_modified": file_metadata.get("last_modified", ""), + } + + # Add lock info if present + if lock_info := file_metadata.get("lock_info"): + metadata[ + "lock_owner" + ] = f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}" + + # Create the document owners + primary_owner = None + if uploaded_by := file_metadata.get("uploaded_by"): + primary_owner = BasicExpertInfo( + email=uploaded_by, # Using username as email since that's what we have + ) + + # Create the document + return Document( + id=f"egnyte-{file_metadata['entry_id']}", + sections=[Section(text=file_content_raw.strip(), link=web_url)], + source=DocumentSource.EGNYTE, + semantic_identifier=file_name, + metadata=metadata, + doc_updated_at=( + _parse_last_modified(file_metadata["last_modified"]) + if "last_modified" in file_metadata + else None + ), + primary_owners=[primary_owner] if primary_owner else None, + ) + + if __name__ == "__main__": connector = EgnyteConnector() connector.load_credentials( diff --git a/backend/danswer/connectors/interfaces.py b/backend/danswer/connectors/interfaces.py index e4ffe23568d..8a93584befb 100644 --- a/backend/danswer/connectors/interfaces.py +++ b/backend/danswer/connectors/interfaces.py @@ -2,6 +2,7 @@ from collections.abc import Iterator from typing import Any +from danswer.configs.constants import DocumentSource from danswer.connectors.models import Document from danswer.connectors.models import SlimDocument from danswer.utils.special_types import JSON_ro @@ -68,7 +69,7 @@ def retrieve_all_slim_documents( class OAuthConnector(BaseConnector): @classmethod @abc.abstractmethod - def oauth_id(cls) -> str: + def oauth_id(cls) -> DocumentSource: raise NotImplementedError @classmethod diff --git a/backend/danswer/server/documents/standard_oauth.py b/backend/danswer/server/documents/standard_oauth.py index b30e56a1033..56d02948c5f 100644 --- a/backend/danswer/server/documents/standard_oauth.py +++ b/backend/danswer/server/documents/standard_oauth.py @@ -8,6 +8,8 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_user +from danswer.configs.app_configs import WEB_DOMAIN +from danswer.configs.constants import DocumentSource from danswer.connectors.interfaces import OAuthConnector from danswer.db.credentials import create_credential from danswer.db.engine import get_session @@ -21,7 +23,7 @@ router = APIRouter(prefix="/connector/oauth") # Cache for OAuth connectors, populated at module load time -_OAUTH_CONNECTORS: dict[str, type[OAuthConnector]] = {} +_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {} def _discover_oauth_connectors() -> dict[str, type[OAuthConnector]]: @@ -93,9 +95,12 @@ async def oauth_callback( ) return { - "credential_id": credential.id, - "token_info": token_info, - "message": "Successfully authenticated and created credential", + "redirect_url": f"{WEB_DOMAIN}/admin/connectors/{source}?step=0&credentialId={credential.id}" } except Exception as e: raise HTTPException(status_code=400, detail=str(e)) + + +@router.get("/available-sources") +def available_sources() -> list[DocumentSource]: + return list(_discover_oauth_connectors().keys()) diff --git a/web/src/app/connector/oauth/callback/[source]/route.tsx b/web/src/app/connector/oauth/callback/[source]/route.tsx new file mode 100644 index 00000000000..815ca5a5a09 --- /dev/null +++ b/web/src/app/connector/oauth/callback/[source]/route.tsx @@ -0,0 +1,50 @@ +import { INTERNAL_URL } from "@/lib/constants"; +import { NextRequest, NextResponse } from "next/server"; + +// TODO: deprecate this and just go directly to the backend via /api/... +// For some reason Egnyte doesn't work when using /api, so leaving this as is for now +// If we do try and remove this, make sure we test the Egnyte connector oauth flow +export async function GET(request: NextRequest) { + if (process.env.NODE_ENV !== "development") { + return NextResponse.json( + { message: "This API is only available in development mode." }, + { status: 404 } + ); + } + + try { + const backendUrl = new URL(INTERNAL_URL); + // Copy path and query parameters from incoming request + backendUrl.pathname = request.nextUrl.pathname; + backendUrl.search = request.nextUrl.search; + + const response = await fetch(backendUrl, { + method: "GET", + headers: request.headers, + body: request.body, + signal: request.signal, + // @ts-ignore + duplex: "half", + }); + + const responseData = await response.json(); + if (responseData.redirect_url) { + return NextResponse.redirect(responseData.redirect_url); + } + + return new NextResponse(JSON.stringify(responseData), { + status: response.status, + headers: response.headers, + }); + } catch (error: unknown) { + console.error("Proxy error:", error); + return NextResponse.json( + { + message: "Proxy error", + error: + error instanceof Error ? error.message : "An unknown error occurred", + }, + { status: 500 } + ); + } +} diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index 3af19b779a9..81330bedf60 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -1053,15 +1053,6 @@ For example, specifying .*-support.* as a "channel" will cause the connector to egnyte: { description: "Configure Egnyte connector", values: [ - { - type: "text", - query: "Enter your Egnyte domain:", - label: "Domain", - name: "domain", - optional: false, - description: - "Your Egnyte domain (e.g., if your Egnyte URL is 'company.egnyte.com', enter 'company')", - }, { type: "text", query: "Enter folder path to index:", diff --git a/web/src/lib/connectors/oauth.ts b/web/src/lib/connectors/oauth.ts new file mode 100644 index 00000000000..37a0c579198 --- /dev/null +++ b/web/src/lib/connectors/oauth.ts @@ -0,0 +1,23 @@ +import { ValidSources } from "../types"; + +export async function getConnectorOauthRedirectUrl( + connector: ValidSources +): Promise { + const response = await fetch(`/api/connector/oauth/authorize/${connector}`); + + if (!response.ok) { + console.error(`Failed to fetch OAuth redirect URL for ${connector}`); + return null; + } + + const data = await response.json(); + return data.redirect_url as string; +} + +export async function getSourceHasStandardOAuthSupport( + source: ValidSources +): Promise { + const response = await fetch("/api/connector/oauth/available-sources"); + const sources = (await response.json()) as ValidSources[]; + return sources.includes(source); +} From b459f51a5bf279240eb03a758ee6b58f6dcc84d2 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 10 Dec 2024 14:35:02 -0800 Subject: [PATCH 4/7] Close to final --- .../danswer/connectors/egnyte/connector.py | 210 +++++++++--------- backend/danswer/connectors/interfaces.py | 5 +- .../server/documents/standard_oauth.py | 101 ++++++--- web/public/Egnyte.png | Bin 0 -> 12716 bytes .../[connector]/AddConnectorPage.tsx | 14 +- .../credentials/CredentialSection.tsx | 15 +- .../credentials/actions/ModifyCredential.tsx | 11 +- web/src/components/icons/icons.tsx | 15 ++ web/src/lib/connectors/oauth.ts | 14 +- web/src/lib/sources.ts | 3 +- 10 files changed, 217 insertions(+), 171 deletions(-) create mode 100644 web/public/Egnyte.png diff --git a/backend/danswer/connectors/egnyte/connector.py b/backend/danswer/connectors/egnyte/connector.py index 7cb2cfa2e6a..a420473daaf 100644 --- a/backend/danswer/connectors/egnyte/connector.py +++ b/backend/danswer/connectors/egnyte/connector.py @@ -3,7 +3,9 @@ from collections.abc import Generator from datetime import datetime from datetime import timezone +from logging import Logger from typing import Any +from typing import cast from typing import IO import requests @@ -27,7 +29,6 @@ from danswer.file_processing.extract_file_text import is_text_file_extension from danswer.file_processing.extract_file_text import read_text_file from danswer.utils.logger import setup_logger -from danswer.utils.special_types import JSON_ro logger = setup_logger() @@ -45,32 +46,26 @@ def _request_with_retries( method: str, url: str, + data: dict[str, Any] | None = None, headers: dict[str, Any] | None = None, params: dict[str, Any] | None = None, timeout: int = _TIMEOUT, stream: bool = False, + tries: int = 8, + delay: float = 1, + backoff: float = 2, ) -> requests.Response: - @retry(tries=8, delay=1, backoff=2, logger=logger) + @retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger)) def _make_request() -> requests.Response: - if method == "GET": - response = requests.get( - url, headers=headers, params=params, timeout=timeout, stream=stream - ) - elif method == "POST": - response = requests.post( - url, headers=headers, json=params, timeout=timeout, stream=stream - ) - elif method == "PUT": - response = requests.put( - url, headers=headers, json=params, timeout=timeout, stream=stream - ) - elif method == "DELETE": - response = requests.delete( - url, headers=headers, params=params, timeout=timeout, stream=stream - ) - else: - raise ValueError(f"Unsupported HTTP method: {method}") - + response = requests.request( + method, + url, + data=data, + headers=headers, + params=params, + timeout=timeout, + stream=stream, + ) response.raise_for_status() return response @@ -83,6 +78,83 @@ def _parse_last_modified(last_modified: str) -> datetime: ) +def _process_egnyte_file( + file_metadata: dict[str, Any], + file_content: IO, + base_url: str, + folder_path: str | None = None, +) -> Document | None: + """Process an Egnyte file into a Document object + + Args: + file_data: The file data from Egnyte API + file_content: The raw content of the file in bytes + base_url: The base URL for the Egnyte instance + folder_path: Optional folder path to filter results + """ + # Skip if file path doesn't match folder path filter + if folder_path and not file_metadata["path"].startswith(folder_path): + raise ValueError( + f"File path {file_metadata['path']} does not match folder path {folder_path}" + ) + + file_name = file_metadata["name"] + extension = get_file_ext(file_name) + if not check_file_ext_is_valid(extension): + logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") + return None + + # Extract text content based on file type + if is_text_file_extension(file_name): + encoding = detect_encoding(file_content) + file_content_raw, file_metadata = read_text_file( + file_content, encoding=encoding, ignore_danswer_metadata=False + ) + else: + file_content_raw = extract_file_text( + file=file_content, + file_name=file_name, + break_on_unprocessable=True, + ) + + # Build the web URL for the file + web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}" + + # Create document metadata + metadata: dict[str, str | list[str]] = { + "file_path": file_metadata["path"], + "last_modified": file_metadata.get("last_modified", ""), + } + + # Add lock info if present + if lock_info := file_metadata.get("lock_info"): + metadata[ + "lock_owner" + ] = f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}" + + # Create the document owners + primary_owner = None + if uploaded_by := file_metadata.get("uploaded_by"): + primary_owner = BasicExpertInfo( + email=uploaded_by, # Using username as email since that's what we have + ) + + # Create the document + return Document( + id=f"egnyte-{file_metadata['entry_id']}", + sections=[Section(text=file_content_raw.strip(), link=web_url)], + source=DocumentSource.EGNYTE, + semantic_identifier=file_name, + metadata=metadata, + doc_updated_at=( + _parse_last_modified(file_metadata["last_modified"]) + if "last_modified" in file_metadata + else None + ), + primary_owners=[primary_owner] if primary_owner else None, + ) + + class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector): def __init__( self, @@ -96,10 +168,10 @@ def __init__( @classmethod def oauth_id(cls) -> DocumentSource: - return "egnyte" + return DocumentSource.EGNYTE @classmethod - def redirect_uri(cls, base_domain: str) -> str: + def redirect_uri(cls, base_domain: str, state: str) -> str: if not _EGNYTE_CLIENT_ID: raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") if not _EGNYTE_BASE_DOMAIN: @@ -114,13 +186,12 @@ def redirect_uri(cls, base_domain: str) -> str: f"?client_id={_EGNYTE_CLIENT_ID}" f"&redirect_uri={callback_uri}" f"&scope=Egnyte.filesystem" - # TODO: Add state support - # f"&state=danswer" + f"&state={state}" f"&response_type=code" ) @classmethod - def code_to_token(cls, code: str) -> JSON_ro: + def code_to_token(cls, code: str) -> dict[str, Any]: if not _EGNYTE_CLIENT_ID: raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") if not _EGNYTE_CLIENT_SECRET: @@ -141,7 +212,13 @@ def code_to_token(cls, code: str) -> JSON_ro: headers = {"Content-Type": "application/x-www-form-urlencoded"} response = _request_with_retries( - method="POST", url=url, data=data, headers=headers + method="POST", + url=url, + data=data, + headers=headers, + # try a lot faster since this is a realtime flow + backoff=0, + delay=0.1, ) if not response.ok: raise RuntimeError(f"Failed to exchange code for token: {response.text}") @@ -252,7 +329,7 @@ def _process_files( buffer.seek(0) # Process the streamed file content - doc = process_egnyte_file( + doc = _process_egnyte_file( file_metadata=file, file_content=buffer, base_url=_EGNYTE_APP_BASE.format(domain=self.domain), @@ -285,83 +362,6 @@ def poll_source( yield from self._process_files(start_time=start_time, end_time=end_time) -def process_egnyte_file( - file_metadata: dict[str, Any], - file_content: IO, - base_url: str, - folder_path: str | None = None, -) -> Document | None: - """Process an Egnyte file into a Document object - - Args: - file_data: The file data from Egnyte API - file_content: The raw content of the file in bytes - base_url: The base URL for the Egnyte instance - folder_path: Optional folder path to filter results - """ - # Skip if file path doesn't match folder path filter - if folder_path and not file_metadata["path"].startswith(folder_path): - raise ValueError( - f"File path {file_metadata['path']} does not match folder path {folder_path}" - ) - - file_name = file_metadata["name"] - extension = get_file_ext(file_name) - if not check_file_ext_is_valid(extension): - logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") - return None - - # Extract text content based on file type - if is_text_file_extension(file_name): - encoding = detect_encoding(file_content) - file_content_raw, file_metadata = read_text_file( - file_content, encoding=encoding, ignore_danswer_metadata=False - ) - else: - file_content_raw = extract_file_text( - file=file_content, - file_name=file_name, - break_on_unprocessable=True, - ) - - # Build the web URL for the file - web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}" - - # Create document metadata - metadata: dict[str, str | list[str]] = { - "file_path": file_metadata["path"], - "last_modified": file_metadata.get("last_modified", ""), - } - - # Add lock info if present - if lock_info := file_metadata.get("lock_info"): - metadata[ - "lock_owner" - ] = f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}" - - # Create the document owners - primary_owner = None - if uploaded_by := file_metadata.get("uploaded_by"): - primary_owner = BasicExpertInfo( - email=uploaded_by, # Using username as email since that's what we have - ) - - # Create the document - return Document( - id=f"egnyte-{file_metadata['entry_id']}", - sections=[Section(text=file_content_raw.strip(), link=web_url)], - source=DocumentSource.EGNYTE, - semantic_identifier=file_name, - metadata=metadata, - doc_updated_at=( - _parse_last_modified(file_metadata["last_modified"]) - if "last_modified" in file_metadata - else None - ), - primary_owners=[primary_owner] if primary_owner else None, - ) - - if __name__ == "__main__": connector = EgnyteConnector() connector.load_credentials( diff --git a/backend/danswer/connectors/interfaces.py b/backend/danswer/connectors/interfaces.py index 8a93584befb..9dd18acea8b 100644 --- a/backend/danswer/connectors/interfaces.py +++ b/backend/danswer/connectors/interfaces.py @@ -5,7 +5,6 @@ from danswer.configs.constants import DocumentSource from danswer.connectors.models import Document from danswer.connectors.models import SlimDocument -from danswer.utils.special_types import JSON_ro SecondsSinceUnixEpoch = float @@ -74,12 +73,12 @@ def oauth_id(cls) -> DocumentSource: @classmethod @abc.abstractmethod - def redirect_uri(cls, base_domain: str) -> str: + def redirect_uri(cls, base_domain: str, state: str) -> str: raise NotImplementedError @classmethod @abc.abstractmethod - def code_to_token(cls, code: str) -> JSON_ro: + def code_to_token(cls, code: str) -> dict[str, Any]: raise NotImplementedError diff --git a/backend/danswer/server/documents/standard_oauth.py b/backend/danswer/server/documents/standard_oauth.py index 56d02948c5f..da66b0715b2 100644 --- a/backend/danswer/server/documents/standard_oauth.py +++ b/backend/danswer/server/documents/standard_oauth.py @@ -1,10 +1,13 @@ +import uuid from typing import Annotated +from typing import cast from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Query from fastapi import Request +from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.auth.users import current_user @@ -12,8 +15,10 @@ from danswer.configs.constants import DocumentSource from danswer.connectors.interfaces import OAuthConnector from danswer.db.credentials import create_credential +from danswer.db.engine import get_current_tenant_id from danswer.db.engine import get_session from danswer.db.models import User +from danswer.redis.redis_pool import get_redis_client from danswer.server.documents.models import CredentialBase from danswer.utils.logger import setup_logger from danswer.utils.subclasses import find_all_subclasses_in_dir @@ -22,17 +27,21 @@ router = APIRouter(prefix="/connector/oauth") +_OAUTH_STATE_KEY_FMT = "oauth_state:{state}" + # Cache for OAuth connectors, populated at module load time _OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {} -def _discover_oauth_connectors() -> dict[str, type[OAuthConnector]]: +def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]: """Walk through the connectors package to find all OAuthConnector implementations""" global _OAUTH_CONNECTORS if _OAUTH_CONNECTORS: # Return cached connectors if already discovered return _OAUTH_CONNECTORS - oauth_connectors = find_all_subclasses_in_dir(OAuthConnector, "danswer.connectors") + oauth_connectors = find_all_subclasses_in_dir( + cast(type[OAuthConnector], OAuthConnector), "danswer.connectors" + ) _OAUTH_CONNECTORS = {cls.oauth_id(): cls for cls in oauth_connectors} return _OAUTH_CONNECTORS @@ -42,12 +51,18 @@ def _discover_oauth_connectors() -> dict[str, type[OAuthConnector]]: _discover_oauth_connectors() +class AuthorizeResponse(BaseModel): + redirect_url: str + + @router.get("/authorize/{source}") def oauth_authorize( request: Request, - source: str, + source: DocumentSource, + desired_return_url: Annotated[str | None, Query()] = None, _: User = Depends(current_user), -) -> dict: + tenant_id: str | None = Depends(get_current_tenant_id), +) -> AuthorizeResponse: """Initiates the OAuth flow by redirecting to the provider's auth page""" oauth_connectors = _discover_oauth_connectors() @@ -58,17 +73,30 @@ def oauth_authorize( base_url = str(request.base_url) if "127.0.0.1" in base_url: base_url = base_url.replace("127.0.0.1", "localhost") - return {"redirect_url": connector_cls.redirect_uri(base_url)} + + # store state in redis + if not desired_return_url: + desired_return_url = f"{WEB_DOMAIN}/admin/connectors/{source}?step=0" + redis_client = get_redis_client(tenant_id=tenant_id) + state = str(uuid.uuid4()) + redis_client.set(_OAUTH_STATE_KEY_FMT.format(state=state), desired_return_url) + + return AuthorizeResponse(redirect_url=connector_cls.redirect_uri(base_url, state)) + + +class CallbackResponse(BaseModel): + redirect_url: str @router.get("/callback/{source}") async def oauth_callback( - source: str, + source: DocumentSource, code: Annotated[str, Query()], - state: Annotated[str | None, Query()] = None, + state: Annotated[str, Query()], db_session: Session = Depends(get_session), user: User = Depends(current_user), -) -> dict: + tenant_id: str | None = Depends(get_current_tenant_id), +) -> CallbackResponse: """Handles the OAuth callback and exchanges the code for tokens""" oauth_connectors = _discover_oauth_connectors() @@ -77,30 +105,35 @@ async def oauth_callback( connector_cls = oauth_connectors[source] - try: - token_info = connector_cls.code_to_token(code) - - # Create a new credential with the token info - credential_data = CredentialBase( - credential_json=token_info, - admin_public=True, # Or based on some logic/parameter - source=source, - name=f"{source.title()} OAuth Credential", - ) - - credential = create_credential( - credential_data=credential_data, - user=user, - db_session=db_session, + # get state from redis + redis_client = get_redis_client(tenant_id=tenant_id) + original_url_bytes = cast( + bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state)) + ) + if not original_url_bytes: + raise HTTPException(status_code=400, detail="Invalid OAuth state") + original_url = original_url_bytes.decode("utf-8") + + token_info = connector_cls.code_to_token(code) + + # Create a new credential with the token info + credential_data = CredentialBase( + credential_json=token_info, + admin_public=True, # Or based on some logic/parameter + source=source, + name=f"{source.title()} OAuth Credential", + ) + + credential = create_credential( + credential_data=credential_data, + user=user, + db_session=db_session, + ) + + return CallbackResponse( + redirect_url=( + f"{original_url}?credentialId={credential.id}" + if "?" not in original_url + else f"{original_url}&credentialId={credential.id}" ) - - return { - "redirect_url": f"{WEB_DOMAIN}/admin/connectors/{source}?step=0&credentialId={credential.id}" - } - except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) - - -@router.get("/available-sources") -def available_sources() -> list[DocumentSource]: - return list(_discover_oauth_connectors().keys()) + ) diff --git a/web/public/Egnyte.png b/web/public/Egnyte.png new file mode 100644 index 0000000000000000000000000000000000000000..54eef07dc2829b473dbf1733667e10736c110cb6 GIT binary patch literal 12716 zcmdUWWmr^SwD$o-1SFIabU;NyLX>6@5J8acE)_vxXohg85d#GoB&54Tx;rGKn-Qr| zazILnk@DT+|NlPsy`SDM_ul8Ze89swd!N1cIxBwbx0Y|UHI-?o*r*@~qEWf8s0%@) z;F1(Na~k}!<2`l&L6;#FMR|Si#Pu03-w%t&f405Oe?I?Qz3N>_8x`kEfoOU|ZnL$g z0gjal1(iU@3@q-B?takJazXjqabDY_TfA|hyIn1?gJhv z?v2S|$|tBe{gN2>J)FhZ6fezmi3Humqn}!xhmg}De8{J z6@QH=5@(`?E~^|Ujd(@$4EXw8v1`tpi4afu4Y%hr*XVJsA6?l+ur}6}E2{RqM;2!4 ztH2h~fv2Eco!yJ2_f;GRw)bW3GwH4gMj_j!hZFnvNQbmXcvY+yOm|-Nb&^0e zCg@DfCrwTtkh(>}Hw{M4s!;UZ6X731LS&>|HmdQ0cy|>cl~qCJ0H{R%y?OI5tB9RA zbA)`ffH|GxfzTxZqr)k_JZ;?~;~kO~TY0xaa`y$(i%uU9g(ZPBN3Y8Q1P^+ z(3a2J2*Yl=iAz>O5a%L%Vpxb{2!bSa-auL}} z|3~o8i$BLnUCJ~6F!1~)f%+(;cdE=05rf+uq&Ia(<~lz4L(nhM1=C2SI9j(n-$H07 zVn=IxfCIu^l8=~;K!pBeZPJHKyFHq@rjhgzQ}>Xnm!LmXIam3qJ^I;?W&Z%T^d`?dgxXBZ-yyHyJ(bx~7zXM)Z!%qpcW~VE3#L<-I^L4wY*qTzgCM9yBVOeo z4^J<*P(e!1iZ0CXy~G<1Ey2&t4GAwtPX8AF7U}g&=g$v^~6K zGEpC5>v_0-27)Bm`0~Q>^yA=;Qg2$yhf`1sjeU3Jgj8T86?&e62fXvurSb`>EJ#sR zRRji%lTPh}g6*_q7IZe6@iF`q1hF$;vwMT5k8r2DBR5I{L0b|2*<|+zu0sxg$RLMs zdpLUG9jUgFE(LLiTM@Kwg?>s>X;8%n#s|ddmgg_$h2ZHmd7*W63UI7mcmnMm@S2uG z(>Pk(OVXhUBya}}AG7cf7$f-4nD}_&BYGP+cw}koyY7CM#HBs_4Fq|IZ&wadlBZVT z!4q$jlE0IT5{O5S8NUTJDAwFQBMYLtx&P3!VK9m6c=<>vfy zf*E*jqiCR6c@pTTeLfF=a4+Wudt>KZm2aj`jKI$oG02;u`hR|}|Fe?$FIf#Udil2b z%p+@!E@(Pk*@3L%bx;C^9xvI#R9oPao^7+7$Lky= zY}OfvAvYMN3tO+n=DJo$!fkMJzIm+ z%FJG?eEwoQ-LkRQKeU2dLrh9U*?OwV%UT%MRznUIzd1uWoY8GpcctltJ-Owx&9gvA zUdXkym$Ak&HOO0F0|nnXGek?htahD7#srJ86$9g_?XS;QCri?jO)Kv%c7@e1dOVK@ zVhme556kF|lRB{L(sVL88l0%M6Vd4I8UHNqICD_f^820rsG(k&+jJGdr~q~$H^pl; zy*kJ?PYwCb%<7^f(1+WrO9V`Oa#Bv_>(JYtp|(a=g+`|zuf)yzF`YaAy>zX`>D-ko z2kt5UD4UkcVukLNT}6>2HE*+R7gsy?V7EPK?tEv4ZeQaQ)epDE$x@73$o zMFhp_Z-26Pw7>A|Q+L42ck7G#I?+C^)R5CT7<2Ypyq{Rx6}owH@>46>w6O6je5(6G z79V%Hg4C4*=~Te}26r|JoM@jd!&uq!Df95F=NU36Vs7Zz*}THv-y-_~;eu_&XMsixZ)vKR%njm`{r{Ay>bnT&e2m1g?o1u2@}&FWFa>;+MM*d^Zx zlim6p2@znFQQd)()hi`8j|HvYbmnAUZ)>f_!0&ZERW&Ty*ccu~G7==6NN{ECuY=ZV zn~uf$lral5??ux!%qqj)jl5*)-I>iviN&0vU>$3C1&r7o87^ZsgL@z|&414#-5R+$ za<8wm#pWRe6yxwT#4tWmdgKZ7pK_u-CcJpP+SGn+R0l^-xAP-QLijbI?|VSikD3X- z&#<)C7|9^$Jug(FNXL`468f&yp5fd>c#loN#Y-P4A|iJ8hSW2(ygq&WvgM+|9$k2d z+8Y^fZOxuHuc|s1juNqc!>+0-4kMR+O^OZ~k-e&`C*qFudZhIVTjU7Gn4Lwi=A_@| zrvA!XZj-n~%dsf2nOK}}VYYvb+fLOM!=`QT8V=S@O{|$f`4~r^7I%IK51#56_tIeY z6_pxCGJ1DgU$;YXT9=rP7dV4QHlogHwV#x)^AuI$dACSzpL)Y46CPh=2It|G? zcDPQ@AS`XF&U1gQi&de6p+b*6p z^3B|onLe|b3gMKMl78FZH(Ya?QlHJBP&Rj7pTy2~AsK7r74N@E#?KfPpZ%;|ULYgU zI+%AL?wIW*?0z98)E!t{6U({t=|oA7Vv-myJ?*MLai;{1s<;kma991~hFRTxGgUPW zcRo6+Gi=rcBz03(1G79k7Bg~rx{y5*yUh{A62>CM^aZzky ztk2wH4$~djv;^#1RE)pt0O+gBgZEDB5vf3id3wkSaQZV;&`zZ1HZa2sq`%3@{Z9o& zBOzGYu}pugu*ndG;_H|8y3Hi0razF_ndTnkrLMHG?+(+XCC1=j+?knpo^2QtBmDz7 z#=4f9+Edq8zhc9x?*2tpRdwYyDP)*q>J!wy0Wca3W!H+=MiyBr#}v6n8A>RJd!;>L z91<|f*V>T4xEQzL;2u6l(ci?fqFp4CoTtz#Cq|-30=BH&!^gK%8kAzhTi~*E;p?2o zdXaQJ51(+9MSO@3|4nTPua(Tba2?bG{3FNP>Uovn+tq&Oiq+wc@E49wS}WPobv%6V zcirAG^A(w_(q1cRHhne{f^boF2A}=$Epf;GKf=FsvY43SEBp&BCo5W!jB8PIo;kyR zCdA1j#^w*tY*Cyc&4&U9KlCh&j9b1tU#CK6RmE@&Zv~p5TtC4L5+xtPF;-74 zwpN2yTbOz6uH@agRf@{=zTRwJtbO?AzRZVNWnf{}@Zh;Gzs8N?OVW++ z9qz?qNI!icK?m=x50{vXiaF0&j_MqBq7m^3H`}V$U6r?B9Zy0XW(*@m%1w@Dp3XKk z_9V(Ab#O%RLXsH^UvACZEY@Z)Djplo=NolBy8EkccW1Rl9t98W4vbk1!~5;@pwEgG zn+<}~yHw<>5{J_KqI#yyG&j!-o%EI?j1UUlS34@}HmdIu@jDCJEfpp8^rm!w4|9J+0WIsBnxt9XNG#R}Qj@fv(O|>o zWdIZ5&3nix(-d(sh^XOu+3C85?!hm^(-Zim_ufq)HAmJ@0;Si zO$2i=5^SRz1d7KAYIJKkze08nloRwlRx|OE{%Ly1cK{pXH`q_4O}3KRWyHXVlG3e# zwo)79Xs+&tiQ~bHirSqLyyTYrG-->?bvb`?IsvnfVia^M;Zgb)?j9DS;l7stIh^{t zKTrd>=43jjxy78!1!VNNb0ev)sN_zF?BN(trEnP5l)RQxH1nR?s7}$B zA?8B;_tJqKbD_lXX#hMp3dyCPbGJo+@21X-jQqc0Q}gFcci=qgpO|~`oE<#a>OmmK zbf_8RlF(&%1N&7WNc0LeajQhPP_O*lFy{uG#;QlaF=D`CZ>>4zfgRI)s`^HHyv{t z;7nQ>P>Xf4MZKvt)vsMwR)($VE0VnTn>ALlrD{3HhUFvKq?PXF-E_|Mj~voym6y++ zC%h`Ov8l6Zy)9GkNsQXojF9(O{tF7x+X2zGKfbm;mX}ZgxDdqGrw&tl*%>Q!p6g;< zB1U(EMQIN@2rrwIRIo%CWIncxUaDnax*P==qB;2MZ%5NvJe#c#F+AE(XCjE z`5ge|oZWeKx0*N^j;3Ey`frI-K@HicGF;y)$E|o^ zWlQT*1z?QH(YL3aa-$5}J=Zi}w35%n#AeQUt{2q-s0t*TAIY=3w4?zSBxW*aw_PJM z3=9rZLf2TozX;LwpGs1~J^=d9a~tRzx<^wie4l+;^0G0;SnWx(x%X`xRYZd~TWy(^#0O7eIyJV_z zgx>fTA9$HN_%s1+w9rx80&Ge&`%QP?)+O-%Ic*3hkj&v*9&p4X*D#>y|tjHRgnI3)=f`IaOV^zmjj3O;3-UH)wvtTv?$JqV4k?m~_p zT|IW)rGgda3&SXAUH>D*8o2$4#Y>{GU|^D5p+!dNtKVPSl5M-gh{uY(-3lbk(=WBq<=~0^pEz@DFFYAIl@6IEu%p*T8xptz2ttaBWl?d{~xB zLyTxEQegZwDnN&S^{!F0rv9}%F(C+dU5x_pwmdRNYi-?;1JRf0%S|A80&GGuRNxK)#vVs?ecUu{T5_6Oq{SSe|2hQetjp`$I>3 zX~-4XI~rd|pc@aCHZ1Czo?>KizrJIt3BLwKt8S)raMYY(gm70KK%~%l^&90s!^TYw z*Y0@BKm=kmLa`mEns(7v%N?cvr?B% zdK}puXaxK-`Uw!;7W(IW*Y4pxj;G_#$xt;$!ED%vWweVCD(`h=e02-*ZB(N?;N zk9j2d+>s>`7zKKsBQ5G8pzG6@p$HXfgZt4DwJ5x6-*VgM-lrm{coeEcfm8b_W8L zObjB0BtOS1P>2CP5qz`~-4b|w2$q%v6rxKk089O|<-N$W)p6AU#j5;#>Rr$l86%)i z1r&*_^4Ld8+h946Hl)!K4s^#%WmhcP;0RqNy`Y7{VG( z+wfR@$`f_9u-bNMftcT9&qJNSdb21%`hSu2#kj1nDhTbUYf=X5oL4ov0)qgdHc|ta zM%uIbIV%^4yj9~{8I3N|cd1YdA5Gh*-GLH^8EkpX>KN_G$)MR-K~C2w>Z1B-XM#1# z3XoB|`+sVkTuRv0Rfo+Kgqbd~JKbc(`7uJqoFLNgz4plR~2O$y6VbAXw=rI;EP zyT_1O5_YJ8=3WP;9DqH+M<&gHJK`m5aIhkw4WT*20kP-8rE3QY{oZC_WN!)Jl z0Hv6}R>GSzAgePD(QK|yt=zAdn^-0H7lMs2UX@8 zZASRpwutou;xIugsHIp3!g*c2dw0iSvkCR=&g*y!zui{>xJzZG9RR@IhHb0?3CM1W zaIlOtILS1^UEU{u|E5j24mPQ}ymFyYOjND;M`9|jh1&V|#DdohD)-5Y*G{UZoi9$O zX92*q8*Ir*L}Mx|+35}8tBg$R*^kj8#0cT!Feg7+=Jw%h&090eVj$|1k?Q93y9zTu zQCrCX0)%R59v6q9EZ;_TYV2s*Vu}#hS;1loBI=M8627J{w9AppazN*D_{j%=>m4MD zqq>Rv_GaAd@tCY|Phs(7jV7$|OF$|Seu4u|AACgvR7$8|FQpL`lQxoQ_tFLopK{Ty zG7agO1Q3vEF_w(dmmRsHTt{QHQaH}sp+ zDPgqKqX6|y9(a`G4rw3fX~~W6k2LEbjbqu>G_VC^!rq#vWI^yri2(S*cKO{9U<4Zd z*Fj>m^VBUKTO1Pd%iG?W%llaj_sz?9tfx8Y~Uu+7T|e65X!q_7{azmID9#Z`e0Jr2m24R=@S|p;9@-@&cf&UI##>HR!N?9g!kK z5wk=9SH)mtZb=7{fL>Pacl!JmQ7H9{Q;Z<2b@^1oD^jS7EhRA@D{`8@$iE*X>-X)Z zUvfTz;l5d32VD0TD4H^wz-IzYs3lb}9Gy3r}9OJtHL9^LfH?SJM!4AXeXrjAAoL zx}DeTJo}Ng-ZLKasPOmU{%T>N+0INJAV9{RKKR8PFZZYdkni5PJc>a>=*K8;Q3Lj>-~t7%Z$rl|{TZg<<4R)aNSo zKz>`XLrg*QbX>IG`%hKYQ@JVUX0(Qtk&g31H&=Pqzl#a-W5em(k_x}KQ_3_ug(@*E zAwbX=b=@MH!bDw{2*|datg-ySqZy0Q%h>s+LXa_4Ap#?hcHdwy+(H1V8n}}Dly)!u z_umGctv+uQ_lWsFD+Q4n0jIiw3t#N5SQ?VoxUZ=fg4$K@k}A&jD4OUc6CNwj;gqSg zfAtSb#tgUuORS(ca5js8$4m91+j_W9|uDHPoKs^y+d_ z-HO$xvT-hWC-Rapfa%2t7T68I@)alsr6lragjMXv{px2YVl3( zL9bc*9&PV)Zf?xwOdrL?uWxBNojQGLg=by*si|S|yw_(O2SBR1J`6s@lm{+##Qwy@ z<`uVqRkqojk5SeKrtJBaep(~_m1XKB{5Dg^LC-h3X^+s*alo$~H<{4^{(Bh!>##R8 zbXco{9V6QU$x)pfPuJ64)CJ(LI`Uq>1tVwQYkTd%ck+{(=qNniz7>NlH>?J|Tp5lM z^}xIxK2~EXs0h@h+6QRDNdknM{%#na(D7#hKZ-)EbpcWU-7$9+K?@uEELFZcfN_L- z8~}<{jBU2A@Ta@-Q|({OfGVN`e#$UxM|GGONV%`J0d0rD_c&&LxBAq%4*b@-ea-G?i9?653BA)@k`@qWu}XkW6T`(X z0kh0qT~KK}8Rn)*5VIhX&Rx$HeDzVr5`9S-0S=Qt2;G6iieSUHjHtz}DZ~~HBV-<) zGlE(hfBs^up~yYV;zOnuhJD;|Q)uijF-=Or2kaJ*mD*muTgmf0)KwpitcW(-r8}ao z{5&iJ!j&O|gAH|N?;mrgCCu{>1c_`U$U0f;y@1LEC6u|>7YXrcmww`*em#+&HRrLG zHYxRqX%k0=!8w5C%htYDVx{AxRPCnQiB2a`GEHR-Yu}zYslN5;-eJv0^~v8NyK)z? z1>k5vLSs30B8`!889I(%!TUMdrJ$OppelwEmC_(asKPr>yt;__4kdrEL4|V|khIP` zNX>i!&?#IU2W)ghyv2GvNX`KrVBQ$f9a!@e>>{e@R%!f~`r-&SabehEAQ5-hGY7QL zbC{PN$<$XKe?CX7Q9P$T5%4F{@A$-2jLtC1W%9h^4H*acd0%w-TxaV!C$*LAJd;|C z)lmTt0dSe`J2;%mGmhSLGBT|QA+RfdA*x=+Sf>Wy?35*{HXi^s4I+n@nV8t}jVKJ;f+y&qGg$^ixfMUYfyC1iS z{6h=GGq@FLyYV%zAu#S=dULR`J|Ly+$)GlHG(b^!c`x6Q==2&J`(f@8=J)_kdnuR{ z%~4w)ga}ND49dE%`+Hk8)T&9n*<3Q^;q|TfPfSb)llAVqWBE*&6oD@woCC&g)ytwD zcz=rTSMK-<0caYqI%)+Vnuc<1Dv@@Xi@0ZM6UM9w^?=<<89J zzrt-uCt%T6gYXT5LHFHYx8Q)Dm}mL(|0H|N0T+w__pypu$0Q3}xNN7GP!*pfeV3J1 zuI%Eh0tp(_ttJv!6~@G`jsVR8)#mJaSzjwJ-_L+*Q8uy3@;^Qy1!w5nu1<$5-ouB= zD|mpq9!Q8GDs;zJCB>)qwwMhS^$`Qw6kPRDuMjcMadw`whf}#|7CU@Rz4huhk-f+B zisk=DDUF-j6hrzT50DrK`y~#_98OnNh#X~Te*$u};V!Mj^LfYH=ILu8R67A2&cl>J zTFjRwgf;Iy%ZQo;%2t&Q8S{erEW_v_{J>jJ<={txGxG3|npi2gqi+0RE+mBkB19}r!M$Fk?>E&t4)?Uaob50iku9;xEA0hxi( zV>OdJ;U|Mb1qtGkZe_UUYUgo~JF9o0uP{O(GP(`=9IzX=VDmnQO(&|=P5VWi#k`99}EmvXSqTyxyDb{AMV?&e8lrj_Yy1# zIVL`ycpea@WGF6~SQzTGmW!oAU!3~WUZ9?Rf?H$%WP`l6$RulRdEj9n$In;0 z^QtUpm|nvsH@Vka&Q`RpVPE~y1{6b`bNAI^j8szPlvt}Xu~WiRqeX?7cKT0W;iA$>eu_Q%yh|~zk518 zOc%GTgTGlYLPs$&eU$yntRTJ)th?=>v$*H8 z55hdQ=C{GOHLzB=a$e{g;n; z)Otp!m~4uk$Ir}U<685nM(+|9Uj}lBMbu{5v$1h~-u+~zsknKhc+M&n2>6#9UyCN) z4YPCrVRN-G$v$PcxoU@rzL{CMuNRwwK=u4*d(Wcj@e2h-fZ}o@hSL=>^6UIWKMC6mg z-EG%i?IH|ZzrUnCS(dAF#$O#R#VOdw(L!^jc$c^A7ESt~3uSHQ_)xuzWBT^WsTKxZ+H28PYmw(z9s|Np0TK{K~&p(eI{_Sk6n*94HDMB-WmIQP# zCNh)Xi9qU)i2u_{Ss2kI4r`+X_kuNut&?ie3AC4`av-n;yfggeLL#Y);xy>)CNwj-wcKg}jhpCe8WJ$@rRWlD;djJxn9vM=drq(LD3VC# zkVP!T5myiDi$N5+pWPX7p zHTzGl{@Z^VvC|wO`?rHs;hyw&%7<))qxTHd=xfIoMGq;# z?#c=yPHf@rwUU)k;(KHa9`x#(r`0bEi8v*iLf7n`XXgFWO`F|Y87&nNNt~%kD<#4T z2wfM&f17YsuO0Ug7q~H8%movLrE_ZvK@kY*MUKBy!%wbIB>fJMZPJFW;XPG2h{tRp zcL+!HyWO_B4s~4%jY57Q_6QHmmt=l(D>T6@^ON6kOUg(T%#KKQSM%802B-V>;?57~Kgdaav=1)KZty`^cEp#8Wij#4Dhy$(D zB}#djh((;WztHPq7D&aEJ1-}*lGr7EG;6A&5-|}=>y~wt24={r;3JY*=qN~`>CK8v zmIYJ99_zq&X3(E~4OAik+finhOv*5AI%`7DLkNXZ9)=s=%TQJo-!Zb6Z?SFA`eHM$}7WrfXB(@KF z%asDn5pU*x?t^(e*FZ;lXy%;i*N5V?P|I`Ao=T+(AZW>l;+W-0Z73fWvy-1m8&P(R z=}qJ*(AUnhXi8nSSU315|7hh(Gj*N~98HpQjVbB3{WZIaOq~c_E^^jBa%ig|GQ#kF zyywyydQx)fi!n6dd1oZ#Nz+v>o6jC~dd@2AT;#DKry-XYqY+hVrhIrZ6>2Wnt@zJZ z9E6}J(pRxvGWl|{##+Sx2;drFdf)!)byeMpOnQ}XmRgjM { const redirectUrl = await getConnectorOauthRedirectUrl(connector); - window.location.href = redirectUrl; - - // setCreateConnectorToggle( - // (createConnectorToggle) => !createConnectorToggle - // ) + // if redirect is supported, just use it + if (redirectUrl) { + window.location.href = redirectUrl; + } else { + setCreateConnectorToggle( + (createConnectorToggle) => + !createConnectorToggle + ); + } }} > Create New diff --git a/web/src/components/credentials/CredentialSection.tsx b/web/src/components/credentials/CredentialSection.tsx index 99c97ad589b..b20f2ea8185 100644 --- a/web/src/components/credentials/CredentialSection.tsx +++ b/web/src/components/credentials/CredentialSection.tsx @@ -28,6 +28,7 @@ import { ConfluenceCredentialJson, Credential, } from "@/lib/connectors/credentials"; +import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth"; export default function CredentialSection({ ccPair, @@ -38,9 +39,14 @@ export default function CredentialSection({ sourceType: ValidSources; refresh: () => void; }) { - const makeShowCreateCredential = () => { - setShowModifyCredential(false); - setShowCreateCredential(true); + const makeShowCreateCredential = async () => { + const redirectUrl = await getConnectorOauthRedirectUrl(sourceType); + if (redirectUrl) { + window.location.href = redirectUrl; + } else { + setShowModifyCredential(false); + setShowCreateCredential(true); + } }; const { data: credentials } = useSWR[]>( @@ -150,9 +156,6 @@ export default function CredentialSection({ title="Update Credentials" > { - setShowCreateCredential(true); - }} close={closeModifyCredential} source={sourceType} attachedConnector={ccPair.connector} diff --git a/web/src/components/credentials/actions/ModifyCredential.tsx b/web/src/components/credentials/actions/ModifyCredential.tsx index 328ed00873c..54ca33bc657 100644 --- a/web/src/components/credentials/actions/ModifyCredential.tsx +++ b/web/src/components/credentials/actions/ModifyCredential.tsx @@ -144,15 +144,12 @@ export default function ModifyCredential({ attachedConnector, credentials, editableCredentials, - source, defaultedCredential, - onSwap, onSwitch, - onCreateNew = () => null, onEditCredential, onDeleteCredential, - showCreate, + onCreateNew, }: { close?: () => void; showIfEmpty?: boolean; @@ -161,13 +158,11 @@ export default function ModifyCredential({ credentials: Credential[]; editableCredentials: Credential[]; source: ValidSources; - onSwitch?: (newCredential: Credential) => void; onSwap?: (newCredential: Credential, connectorId: number) => void; onCreateNew?: () => void; onDeleteCredential: (credential: Credential) => void; onEditCredential?: (credential: Credential) => void; - showCreate?: () => void; }) { const [selectedCredential, setSelectedCredential] = useState | null>(null); @@ -244,10 +239,10 @@ export default function ModifyCredential({ {!showIfEmpty && (
- {showCreate ? ( + {onCreateNew ? (