Skip to content

Commit

Permalink
Fix API keys for MIT users (#3237)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves authored Nov 25, 2024
1 parent 86d8666 commit 7573416
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 34 deletions.
24 changes: 16 additions & 8 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession

from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
Expand Down Expand Up @@ -80,8 +80,8 @@
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
Expand Down Expand Up @@ -609,7 +609,7 @@ async def logout(
async def optional_user_(
request: Request,
user: User | None,
db_session: Session,
async_db_session: AsyncSession,
) -> User | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
Expand All @@ -618,13 +618,21 @@ async def optional_user_(

async def optional_user(
request: Request,
db_session: Session = Depends(get_session),
async_db_session: AsyncSession = Depends(get_async_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
)
return await versioned_fetch_user(request, user, db_session)
user = await versioned_fetch_user(request, user, async_db_session)

# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)

return user


async def double_check_user(
Expand Down Expand Up @@ -910,8 +918,8 @@ async def callback(
return router


def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
async def api_key_dep(
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
Expand All @@ -921,7 +929,7 @@ def api_key_dep(
raise HTTPException(status_code=401, detail="Missing API key")

if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)

if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
Expand Down
17 changes: 10 additions & 7 deletions backend/danswer/db/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -45,14 +46,16 @@ def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
]


def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | None:
api_key = db_session.scalar(
select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key)
async def fetch_user_for_api_key(
hashed_api_key: str, async_db_session: AsyncSession
) -> User | None:
"""NOTE: this is async, since it's used during auth
(which is necessarily async due to FastAPI Users)"""
return await async_db_session.scalar(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
)
if api_key is None:
return None

return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore


def get_api_key_fake_email(
Expand Down
16 changes: 5 additions & 11 deletions backend/ee/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession

from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.constants import AuthType
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from ee.danswer.db.saml import get_saml_account
Expand All @@ -28,22 +26,18 @@ def verify_auth_setting() -> None:
async def optional_user_(
request: Request,
user: User | None,
db_session: Session,
async_db_session: AsyncSession,
) -> User | None:
# Check if the user has a session cookie from SAML
if AUTH_TYPE == AuthType.SAML:
saved_cookie = extract_hashed_cookie(request)

if saved_cookie:
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
saml_account = await get_saml_account(
cookie=saved_cookie, async_db_session=async_db_session
)
user = saml_account.user if saml_account else None

# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)

return user


Expand Down
15 changes: 11 additions & 4 deletions backend/ee/danswer/db/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy import and_
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
Expand Down Expand Up @@ -44,7 +45,11 @@ def upsert_saml_account(
return saml_acc.expires_at


def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None:
async def get_saml_account(
cookie: str, async_db_session: AsyncSession
) -> SamlAccount | None:
"""NOTE: this is async, since it's used during auth
(which is necessarily async due to FastAPI Users)"""
stmt = (
select(SamlAccount)
.join(User, User.id == SamlAccount.user_id) # type: ignore
Expand All @@ -56,10 +61,12 @@ def get_saml_account(cookie: str, db_session: Session) -> SamlAccount | None:
)
)

result = db_session.execute(stmt)
result = await async_db_session.execute(stmt)
return result.scalar_one_or_none()


def expire_saml_account(saml_account: SamlAccount, db_session: Session) -> None:
async def expire_saml_account(
saml_account: SamlAccount, async_db_session: AsyncSession
) -> None:
saml_account.expires_at = func.now()
db_session.commit()
await async_db_session.commit()
13 changes: 9 additions & 4 deletions backend/ee/danswer/server/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fastapi_users.password import PasswordHelper
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from danswer.auth.schemas import UserCreate
Expand Down Expand Up @@ -170,15 +171,19 @@ async def saml_login_callback(


@router.post("/logout")
def saml_logout(
async def saml_logout(
request: Request,
db_session: Session = Depends(get_session),
async_db_session: AsyncSession = Depends(get_async_session),
) -> None:
saved_cookie = extract_hashed_cookie(request)

if saved_cookie:
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
saml_account = await get_saml_account(
cookie=saved_cookie, async_db_session=async_db_session
)
if saml_account:
expire_saml_account(saml_account, db_session)
await expire_saml_account(
saml_account=saml_account, async_db_session=async_db_session
)

return

0 comments on commit 7573416

Please sign in to comment.