Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix API keys for MIT users #3237

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession

from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
Expand Down Expand Up @@ -80,8 +80,8 @@
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
Expand Down Expand Up @@ -609,7 +609,7 @@ async def logout(
async def optional_user_(
request: Request,
user: User | None,
db_session: Session,
async_db_session: AsyncSession,
) -> User | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
Expand All @@ -618,13 +618,21 @@ async def optional_user_(

async def optional_user(
request: Request,
db_session: Session = Depends(get_session),
async_db_session: AsyncSession = Depends(get_async_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
)
return await versioned_fetch_user(request, user, db_session)
user = await versioned_fetch_user(request, user, async_db_session)

# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)

return user


async def double_check_user(
Expand Down Expand Up @@ -910,8 +918,8 @@ async def callback(
return router


def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
async def api_key_dep(
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
Expand All @@ -921,7 +929,7 @@ def api_key_dep(
raise HTTPException(status_code=401, detail="Missing API key")

if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)

if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
Expand Down
17 changes: 10 additions & 7 deletions backend/danswer/db/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -45,14 +46,16 @@ def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
]


def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | None:
api_key = db_session.scalar(
select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key)
async def fetch_user_for_api_key(
hashed_api_key: str, async_db_session: AsyncSession
) -> User | None:
"""NOTE: this is async, since it's used during auth
(which is necessarily async due to FastAPI Users)"""
return await async_db_session.scalar(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
)
if api_key is None:
return None

return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore


def get_api_key_fake_email(
Expand Down
16 changes: 5 additions & 11 deletions backend/ee/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession

from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.constants import AuthType
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from ee.danswer.db.saml import get_saml_account
Expand All @@ -28,22 +26,18 @@ def verify_auth_setting() -> None:
async def optional_user_(
request: Request,
user: User | None,
db_session: Session,
async_db_session: AsyncSession,
) -> User | None:
# Check if the user has a session cookie from SAML
if AUTH_TYPE == AuthType.SAML:
saved_cookie = extract_hashed_cookie(request)

if saved_cookie:
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
saml_account = await get_saml_account(
cookie=saved_cookie, async_db_session=async_db_session
)
user = saml_account.user if saml_account else None

# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)

return user


Expand Down
15 changes: 11 additions & 4 deletions backend/ee/danswer/db/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy import and_
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
Expand Down Expand Up @@ -44,7 +45,11 @@ def upsert_saml_account(
return saml_acc.expires_at


def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None:
async def get_saml_account(
cookie: str, async_db_session: AsyncSession
) -> SamlAccount | None:
"""NOTE: this is async, since it's used during auth
(which is necessarily async due to FastAPI Users)"""
stmt = (
select(SamlAccount)
.join(User, User.id == SamlAccount.user_id) # type: ignore
Expand All @@ -56,10 +61,12 @@ def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None:
)
)

result = db_session.execute(stmt)
result = await async_db_session.execute(stmt)
return result.scalar_one_or_none()


def expire_saml_account(saml_account: SamlAccount, db_session: Session) -> None:
async def expire_saml_account(
saml_account: SamlAccount, async_db_session: AsyncSession
) -> None:
saml_account.expires_at = func.now()
db_session.commit()
await async_db_session.commit()
13 changes: 9 additions & 4 deletions backend/ee/danswer/server/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fastapi_users.password import PasswordHelper
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from danswer.auth.schemas import UserCreate
Expand Down Expand Up @@ -170,15 +171,19 @@ async def saml_login_callback(


@router.post("/logout")
def saml_logout(
async def saml_logout(
request: Request,
db_session: Session = Depends(get_session),
async_db_session: AsyncSession = Depends(get_async_session),
) -> None:
saved_cookie = extract_hashed_cookie(request)

if saved_cookie:
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
saml_account = await get_saml_account(
cookie=saved_cookie, async_db_session=async_db_session
)
if saml_account:
expire_saml_account(saml_account, db_session)
await expire_saml_account(
saml_account=saml_account, async_db_session=async_db_session
)

return
Loading