diff --git a/backend/danswer/auth/schemas.py b/backend/danswer/auth/schemas.py index 1fe9d85f01d..9c81899a421 100644 --- a/backend/danswer/auth/schemas.py +++ b/backend/danswer/auth/schemas.py @@ -35,7 +35,6 @@ class UserCreate(schemas.BaseUserCreate): role: UserRole = UserRole.BASIC has_web_login: bool | None = True tenant_id: str | None = None - referral_source: str | None = None class UserUpdate(schemas.BaseUserUpdate): diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index a9ca07136ac..f92f55cdf32 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -1,5 +1,3 @@ -import base64 -import json import smtplib import uuid from collections.abc import AsyncGenerator @@ -224,14 +222,11 @@ async def create( request: Optional[Request] = None, ) -> User: # Print the cookies from the request + referral_source = None if request: - cookies = request.cookies - print("Cookies from request:", cookies) - else: - print("No request object available") - tenant_id = await get_or_create_tenant_id( - user_create.email, user_create.referral_source - ) + referral_source = request.cookies.get("referral_source", None) + + tenant_id = await get_or_create_tenant_id(user_create.email, referral_source) async with get_async_session_with_tenant(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) @@ -291,39 +286,7 @@ async def oauth_callback( associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> models.UOAP: - print("within callback") - - if request: - cookies = request.cookies - print("Cookies from request:", cookies) - else: - print("No request object available") - - referral_source = request.query_params.get("referral_source") - print("referral_source") - - print(referral_source) - - referral_source = None - if request and request.query_params.get("state"): - try: - state_json = base64.b64decode(request.query_params["state"]).decode( - "utf-8" - ) - state_data = json.loads(state_json) - referral_source = state_data.get("referralSource") - except Exception as e: - # Log the error for debugging - logger.error(f"Error parsing state parameter: {str(e)}") - # Continue with the flow, treating it as if there was no referral source - - # Log the referral source for debugging - logger.info(f"zpzpzpReferral source: {referral_source}") - - print("referral_source") - print(referral_source) - return - tenant_id = await get_or_create_tenant_id(account_email, referral_source) + tenant_id = await get_or_create_tenant_id(account_email) if not tenant_id: raise HTTPException(status_code=401, detail="User not found") diff --git a/backend/ee/danswer/server/tenants/models.py b/backend/ee/danswer/server/tenants/models.py index df24ff6c32d..c372418f6a4 100644 --- a/backend/ee/danswer/server/tenants/models.py +++ b/backend/ee/danswer/server/tenants/models.py @@ -38,3 +38,4 @@ class ImpersonateRequest(BaseModel): class TenantCreationPayload(BaseModel): tenant_id: str email: str + referral_source: str | None = None diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py index 5b951e968e7..32f95b1200d 100644 --- a/backend/ee/danswer/server/tenants/provisioning.py +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -70,9 +70,9 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str: tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) try: # Provision tenant on data plane - await provision_tenant(tenant_id, email, referral_source) + await provision_tenant(tenant_id, email) # Notify control plane - await notify_control_plane(tenant_id, email) + await notify_control_plane(tenant_id, email, referral_source) except Exception as e: logger.error(f"Tenant provisioning failed: {e}") await rollback_tenant_provisioning(tenant_id) @@ -119,14 +119,18 @@ async def provision_tenant(tenant_id: str, email: str) -> None: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) -async def notify_control_plane(tenant_id: str, email: str) -> None: +async def notify_control_plane( + tenant_id: str, email: str, referral_source: str | None = None +) -> None: logger.info("Fetching billing information") token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } - payload = TenantCreationPayload(tenant_id=tenant_id, email=email) + payload = TenantCreationPayload( + tenant_id=tenant_id, email=email, referral_source=referral_source + ) async with aiohttp.ClientSession() as session: async with session.post( diff --git a/backend/tests/integration/multitenant_tests/tenants/test_tenant_creation.py b/backend/tests/integration/multitenant_tests/tenants/test_tenant_creation.py index 6088743e317..c2dbcc81790 100644 --- a/backend/tests/integration/multitenant_tests/tenants/test_tenant_creation.py +++ b/backend/tests/integration/multitenant_tests/tenants/test_tenant_creation.py @@ -11,7 +11,7 @@ # Test flow from creating tenant to registering as a user def test_tenant_creation(reset_multitenant: None) -> None: - TenantManager.create("tenant_dev", "test@test.com") + TenantManager.create("tenant_dev", "test@test.com", "Data Plane Registration") test_user: DATestUser = UserManager.create(name="test", email="test@test.com") assert UserManager.verify_role(test_user, UserRole.ADMIN) diff --git a/web/src/app/auth/login/EmailPasswordForm.tsx b/web/src/app/auth/login/EmailPasswordForm.tsx index 334c74d14f7..06053fae5f2 100644 --- a/web/src/app/auth/login/EmailPasswordForm.tsx +++ b/web/src/app/auth/login/EmailPasswordForm.tsx @@ -14,9 +14,11 @@ import { Spinner } from "@/components/Spinner"; export function EmailPasswordForm({ isSignup = false, shouldVerify, + referralSource, }: { isSignup?: boolean; shouldVerify?: boolean; + referralSource?: string; }) { const router = useRouter(); const { popup, setPopup } = usePopup(); @@ -39,7 +41,11 @@ export function EmailPasswordForm({ if (isSignup) { // login is fast, no need to show a spinner setIsWorking(true); - const response = await basicSignup(values.email, values.password); + const response = await basicSignup( + values.email, + values.password, + referralSource + ); if (!response.ok) { const errorDetail = (await response.json()).detail; diff --git a/web/src/app/auth/login/SignInButton.tsx b/web/src/app/auth/login/SignInButton.tsx index 735854ca6ed..c8212611fe3 100644 --- a/web/src/app/auth/login/SignInButton.tsx +++ b/web/src/app/auth/login/SignInButton.tsx @@ -1,4 +1,3 @@ -import Cookies from "js-cookie"; import { AuthType } from "@/lib/constants"; import { FaGoogle } from "react-icons/fa"; @@ -41,7 +40,9 @@ export function SignInButton({ const url = new URL(authorizeUrl); + // if (referralSource) { url.searchParams.set("referral_source", "docs"); + // } const finalAuthorizeUrl = url.toString(); @@ -51,19 +52,12 @@ export function SignInButton({ throw new Error(`Unhandled authType: ${authType}`); } - const handleClick = () => { - Cookies.set("referral_source", "docs", { - sameSite: "lax", - expires: new Date(Date.now() + 1000 * 60 * 60 * 24 * 30), - }); - window.location.href = finalAuthorizeUrl; - }; return ( - + ); } diff --git a/web/src/app/auth/signup/ReferralSourceSelector.tsx b/web/src/app/auth/signup/ReferralSourceSelector.tsx index 63212ab1575..e503bc96a01 100644 --- a/web/src/app/auth/signup/ReferralSourceSelector.tsx +++ b/web/src/app/auth/signup/ReferralSourceSelector.tsx @@ -11,12 +11,10 @@ import { interface ReferralSourceSelectorProps { defaultValue?: string; - onChange: (value: string) => void; } const ReferralSourceSelector: React.FC = ({ defaultValue, - onChange, }) => { const [referralSource, setReferralSource] = useState(defaultValue); @@ -36,7 +34,12 @@ const ReferralSourceSelector: React.FC = ({ const handleChange = (value: string) => { setReferralSource(value); - onChange(value); + const cookies = require("js-cookie"); + cookies.set("referral_source", value, { + expires: 365, + path: "/", + sameSite: "strict", + }); }; return ( diff --git a/web/src/app/auth/signup/SignupPage.tsx b/web/src/app/auth/signup/SignupPage.tsx deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/web/src/app/auth/signup/SignupPageContent.tsx b/web/src/app/auth/signup/SignupPageContent.tsx deleted file mode 100644 index 6b7179cab04..00000000000 --- a/web/src/app/auth/signup/SignupPageContent.tsx +++ /dev/null @@ -1,82 +0,0 @@ -"use client"; - -import { useState } from "react"; -import { AuthTypeMetadata } from "@/lib/userSS"; -import { HealthCheckBanner } from "@/components/health/healthcheck"; -import { EmailPasswordForm } from "../login/EmailPasswordForm"; -import Text from "@/components/ui/text"; -import Link from "next/link"; -import { SignInButton } from "../login/SignInButton"; -import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; -import ReferralSourceSelector from "./ReferralSourceSelector"; - -interface SignupPageContentProps { - authTypeMetadata: AuthTypeMetadata | null; - cloud: boolean; - authUrl: string | null; - initialReferralSource?: string; -} - -const SignupPageContent: React.FC = ({ - authTypeMetadata, - cloud, - authUrl, - initialReferralSource, -}) => { - const [selectedReferralSource, setSelectedReferralSource] = useState( - initialReferralSource - ); - - return ( - - - - <> -
-
-

- {cloud ? "Complete your sign up" : "Sign Up for Danswer"} -

- - {cloud && ( - setSelectedReferralSource(value)} - /> - )} - - {cloud && authUrl && ( -
- -
-
- or -
-
-
- )} - - - -
- - Already have an account?{" "} - - Log In - - -
-
- -
- ); -}; - -export default SignupPageContent; diff --git a/web/src/app/auth/signup/page.tsx b/web/src/app/auth/signup/page.tsx index d124a1f7dc0..505fe17be6d 100644 --- a/web/src/app/auth/signup/page.tsx +++ b/web/src/app/auth/signup/page.tsx @@ -1,3 +1,4 @@ +import { HealthCheckBanner } from "@/components/health/healthcheck"; import { User } from "@/lib/types"; import { getCurrentUserSS, @@ -6,13 +7,14 @@ import { getAuthUrlSS, } from "@/lib/userSS"; import { redirect } from "next/navigation"; -import SignupPageContent from "./SignupPageContent"; - -const Page = async (props: { - searchParams?: Promise<{ [key: string]: string | string[] | undefined }>; -}) => { - const searchParams = await props.searchParams; +import { EmailPasswordForm } from "../login/EmailPasswordForm"; +import Text from "@/components/ui/text"; +import Link from "next/link"; +import { SignInButton } from "../login/SignInButton"; +import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; +import ReferralSourceSelector from "./ReferralSourceSelector"; +const Page = async () => { // catch cases where the backend is completely unreachable here // without try / catch, will just raise an exception and the page // will not render @@ -51,18 +53,45 @@ const Page = async (props: { authUrl = await getAuthUrlSS(authTypeMetadata.authType, null); } - // Initialize referralSource with the value from URL parameters - const referralSource = searchParams - ? (searchParams.referral as string | undefined) - : undefined; - return ( - + + + + <> +
+
+

+ {cloud ? "Complete your sign up" : "Sign Up for Danswer"} +

+ + {cloud && } + {cloud && authUrl && ( +
+ +
+
+ or +
+
+
+ )} + + + +
+ + Already have an account?{" "} + + Log In + + +
+
+ +
); }; diff --git a/web/src/lib/user.ts b/web/src/lib/user.ts index 10d426c1ec9..d540366f519 100644 --- a/web/src/lib/user.ts +++ b/web/src/lib/user.ts @@ -43,7 +43,11 @@ export const basicLogin = async ( return response; }; -export const basicSignup = async (email: string, password: string) => { +export const basicSignup = async ( + email: string, + password: string, + referralSource?: string +) => { const response = await fetch("/api/auth/register", { method: "POST", credentials: "include", @@ -54,6 +58,7 @@ export const basicSignup = async (email: string, password: string) => { email, username: email, password, + referral_source: referralSource, }), }); return response;