From cb6fad26b8ec9f77a7bc82a94da8e6748bbc20f0 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 11 Nov 2024 16:43:23 -0800 Subject: [PATCH 1/7] cloud auth referral source --- backend/danswer/auth/users.py | 34 ++++---- backend/danswer/main.py | 2 +- backend/ee/danswer/main.py | 27 +++++++ backend/ee/danswer/server/tenants/models.py | 1 + .../ee/danswer/server/tenants/provisioning.py | 18 +++-- .../common_utils/managers/tenant.py | 2 + .../syncing/test_search_permissions.py | 4 +- .../tenants/test_tenant_creation.py | 2 +- .../status/CCPairIndexingStatusTable.tsx | 7 +- web/src/app/admin/users/page.tsx | 2 +- web/src/app/auth/login/EmailPasswordForm.tsx | 8 +- web/src/app/auth/login/SignInButton.tsx | 8 +- .../auth/signup/ReferralSourceSelector.tsx | 77 +++++++++++++++++++ web/src/app/auth/signup/page.tsx | 9 +++ web/src/components/admin/ClientLayout.tsx | 2 +- web/src/lib/user.ts | 7 +- web/src/lib/userSS.ts | 6 +- 17 files changed, 181 insertions(+), 35 deletions(-) create mode 100644 web/src/app/auth/signup/ReferralSourceSelector.tsx diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 73f1ec18484..067acbf653e 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -228,12 +228,16 @@ async def create( safe: bool = False, request: Optional[Request] = None, ) -> User: + referral_source = None + if request: + referral_source = request.cookies.get("referral_source", None) + tenant_id = await fetch_ee_implementation_or_noop( "danswer.server.tenants.provisioning", "get_or_create_tenant_id", async_return_default_schema, )( - email=user_create.email, + email=user_create.email, referral_source=referral_source, ) async with get_async_session_with_tenant(tenant_id) as db_session: @@ -291,6 +295,7 @@ async def oauth_callback( refresh_token: Optional[str] = None, request: Optional[Request] = None, *, + referral_source: Optional[str] = None, associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> models.UOAP: @@ -299,7 +304,7 @@ async def oauth_callback( "get_or_create_tenant_id", async_return_default_schema, )( - email=account_email, + email=account_email, referral_source=referral_source, ) if not tenant_id: @@ -711,8 +716,6 @@ def generate_state_token( # refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91 - - def create_danswer_oauth_router( oauth_client: BaseOAuth2, backend: AuthenticationBackend, @@ -762,15 +765,23 @@ def get_oauth_router( response_model=OAuth2AuthorizeResponse, ) async def authorize( - request: Request, scopes: List[str] = Query(None) + request: Request, + scopes: List[str] = Query(None), ) -> OAuth2AuthorizeResponse: + cookies = request.cookies + referral_source = cookies.get("referral_source", None) + if redirect_url is not None: authorize_redirect_url = redirect_url else: authorize_redirect_url = str(request.url_for(callback_route_name)) next_url = request.query_params.get("next", "/") - state_data: Dict[str, str] = {"next_url": next_url} + # Include the referral_source in the state_data + state_data: Dict[str, str] = { + "next_url": next_url, + "referral_source": referral_source or "default_referral", + } state = generate_state_token(state_data, state_secret) authorization_url = await oauth_client.get_authorization_url( authorize_redirect_url, @@ -829,8 +840,9 @@ async def callback( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) next_url = state_data.get("next_url", "/") + referral_source = state_data.get("referral_source", "/") - # Authenticate user + # Proceed to authenticate or create the user try: user = await user_manager.oauth_callback( oauth_client.name, @@ -840,6 +852,7 @@ async def callback( token.get("expires_at"), token.get("refresh_token"), request, + referral_source=referral_source, associate_by_email=associate_by_email, is_verified_by_default=is_verified_by_default, ) @@ -866,13 +879,6 @@ async def callback( for header_name, header_value in response.headers.items(): redirect_response.headers[header_name] = header_value - if hasattr(response, "body"): - redirect_response.body = response.body - if hasattr(response, "status_code"): - redirect_response.status_code = response.status_code - if hasattr(response, "media_type"): - redirect_response.media_type = response.media_type - return redirect_response return router diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 0aff801c8f4..ff2185dab7c 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -315,7 +315,7 @@ def get_application() -> FastAPI: tags=["users"], ) - if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD: + if AUTH_TYPE == AuthType.GOOGLE_OAUTH: oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) include_router_with_global_prefix_prepended( application, diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index d09a2893d69..96655af2acd 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -1,4 +1,5 @@ from fastapi import FastAPI +from httpx_oauth.clients.google import GoogleOAuth2 from httpx_oauth.clients.openid import OpenID from danswer.auth.users import auth_backend @@ -59,6 +60,31 @@ def get_application() -> FastAPI: if MULTI_TENANT: add_tenant_id_middleware(application, logger) + if AUTH_TYPE == AuthType.CLOUD: + oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET) + include_router_with_global_prefix_prepended( + application, + create_danswer_oauth_router( + oauth_client, + auth_backend, + USER_AUTH_SECRET, + associate_by_email=True, + is_verified_by_default=True, + # Points the user back to the login page + redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback", + ), + prefix="/auth/oauth", + tags=["auth"], + ) + + # Need basic auth router for `logout` endpoint + include_router_with_global_prefix_prepended( + application, + fastapi_users.get_logout_router(auth_backend), + prefix="/auth", + tags=["auth"], + ) + if AUTH_TYPE == AuthType.OIDC: include_router_with_global_prefix_prepended( application, @@ -73,6 +99,7 @@ def get_application() -> FastAPI: prefix="/auth/oidc", tags=["auth"], ) + # need basic auth router for `logout` endpoint include_router_with_global_prefix_prepended( application, diff --git a/backend/ee/danswer/server/tenants/models.py b/backend/ee/danswer/server/tenants/models.py index df24ff6c32d..c372418f6a4 100644 --- a/backend/ee/danswer/server/tenants/models.py +++ b/backend/ee/danswer/server/tenants/models.py @@ -38,3 +38,4 @@ class ImpersonateRequest(BaseModel): class TenantCreationPayload(BaseModel): tenant_id: str email: str + referral_source: str | None = None diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index e956cf4359c..32f95b1200d 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -41,7 +41,9 @@ logger = logging.getLogger(__name__) -async def get_or_create_tenant_id(email: str) -> str: +async def get_or_create_tenant_id( + email: str, referral_source: str | None = None +) -> str: """Get existing tenant ID for an email or create a new tenant if none exists.""" if not MULTI_TENANT: return POSTGRES_DEFAULT_SCHEMA @@ -51,7 +53,7 @@ async def get_or_create_tenant_id(email: str) -> str: except exceptions.UserNotExists: # If tenant does not exist and in Multi tenant mode, provision a new tenant try: - tenant_id = await create_tenant(email) + tenant_id = await create_tenant(email, referral_source) except Exception as e: logger.error(f"Tenant provisioning failed: {e}") raise HTTPException(status_code=500, detail="Failed to provision tenant.") @@ -64,13 +66,13 @@ async def get_or_create_tenant_id(email: str) -> str: return tenant_id -async def create_tenant(email: str) -> str: +async def create_tenant(email: str, referral_source: str | None = None) -> str: tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) try: # Provision tenant on data plane await provision_tenant(tenant_id, email) # Notify control plane - await notify_control_plane(tenant_id, email) + await notify_control_plane(tenant_id, email, referral_source) except Exception as e: logger.error(f"Tenant provisioning failed: {e}") await rollback_tenant_provisioning(tenant_id) @@ -117,14 +119,18 @@ async def provision_tenant(tenant_id: str, email: str) -> None: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) -async def notify_control_plane(tenant_id: str, email: str) -> None: +async def notify_control_plane( + tenant_id: str, email: str, referral_source: str | None = None +) -> None: logger.info("Fetching billing information") token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } - payload = TenantCreationPayload(tenant_id=tenant_id, email=email) + payload = TenantCreationPayload( + tenant_id=tenant_id, email=email, referral_source=referral_source + ) async with aiohttp.ClientSession() as session: async with session.post( diff --git a/backend/tests/integration/common_utils/managers/tenant.py b/backend/tests/integration/common_utils/managers/tenant.py index 76fd16471f8..fc411018df7 100644 --- a/backend/tests/integration/common_utils/managers/tenant.py +++ b/backend/tests/integration/common_utils/managers/tenant.py @@ -28,10 +28,12 @@ class TenantManager: def create( tenant_id: str | None = None, initial_admin_email: str | None = None, + referral_source: str | None = None, ) -> dict[str, str]: body = { "tenant_id": tenant_id, "initial_admin_email": initial_admin_email, + "referral_source": referral_source, } token = generate_auth_token() diff --git a/backend/tests/integration/multitenant_tests/syncing/test_search_permissions.py b/backend/tests/integration/multitenant_tests/syncing/test_search_permissions.py index 454b02412d4..fead77387f6 100644 --- a/backend/tests/integration/multitenant_tests/syncing/test_search_permissions.py +++ b/backend/tests/integration/multitenant_tests/syncing/test_search_permissions.py @@ -14,12 +14,12 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None: # Create Tenant 1 and its Admin User - TenantManager.create("tenant_dev1", "test1@test.com") + TenantManager.create("tenant_dev1", "test1@test.com", "Data Plane Registration") test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com") assert UserManager.verify_role(test_user1, UserRole.ADMIN) # Create Tenant 2 and its Admin User - TenantManager.create("tenant_dev2", "test2@test.com") + TenantManager.create("tenant_dev2", "test2@test.com", "Data Plane Registration") test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com") assert UserManager.verify_role(test_user2, UserRole.ADMIN) diff --git a/backend/tests/integration/multitenant_tests/tenants/test_tenant_creation.py b/backend/tests/integration/multitenant_tests/tenants/test_tenant_creation.py index 6088743e317..c2dbcc81790 100644 --- a/backend/tests/integration/multitenant_tests/tenants/test_tenant_creation.py +++ b/backend/tests/integration/multitenant_tests/tenants/test_tenant_creation.py @@ -11,7 +11,7 @@ # Test flow from creating tenant to registering as a user def test_tenant_creation(reset_multitenant: None) -> None: - TenantManager.create("tenant_dev", "test@test.com") + TenantManager.create("tenant_dev", "test@test.com", "Data Plane Registration") test_user: DATestUser = UserManager.create(name="test", email="test@test.com") assert UserManager.verify_role(test_user, UserRole.ADMIN) diff --git a/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx b/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx index a014ce105bb..faea0de39b3 100644 --- a/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx +++ b/web/src/app/admin/indexing/status/CCPairIndexingStatusTable.tsx @@ -194,7 +194,9 @@ function ConnectorRow({ return ( { router.push(`/admin/connector/${ccPairsIndexingStatus.cc_pair_id}`); @@ -434,7 +436,6 @@ export function CCPairIndexingStatusTable({ {!shouldExpand ? "Collapse All" : "Expand All"} - {sortedSources .filter( @@ -454,14 +455,12 @@ export function CCPairIndexingStatusTable({ return (
- toggleSource(source)} /> - {connectorsToggled[source] && ( <> diff --git a/web/src/app/admin/users/page.tsx b/web/src/app/admin/users/page.tsx index 44397515181..8dbc6c308d5 100644 --- a/web/src/app/admin/users/page.tsx +++ b/web/src/app/admin/users/page.tsx @@ -188,7 +188,7 @@ const AddUserButton = ({ }; return ( <> -