Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add referral source to cloud on data plane #3096

Merged
merged 7 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,17 @@ async def create(
safe: bool = False,
request: Optional[Request] = None,
) -> User:
referral_source = None
if request is not None:
referral_source = request.cookies.get("referral_source", None)

tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
referral_source=referral_source,
)

async with get_async_session_with_tenant(tenant_id) as db_session:
Expand Down Expand Up @@ -294,12 +299,17 @@ async def oauth_callback(
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)

tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
referral_source=referral_source,
)

if not tenant_id:
Expand Down Expand Up @@ -711,8 +721,6 @@ def generate_state_token(


# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91


def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
Expand Down Expand Up @@ -762,15 +770,22 @@ def get_oauth_router(
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request, scopes: List[str] = Query(None)
request: Request,
scopes: List[str] = Query(None),
) -> OAuth2AuthorizeResponse:
referral_source = request.cookies.get("referral_source", None)

if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))

next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {"next_url": next_url}

state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
Expand Down Expand Up @@ -829,8 +844,11 @@ async def callback(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)

next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)

# Authenticate user
request.state.referral_source = referral_source

# Proceed to authenticate or create the user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
Expand Down Expand Up @@ -872,7 +890,6 @@ async def callback(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type

return redirect_response

return router
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def get_application() -> FastAPI:
tags=["users"],
)

if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
Expand Down
27 changes: 27 additions & 0 deletions backend/ee/danswer/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import OpenID

from danswer.auth.users import auth_backend
Expand Down Expand Up @@ -59,6 +60,31 @@ def get_application() -> FastAPI:
if MULTI_TENANT:
add_tenant_id_middleware(application, logger)

if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
create_danswer_oauth_router(
oauth_client,
auth_backend,
USER_AUTH_SECRET,
associate_by_email=True,
is_verified_by_default=True,
# Points the user back to the login page
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)

# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)

if AUTH_TYPE == AuthType.OIDC:
include_router_with_global_prefix_prepended(
application,
Expand All @@ -73,6 +99,7 @@ def get_application() -> FastAPI:
prefix="/auth/oidc",
tags=["auth"],
)

# need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
application,
Expand Down
1 change: 1 addition & 0 deletions backend/ee/danswer/server/tenants/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ class ImpersonateRequest(BaseModel):
class TenantCreationPayload(BaseModel):
tenant_id: str
email: str
referral_source: str | None = None
18 changes: 12 additions & 6 deletions backend/ee/danswer/server/tenants/provisioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
logger = logging.getLogger(__name__)


async def get_or_create_tenant_id(email: str) -> str:
async def get_or_create_tenant_id(
email: str, referral_source: str | None = None
) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
Expand All @@ -51,7 +53,7 @@ async def get_or_create_tenant_id(email: str) -> str:
except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email)
tenant_id = await create_tenant(email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
Expand All @@ -64,13 +66,13 @@ async def get_or_create_tenant_id(email: str) -> str:
return tenant_id


async def create_tenant(email: str) -> str:
async def create_tenant(email: str, referral_source: str | None = None) -> str:
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
await notify_control_plane(tenant_id, email)
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
Expand Down Expand Up @@ -117,14 +119,18 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)


async def notify_control_plane(tenant_id: str, email: str) -> None:
async def notify_control_plane(
tenant_id: str, email: str, referral_source: str | None = None
) -> None:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
payload = TenantCreationPayload(tenant_id=tenant_id, email=email)
payload = TenantCreationPayload(
tenant_id=tenant_id, email=email, referral_source=referral_source
)

async with aiohttp.ClientSession() as session:
async with session.post(
Expand Down
2 changes: 2 additions & 0 deletions backend/tests/integration/common_utils/managers/tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ class TenantManager:
def create(
tenant_id: str | None = None,
initial_admin_email: str | None = None,
referral_source: str | None = None,
) -> dict[str, str]:
body = {
"tenant_id": tenant_id,
"initial_admin_email": initial_admin_email,
"referral_source": referral_source,
}

token = generate_auth_token()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Create Tenant 1 and its Admin User
TenantManager.create("tenant_dev1", "[email protected]")
TenantManager.create("tenant_dev1", "[email protected]", "Data Plane Registration")
test_user1: DATestUser = UserManager.create(name="test1", email="[email protected]")
assert UserManager.verify_role(test_user1, UserRole.ADMIN)

# Create Tenant 2 and its Admin User
TenantManager.create("tenant_dev2", "[email protected]")
TenantManager.create("tenant_dev2", "[email protected]", "Data Plane Registration")
test_user2: DATestUser = UserManager.create(name="test2", email="[email protected]")
assert UserManager.verify_role(test_user2, UserRole.ADMIN)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# Test flow from creating tenant to registering as a user
def test_tenant_creation(reset_multitenant: None) -> None:
TenantManager.create("tenant_dev", "[email protected]")
TenantManager.create("tenant_dev", "[email protected]", "Data Plane Registration")
test_user: DATestUser = UserManager.create(name="test", email="[email protected]")

assert UserManager.verify_role(test_user, UserRole.ADMIN)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ function ConnectorRow({
return (
<TableRow
className={`hover:bg-hover-light ${
invisible ? "invisible !h-0 !-mb-10" : "!border !border-border"
invisible
? "invisible !h-0 !-mb-10 !border-none"
: "!border !border-border"
} w-full cursor-pointer relative `}
onClick={() => {
router.push(`/admin/connector/${ccPairsIndexingStatus.cc_pair_id}`);
Expand Down Expand Up @@ -434,7 +436,6 @@ export function CCPairIndexingStatusTable({
{!shouldExpand ? "Collapse All" : "Expand All"}
</Button>
</div>

<TableBody>
{sortedSources
.filter(
Expand All @@ -454,14 +455,12 @@ export function CCPairIndexingStatusTable({
return (
<React.Fragment key={ind}>
<br className="mt-4" />

<SummaryRow
source={source}
summary={groupSummaries[source]}
isOpen={connectorsToggled[source] || false}
onToggle={() => toggleSource(source)}
/>

{connectorsToggled[source] && (
<>
<TableRow className="border border-border">
Expand Down
2 changes: 1 addition & 1 deletion web/src/app/admin/users/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ const AddUserButton = ({
};
return (
<>
<Button className="w-fit" onClick={() => setModal(true)}>
<Button className="my-auto w-fit" onClick={() => setModal(true)}>
<div className="flex">
<FiPlusSquare className="my-auto mr-2" />
Invite Users
Expand Down
8 changes: 7 additions & 1 deletion web/src/app/auth/login/EmailPasswordForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import { Spinner } from "@/components/Spinner";
export function EmailPasswordForm({
isSignup = false,
shouldVerify,
referralSource,
}: {
isSignup?: boolean;
shouldVerify?: boolean;
referralSource?: string;
}) {
const router = useRouter();
const { popup, setPopup } = usePopup();
Expand All @@ -39,7 +41,11 @@ export function EmailPasswordForm({
if (isSignup) {
// login is fast, no need to show a spinner
setIsWorking(true);
const response = await basicSignup(values.email, values.password);
const response = await basicSignup(
values.email,
values.password,
referralSource
);

if (!response.ok) {
const errorDetail = (await response.json()).detail;
Expand Down
8 changes: 6 additions & 2 deletions web/src/app/auth/login/SignInButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ export function SignInButton({
);
}

const url = new URL(authorizeUrl);

const finalAuthorizeUrl = url.toString();

if (!button) {
throw new Error(`Unhandled authType: ${authType}`);
}

return (
<a
className="mx-auto mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
href={authorizeUrl}
className="mx-auto mt-6 py-3 w-full text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
href={finalAuthorizeUrl}
>
{button}
</a>
Expand Down
Loading
Loading