Skip to content

Commit

Permalink
finalize
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Nov 12, 2024
1 parent 7732ea9 commit 628259f
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 30 deletions.
28 changes: 15 additions & 13 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,11 @@ async def oauth_callback(
refresh_token: Optional[str] = None,
request: Optional[Request] = None,
*,
referral_source: Optional[str] = None,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
tenant_id = await get_or_create_tenant_id(account_email)
tenant_id = await get_or_create_tenant_id(account_email, referral_source)

if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
Expand Down Expand Up @@ -690,8 +691,6 @@ def generate_state_token(


# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91


def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
Expand Down Expand Up @@ -741,15 +740,23 @@ def get_oauth_router(
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request, scopes: List[str] = Query(None)
request: Request,
scopes: List[str] = Query(None),
) -> OAuth2AuthorizeResponse:
cookies = request.cookies
referral_source = cookies.get("referral_source", None)

if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))

next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {"next_url": next_url}
# Include the referral_source in the state_data
state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
Expand Down Expand Up @@ -808,8 +815,9 @@ async def callback(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)

next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", "/")

# Authenticate user
# Proceed to authenticate or create the user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
Expand All @@ -819,6 +827,7 @@ async def callback(
token.get("expires_at"),
token.get("refresh_token"),
request,
referral_source=referral_source,
associate_by_email=associate_by_email,
is_verified_by_default=is_verified_by_default,
)
Expand All @@ -845,13 +854,6 @@ async def callback(
for header_name, header_value in response.headers.items():
redirect_response.headers[header_name] = header_value

if hasattr(response, "body"):
redirect_response.body = response.body
if hasattr(response, "status_code"):
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type

return redirect_response

return router
2 changes: 1 addition & 1 deletion backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def get_application() -> FastAPI:
tags=["users"],
)

if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
Expand Down
27 changes: 27 additions & 0 deletions backend/ee/danswer/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import OpenID

from danswer.auth.users import auth_backend
Expand Down Expand Up @@ -59,6 +60,31 @@ def get_application() -> FastAPI:
if MULTI_TENANT:
add_tenant_id_middleware(application, logger)

if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
create_danswer_oauth_router(
oauth_client,
auth_backend,
USER_AUTH_SECRET,
associate_by_email=True,
is_verified_by_default=True,
# Points the user back to the login page
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)

# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)

if AUTH_TYPE == AuthType.OIDC:
include_router_with_global_prefix_prepended(
application,
Expand All @@ -73,6 +99,7 @@ def get_application() -> FastAPI:
prefix="/auth/oidc",
tags=["auth"],
)

# need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
application,
Expand Down
10 changes: 1 addition & 9 deletions web/src/app/auth/login/SignInButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ import { FaGoogle } from "react-icons/fa";
export function SignInButton({
authorizeUrl,
authType,
referralSource,
}: {
authorizeUrl: string;
authType: AuthType;
referralSource?: string;
}) {
let button;
if (authType === "google_oauth" || authType === "cloud") {
Expand Down Expand Up @@ -40,21 +38,15 @@ export function SignInButton({

const url = new URL(authorizeUrl);

// if (referralSource) {
url.searchParams.set("referral_source", "docs");
// }

const finalAuthorizeUrl = url.toString();

console.log("finalAuthorizeUrl", finalAuthorizeUrl);

if (!button) {
throw new Error(`Unhandled authType: ${authType}`);
}

return (
<a
className="mx-auto mt-6 py-3 w-72 text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
className="mx-auto mt-6 py-3 w-full text-text-100 bg-accent flex rounded cursor-pointer hover:bg-indigo-800"
href={finalAuthorizeUrl}
>
{button}
Expand Down
24 changes: 19 additions & 5 deletions web/src/app/auth/signup/ReferralSourceSelector.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Label } from "@/components/admin/connectors/Field";

interface ReferralSourceSelectorProps {
defaultValue?: string;
Expand Down Expand Up @@ -43,14 +44,27 @@ const ReferralSourceSelector: React.FC<ReferralSourceSelectorProps> = ({
};

return (
<div className="w-full max-w-sm mx-auto mb-4">
<div className="w-full max-w-sm gap-y-2 flex flex-col mx-auto">
<Label className="text-text-950" small={false}>
<>
How did you hear about us?{" "}
<span className=" text-sm italic">(optional)</span>
</>
</Label>
<Select value={referralSource} onValueChange={handleChange}>
<SelectTrigger>
<SelectValue placeholder="How did you hear about us?" />
<SelectTrigger
id="referral-source"
className="w-full border-gray-300 rounded-md shadow-sm focus:border-indigo-500 focus:ring-indigo-500"
>
<SelectValue placeholder="Select an option" />
</SelectTrigger>
<SelectContent>
<SelectContent className="max-h-60 overflow-y-auto">
{referralOptions.map((option) => (
<SelectItem key={option.value} value={option.value}>
<SelectItem
key={option.value}
value={option.value}
className="py-2 px-3 hover:bg-indigo-100 cursor-pointer"
>
{option.label}
</SelectItem>
))}
Expand Down
9 changes: 8 additions & 1 deletion web/src/app/auth/signup/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import Link from "next/link";
import { SignInButton } from "../login/SignInButton";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import ReferralSourceSelector from "./ReferralSourceSelector";
import { Separator } from "@/components/ui/separator";

const Page = async () => {
// catch cases where the backend is completely unreachable here
Expand Down Expand Up @@ -63,8 +64,14 @@ const Page = async () => {
<h2 className="text-center text-xl text-strong font-bold">
{cloud ? "Complete your sign up" : "Sign Up for Danswer"}
</h2>
{cloud && (
<>
<div className="w-full flex flex-col items-center space-y-4 mb-4 mt-4">
<ReferralSourceSelector />
</div>
</>
)}

{cloud && <ReferralSourceSelector />}
{cloud && authUrl && (
<div className="w-full justify-center">
<SignInButton authorizeUrl={authUrl} authType="cloud" />
Expand Down
6 changes: 5 additions & 1 deletion web/src/lib/userSS.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ const getOIDCAuthUrlSS = async (nextUrl: string | null): Promise<string> => {
};

const getGoogleOAuthUrlSS = async (): Promise<string> => {
const res = await fetch(buildUrl(`/auth/oauth/authorize`));
const res = await fetch(buildUrl(`/auth/oauth/authorize`), {
headers: {
cookie: processCookies(await cookies()),
},
});
if (!res.ok) {
throw new Error("Failed to fetch data");
}
Expand Down

0 comments on commit 628259f

Please sign in to comment.