Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed Dec 10, 2024
1 parent 42bb4aa commit 7d51171
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 15 deletions.
12 changes: 6 additions & 6 deletions backend/danswer/connectors/egnyte/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.file_processing.extract_file_text import get_file_ext
from danswer.file_processing.extract_file_text import is_text_file_extension
from danswer.file_processing.extract_file_text import is_valid_file_ext
from danswer.file_processing.extract_file_text import read_text_file
from danswer.utils.logger import setup_logger

Expand Down Expand Up @@ -99,7 +99,7 @@ def _process_egnyte_file(

file_name = file_metadata["name"]
extension = get_file_ext(file_name)
if not check_file_ext_is_valid(extension):
if not is_valid_file_ext(extension):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return None

Expand Down Expand Up @@ -170,7 +170,7 @@ def oauth_id(cls) -> DocumentSource:
return DocumentSource.EGNYTE

@classmethod
def redirect_uri(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
Expand All @@ -190,7 +190,7 @@ def redirect_uri(cls, base_domain: str, state: str) -> str:
)

@classmethod
def code_to_token(cls, code: str) -> dict[str, Any]:
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_CLIENT_SECRET:
Expand Down Expand Up @@ -342,8 +342,8 @@ def _process_files(
yield current_batch
current_batch = []

except Exception as e:
logger.error(f"Failed to process file {file['path']}: {str(e)}")
except Exception:
logger.exception(f"Failed to process file {file['path']}")
continue

if current_batch:
Expand Down
6 changes: 3 additions & 3 deletions backend/danswer/connectors/file/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_session_with_tenant
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.file_processing.extract_file_text import get_file_ext
from danswer.file_processing.extract_file_text import is_text_file_extension
from danswer.file_processing.extract_file_text import is_valid_file_ext
from danswer.file_processing.extract_file_text import load_files_from_zip
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.extract_file_text import read_text_file
Expand Down Expand Up @@ -50,7 +50,7 @@ def _read_files_and_metadata(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), file, metadata
elif check_file_ext_is_valid(extension):
elif is_valid_file_ext(extension):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
Expand All @@ -63,7 +63,7 @@ def _process_file(
pdf_pass: str | None = None,
) -> list[Document]:
extension = get_file_ext(file_name)
if not check_file_ext_is_valid(extension):
if not is_valid_file_ext(extension):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return []

Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/connectors/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def oauth_id(cls) -> DocumentSource:

@classmethod
@abc.abstractmethod
def redirect_uri(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
raise NotImplementedError

@classmethod
@abc.abstractmethod
def code_to_token(cls, code: str) -> dict[str, Any]:
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
raise NotImplementedError


Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/file_processing/extract_file_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_file_ext(file_path_or_name: str | Path) -> str:
return extension


def check_file_ext_is_valid(ext: str) -> bool:
def is_valid_file_ext(ext: str) -> bool:
return ext in VALID_FILE_EXTENSIONS


Expand Down Expand Up @@ -364,7 +364,7 @@ def extract_file_text(
elif file_name is not None:
final_extension = get_file_ext(file_name)

if check_file_ext_is_valid(final_extension):
if is_valid_file_ext(final_extension):
return extension_to_function.get(final_extension, file_io_to_text)(file)

# Either the file somehow has no name or the extension is not one that we recognize
Expand Down
6 changes: 4 additions & 2 deletions backend/danswer/server/documents/standard_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def oauth_authorize(
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
)

return AuthorizeResponse(redirect_url=connector_cls.redirect_uri(base_url, state))
return AuthorizeResponse(
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
)


class CallbackResponse(BaseModel):
Expand Down Expand Up @@ -119,7 +121,7 @@ def oauth_callback(
raise HTTPException(status_code=400, detail="Invalid OAuth state")
original_url = original_url_bytes.decode("utf-8")

token_info = connector_cls.code_to_token(code)
token_info = connector_cls.oauth_code_to_token(code)

# Create a new credential with the token info
credential_data = CredentialBase(
Expand Down

0 comments on commit 7d51171

Please sign in to comment.