Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed Nov 24, 2024
1 parent efe60ec commit 3c31451
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
8 changes: 3 additions & 5 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
Expand Down Expand Up @@ -83,7 +82,6 @@
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
Expand Down Expand Up @@ -920,8 +918,8 @@ async def callback(
return router


def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
async def api_key_dep(
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
Expand All @@ -931,7 +929,7 @@ def api_key_dep(
raise HTTPException(status_code=401, detail="Missing API key")

if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)

if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
Expand Down
6 changes: 4 additions & 2 deletions backend/ee/danswer/db/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ async def get_saml_account(
return result.scalar_one_or_none()


def expire_saml_account(saml_account: SamlAccount, db_session: Session) -> None:
async def expire_saml_account(
saml_account: SamlAccount, async_db_session: AsyncSession
) -> None:
saml_account.expires_at = func.now()
db_session.commit()
await async_db_session.commit()
13 changes: 9 additions & 4 deletions backend/ee/danswer/server/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fastapi_users.password import PasswordHelper
from onelogin.saml2.auth import OneLogin_Saml2_Auth # type: ignore
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from danswer.auth.schemas import UserCreate
Expand Down Expand Up @@ -170,15 +171,19 @@ async def saml_login_callback(


@router.post("/logout")
def saml_logout(
async def saml_logout(
request: Request,
db_session: Session = Depends(get_session),
async_db_session: AsyncSession = Depends(get_async_session),
) -> None:
saved_cookie = extract_hashed_cookie(request)

if saved_cookie:
saml_account = get_saml_account(cookie=saved_cookie, db_session=db_session)
saml_account = await get_saml_account(
cookie=saved_cookie, async_db_session=async_db_session
)
if saml_account:
expire_saml_account(saml_account, db_session)
await expire_saml_account(
saml_account=saml_account, async_db_session=async_db_session
)

return

0 comments on commit 3c31451

Please sign in to comment.