diff --git a/google_sheets/app.py b/google_sheets/app.py index 7eb5f15..a0b21ad 100644 --- a/google_sheets/app.py +++ b/google_sheets/app.py @@ -3,7 +3,7 @@ import urllib.parse from os import environ from pathlib import Path -from typing import Annotated, Any, Dict, List, Tuple, Union +from typing import Annotated, Any, Dict, List, Union import httpx from asyncify import asyncify @@ -14,7 +14,7 @@ from prisma.errors import RecordNotFoundError from . import __version__ -from .db_helpers import get_db_connection, get_wasp_db_url +from .db_helpers import get_db_connection __all__ = ["app"] @@ -47,23 +47,7 @@ } -async def get_user_id_chat_uuid_from_chat_id( - chat_id: Union[int, str], -) -> Tuple[int, str]: - wasp_db_url = get_wasp_db_url() - async with get_db_connection(db_url=wasp_db_url) as db: - chat = await db.query_first( - f'SELECT * from "Chat" where id={chat_id}' # nosec: [B608] - ) - if not chat: - raise HTTPException(status_code=404, detail=f"chat {chat} not found") - user_id = chat["userId"] - chat_uuid = chat["uuid"] - return user_id, chat_uuid - - async def is_authenticated_for_ads(user_id: int) -> bool: - await get_user(user_id=user_id) async with get_db_connection() as db: data = await db.gauth.find_unique(where={"user_id": user_id}) @@ -77,7 +61,6 @@ async def is_authenticated_for_ads(user_id: int) -> bool: async def get_login_url( request: Request, user_id: int = Query(title="User ID"), - conv_id: int = Query(title="Conversation ID"), force_new_login: bool = Query(title="Force new login", default=False), ) -> Dict[str, str]: if not force_new_login: @@ -89,7 +72,7 @@ async def get_login_url( f"{oauth2_settings['auth_uri']}?client_id={oauth2_settings['clientId']}" f"&redirect_uri={oauth2_settings['redirectUri']}&response_type=code" f"&scope={urllib.parse.quote_plus('email https://www.googleapis.com/auth/spreadsheets https://www.googleapis.com/auth/drive.metadata.readonly')}" - f"&access_type=offline&prompt=consent&state={conv_id}" + f"&access_type=offline&prompt=consent&state={user_id}" ) markdown_url = f"To navigate Google Ads waters, I require access to your account. Please [click here]({google_oauth_url}) to grant permission." return {"login_url": markdown_url} @@ -105,9 +88,9 @@ async def get_login_success() -> Dict[str, str]: async def login_callback( code: str = Query(title="Authorization Code"), state: str = Query(title="State") ) -> RedirectResponse: - chat_id = state - user_id, chat_uuid = await get_user_id_chat_uuid_from_chat_id(chat_id) - user = await get_user(user_id=user_id) + if not state.isdigit(): + raise HTTPException(status_code=400, detail="User ID must be an integer") + user_id = int(state) token_request_data = { "code": code, @@ -135,10 +118,10 @@ async def login_callback( user_info = userinfo_response.json() async with get_db_connection() as db: await db.gauth.upsert( - where={"user_id": user["id"]}, + where={"user_id": user_id}, data={ "create": { - "user_id": user["id"], + "user_id": user_id, "creds": json.dumps(token_data), "info": json.dumps(user_info), }, @@ -157,19 +140,7 @@ async def login_callback( return RedirectResponse(url=f"{base_url}/login/success") -async def get_user(user_id: Union[int, str]) -> Any: - wasp_db_url = get_wasp_db_url() - async with get_db_connection(db_url=wasp_db_url) as db: - user = await db.query_first( - f'SELECT * from "User" where id={user_id}' # nosec: [B608] - ) - if not user: - raise HTTPException(status_code=404, detail=f"user_id {user_id} not found") - return user - - async def load_user_credentials(user_id: Union[int, str]) -> Any: - await get_user(user_id=user_id) async with get_db_connection() as db: try: data = await db.gauth.find_unique_or_raise(where={"user_id": user_id}) # type: ignore[typeddict-item] diff --git a/google_sheets/db_helpers.py b/google_sheets/db_helpers.py index d99f06a..1572277 100644 --- a/google_sheets/db_helpers.py +++ b/google_sheets/db_helpers.py @@ -23,12 +23,3 @@ async def get_db_connection( yield db finally: await db.disconnect() - - -def get_wasp_db_url() -> str: - curr_db_url = environ.get("DATABASE_URL") - wasp_db_name = environ.get("WASP_DB_NAME", "waspdb") - wasp_db_url = curr_db_url.replace(curr_db_url.split("/")[-1], wasp_db_name) # type: ignore[union-attr] - if "connect_timeout" not in wasp_db_url: - wasp_db_url += "?connect_timeout=60" - return wasp_db_url diff --git a/tests/app/test_app.py b/tests/app/test_app.py index 2b831e6..1df2241 100644 --- a/tests/app/test_app.py +++ b/tests/app/test_app.py @@ -74,15 +74,6 @@ def test_openapi(self) -> None: "required": True, "schema": {"type": "integer", "title": "User ID"}, }, - { - "name": "conv_id", - "in": "query", - "required": True, - "schema": { - "type": "integer", - "title": "Conversation ID", - }, - }, { "name": "force_new_login", "in": "query", diff --git a/tests/app/test_db_helpers.py b/tests/app/test_db_helpers.py deleted file mode 100644 index 2d56522..0000000 --- a/tests/app/test_db_helpers.py +++ /dev/null @@ -1,17 +0,0 @@ -import os -from unittest.mock import patch - -from google_sheets.db_helpers import get_wasp_db_url - - -def test_get_wasp_db_url() -> None: - root_db_url = "db://user:pass@localhost:5432" # pragma: allowlist secret - env_vars = { - "DATABASE_URL": f"{root_db_url}/dbname", - "WASP_DB_NAME": "waspdb", - } - with patch.dict(os.environ, env_vars, clear=True): - wasp_db_url = get_wasp_db_url() - excepted = f"{root_db_url}/waspdb?connect_timeout=60" - - assert wasp_db_url == excepted