Skip to content

Commit

Permalink
maybe these fixes work
Browse files Browse the repository at this point in the history
  • Loading branch information
hagen-danswer committed Nov 4, 2024
1 parent 84aed73 commit cd9ead8
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 30 deletions.
55 changes: 40 additions & 15 deletions backend/tests/daily/connectors/gmail/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,52 @@
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from tests.load_env_vars import load_env_vars


def load_env_vars(env_file: str = ".env") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(current_dir, env_file)
try:
with open(env_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
key, value = line.split("=", 1)
os.environ[key] = value.strip()
print("Successfully loaded environment variables")
except FileNotFoundError:
print(f"File {env_file} not found")
# def load_env_vars(env_file: str = ".env") -> None:
# current_dir = os.path.dirname(os.path.abspath(__file__))
# env_path = os.path.join(current_dir, env_file)
# try:
# with open(env_path, "r") as f:
# for line in f:
# line = line.strip()
# if line and not line.startswith("#"):
# key, value = line.split("=", 1)
# os.environ[key] = value.strip()
# print("Successfully loaded environment variables")
# except FileNotFoundError:
# print(f"File {env_file} not found")


# Load environment variables at the module level
load_env_vars()


def parse_credentials(env_str: str) -> dict:
"""
Parse a double-escaped JSON string from environment variables into a Python dictionary.
Args:
env_str (str): The double-escaped JSON string from environment variables
Returns:
dict: Parsed OAuth credentials
"""
# first try normally
try:
return json.loads(env_str)
except Exception:
# First, try remove extra escaping backslashes
unescaped = env_str.replace('\\"', '"')

# remove leading / trailing quotes
unescaped = unescaped.strip('"')

# Now parse the JSON
return json.loads(unescaped)


@pytest.fixture
def google_gmail_oauth_connector_factory() -> Callable[..., GmailConnector]:
def _connector_factory(
Expand All @@ -43,7 +68,7 @@ def _connector_factory(
connector = GmailConnector()

json_string = os.environ["GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR"]
refried_json_string = json.loads(json_string)
refried_json_string = json.loads(parse_credentials(json_string))

credentials_json = {
DB_CREDENTIALS_DICT_TOKEN_KEY: refried_json_string,
Expand All @@ -64,7 +89,7 @@ def _connector_factory(
connector = GmailConnector()

json_string = os.environ["GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR"]
refried_json_string = json.loads(json_string)
refried_json_string = json.loads(parse_credentials(json_string))

# Load Service Account Credentials
connector.load_credentials(
Expand Down
30 changes: 15 additions & 15 deletions backend/tests/daily/connectors/google_drive/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@
from danswer.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)


def load_env_vars(env_file: str = ".env") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(current_dir, env_file)
try:
with open(env_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
key, value = line.split("=", 1)
os.environ[key] = value.strip()
print("Successfully loaded environment variables")
except FileNotFoundError:
print(f"File {env_file} not found")
from tests.load_env_vars import load_env_vars

# def load_env_vars(env_file: str = ".env") -> None:
# current_dir = os.path.dirname(os.path.abspath(__file__))
# env_path = os.path.join(current_dir, env_file)
# try:
# with open(env_path, "r") as f:
# for line in f:
# line = line.strip()
# if line and not line.startswith("#"):
# key, value = line.split("=", 1)
# os.environ[key] = value.strip()
# print("Successfully loaded environment variables")
# except FileNotFoundError:
# print(f"File {env_file} not found")


# Load environment variables at the module level
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def test_all_permissions(
+ list(range(45, 50)) # Folder 2
+ list(range(50, 55)) # Folder 2_1
+ list(range(55, 60)) # Folder 2_2
+ [61] # Sections
)

# Should get everything
Expand Down

0 comments on commit cd9ead8

Please sign in to comment.