From 628259fa8660894492f8c4a80b43aab967a655bf Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 11 Nov 2024 16:42:43 -0800 Subject: [PATCH] finalize --- backend/danswer/auth/users.py | 28 ++++++++++--------- backend/danswer/main.py | 2 +- backend/ee/danswer/main.py | 27 ++++++++++++++++++ web/src/app/auth/login/SignInButton.tsx | 10 +------ .../auth/signup/ReferralSourceSelector.tsx | 24 ++++++++++++---- web/src/app/auth/signup/page.tsx | 9 +++++- web/src/lib/userSS.ts | 6 +++- 7 files changed, 76 insertions(+), 30 deletions(-) diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index cab28102ab4..af2395a1fa5 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -282,10 +282,11 @@ async def oauth_callback( refresh_token: Optional[str] = None, request: Optional[Request] = None, *, + referral_source: Optional[str] = None, associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> models.UOAP: - tenant_id = await get_or_create_tenant_id(account_email) + tenant_id = await get_or_create_tenant_id(account_email, referral_source) if not tenant_id: raise HTTPException(status_code=401, detail="User not found") @@ -690,8 +691,6 @@ def generate_state_token( # refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91 - - def create_danswer_oauth_router( oauth_client: BaseOAuth2, backend: AuthenticationBackend, @@ -741,15 +740,23 @@ def get_oauth_router( response_model=OAuth2AuthorizeResponse, ) async def authorize( - request: Request, scopes: List[str] = Query(None) + request: Request, + scopes: List[str] = Query(None), ) -> OAuth2AuthorizeResponse: + cookies = request.cookies + referral_source = cookies.get("referral_source", None) + if redirect_url is not None: authorize_redirect_url = redirect_url else: authorize_redirect_url = str(request.url_for(callback_route_name)) next_url = request.query_params.get("next", "/") - state_data: Dict[str, str] = {"next_url": next_url} + # Include the referral_source in the state_data + state_data: Dict[str, str] = { + "next_url": next_url, + "referral_source": referral_source or "default_referral", + } state = generate_state_token(state_data, state_secret) authorization_url = await oauth_client.get_authorization_url( authorize_redirect_url, @@ -808,8 +815,9 @@ async def callback( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) next_url = state_data.get("next_url", "/") + referral_source = state_data.get("referral_source", "/") - # Authenticate user + # Proceed to authenticate or create the user try: user = await user_manager.oauth_callback( oauth_client.name, @@ -819,6 +827,7 @@ async def callback( token.get("expires_at"), token.get("refresh_token"), request, + referral_source=referral_source, associate_by_email=associate_by_email, is_verified_by_default=is_verified_by_default, ) @@ -845,13 +854,6 @@ async def callback( for header_name, header_value in response.headers.items(): redirect_response.headers[header_name] = header_value - if hasattr(response, "body"): - redirect_response.body = response.body - if hasattr(response, "status_code"): - redirect_response.status_code = response.status_code - if hasattr(response, "media_type"): - redirect_response.media_type = response.media_type - return redirect_response return router diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 06ce7bf4092..c59722f74ab 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -304,7 +304,7 @@ def get_application() -> FastAPI: tags=["users"], ) - if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD: + if AUTH_TYPE == AuthType.GOOGLE_OAUTH: oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) include_router_with_global_prefix_prepended( application, diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index affa5fd1cde..7ccb46b8746 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -1,4 +1,5 @@ from fastapi import FastAPI +from httpx_oauth.clients.google import GoogleOAuth2 from httpx_oauth.clients.openid import OpenID from danswer.auth.users import auth_backend @@ -59,6 +60,31 @@ def get_application() -> FastAPI: if MULTI_TENANT: add_tenant_id_middleware(application, logger) + if AUTH_TYPE == AuthType.CLOUD: + oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) + include_router_with_global_prefix_prepended( + application, + create_danswer_oauth_router( + oauth_client, + auth_backend, + USER_AUTH_SECRET, + associate_by_email=True, + is_verified_by_default=True, + # Points the user back to the login page + redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback", + ), + prefix="/auth/oauth", + tags=["auth"], + ) + + # Need basic auth router for `logout` endpoint + include_router_with_global_prefix_prepended( + application, + fastapi_users.get_logout_router(auth_backend), + prefix="/auth", + tags=["auth"], + ) + if AUTH_TYPE == AuthType.OIDC: include_router_with_global_prefix_prepended( application, @@ -73,6 +99,7 @@ def get_application() -> FastAPI: prefix="/auth/oidc", tags=["auth"], ) + # need basic auth router for `logout` endpoint include_router_with_global_prefix_prepended( application, diff --git a/web/src/app/auth/login/SignInButton.tsx b/web/src/app/auth/login/SignInButton.tsx index c8212611fe3..b06f9bad79b 100644 --- a/web/src/app/auth/login/SignInButton.tsx +++ b/web/src/app/auth/login/SignInButton.tsx @@ -4,11 +4,9 @@ import { FaGoogle } from "react-icons/fa"; export function SignInButton({ authorizeUrl, authType, - referralSource, }: { authorizeUrl: string; authType: AuthType; - referralSource?: string; }) { let button; if (authType === "google_oauth" || authType === "cloud") { @@ -40,21 +38,15 @@ export function SignInButton({ const url = new URL(authorizeUrl); - // if (referralSource) { - url.searchParams.set("referral_source", "docs"); - // } - const finalAuthorizeUrl = url.toString(); - console.log("finalAuthorizeUrl", finalAuthorizeUrl); - if (!button) { throw new Error(`Unhandled authType: ${authType}`); } return ( {button} diff --git a/web/src/app/auth/signup/ReferralSourceSelector.tsx b/web/src/app/auth/signup/ReferralSourceSelector.tsx index e503bc96a01..a552907fb9f 100644 --- a/web/src/app/auth/signup/ReferralSourceSelector.tsx +++ b/web/src/app/auth/signup/ReferralSourceSelector.tsx @@ -8,6 +8,7 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; +import { Label } from "@/components/admin/connectors/Field"; interface ReferralSourceSelectorProps { defaultValue?: string; @@ -43,14 +44,27 @@ const ReferralSourceSelector: React.FC = ({ }; return ( -
+
+