Skip to content

Commit

Permalink
Remove ee (#3093)
Browse files Browse the repository at this point in the history
* move api key to non-ee

* finalize previous migration

* move token rate limit to non-ee

* general cleanup

* update

* update

* finalize

* finalize

* ensure callable

* k
  • Loading branch information
pablonyx authored Nov 9, 2024
1 parent 4fb65dc commit 9272d6e
Show file tree
Hide file tree
Showing 39 changed files with 362 additions and 258 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic import BaseModel

from danswer.auth.schemas import UserRole
from ee.danswer.configs.app_configs import API_KEY_HASH_ROUNDS
from danswer.configs.app_configs import API_KEY_HASH_ROUNDS


_API_KEY_HEADER_NAME = "Authorization"
Expand Down
58 changes: 52 additions & 6 deletions backend/danswer/auth/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session

from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
Expand All @@ -74,6 +75,7 @@
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
Expand All @@ -89,9 +91,9 @@
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.server.tenants.provisioning import get_or_create_tenant_id
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

Expand Down Expand Up @@ -221,7 +223,13 @@ async def create(
safe: bool = False,
request: Optional[Request] = None,
) -> User:
tenant_id = await get_or_create_tenant_id(user_create.email)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
)

async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
Expand Down Expand Up @@ -281,7 +289,13 @@ async def oauth_callback(
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
tenant_id = await get_or_create_tenant_id(account_email)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
)

if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
Expand Down Expand Up @@ -419,7 +433,13 @@ async def authenticate(
email = credentials.username

# Get tenant_id from mapping table
tenant_id = get_tenant_id_for_email(email)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=email,
)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
Expand Down Expand Up @@ -477,7 +497,14 @@ async def get_user_manager(
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = get_tenant_id_for_email(user.email)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user.email,
)

data = {
"sub": str(user.id),
"aud": self.token_audience,
Expand Down Expand Up @@ -851,3 +878,22 @@ async def callback(
return redirect_response

return router


def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None

hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")

if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)

if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")

return user
4 changes: 4 additions & 0 deletions backend/danswer/background/task_name_builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def name_sync_external_doc_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"
10 changes: 10 additions & 0 deletions backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,13 @@
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["[email protected]"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")


#####
# API Key Configs
#####
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
API_KEY_HASH_ROUNDS = (
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session

from danswer.auth.api_key import ApiKeyDescriptor
from danswer.auth.api_key import build_displayable_api_key
from danswer.auth.api_key import generate_api_key
from danswer.auth.api_key import hash_api_key
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.models import ApiKey
from danswer.db.models import User
from ee.danswer.auth.api_key import ApiKeyDescriptor
from ee.danswer.auth.api_key import build_displayable_api_key
from ee.danswer.auth.api_key import generate_api_key
from ee.danswer.auth.api_key import hash_api_key
from ee.danswer.server.api_key.models import APIKeyArgs
from danswer.server.api_key.models import APIKeyArgs
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/db/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserRole
from danswer.db.api_key import get_api_key_email_pattern
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.models import AccessToken
Expand All @@ -22,7 +23,6 @@
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from ee.danswer.db.api_key import get_api_key_email_pattern


def get_default_admin_user_emails() -> list[str]:
Expand Down
15 changes: 11 additions & 4 deletions backend/danswer/db/connector_credential_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop


logger = setup_logger()

Expand Down Expand Up @@ -351,7 +351,11 @@ def add_credential_to_connector(
raise HTTPException(status_code=404, detail="Connector does not exist")

if access_type == AccessType.SYNC:
if not check_if_valid_sync_source(connector.source):
if not fetch_ee_implementation_or_noop(
"danswer.external_permissions.sync_params",
"check_if_valid_sync_source",
noop_return_value=True,
)(connector.source):
raise HTTPException(
status_code=400,
detail=f"Connector of type {connector.source} does not support SYNC access type",
Expand Down Expand Up @@ -438,7 +442,10 @@ def remove_credential_from_connector(
)

if association is not None:
delete_user__ext_group_for_cc_pair__no_commit(
fetch_ee_implementation_or_noop(
"danswer.db.external_perm",
"delete_user__ext_group_for_cc_pair__no_commit",
)(
db_session=db_session,
cc_pair_id=association.id,
)
Expand Down
111 changes: 111 additions & 0 deletions backend/danswer/db/token_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from collections.abc import Sequence

from sqlalchemy import select
from sqlalchemy.orm import Session

from danswer.configs.constants import TokenRateLimitScope
from danswer.db.models import TokenRateLimit
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.server.token_rate_limits.models import TokenRateLimitArgs


def fetch_all_user_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.USER
)

if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))

if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())

return db_session.scalars(query).all()


def fetch_all_global_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
)

if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))

if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())

token_rate_limits = db_session.scalars(query).all()
return token_rate_limits


def insert_user_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.USER,
)
db_session.add(token_limit)
db_session.commit()

return token_limit


def insert_global_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.GLOBAL,
)
db_session.add(token_limit)
db_session.commit()

return token_limit


def update_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")

token_limit.enabled = token_rate_limit_settings.enabled
token_limit.token_budget = token_rate_limit_settings.token_budget
token_limit.period_hours = token_rate_limit_settings.period_hours
db_session.commit()

return token_limit


def delete_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
) -> None:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")

db_session.query(TokenRateLimit__UserGroup).filter(
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
).delete()

db_session.delete(token_limit)
db_session.commit()
10 changes: 5 additions & 5 deletions backend/danswer/one_shot_answer/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop

logger = setup_logger()

Expand Down Expand Up @@ -125,11 +125,11 @@ def stream_answer_objects(
)

temporary_persona: Persona | None = None

if query_req.persona_config is not None:
new_persona = create_temporary_persona(
db_session=db_session, persona_config=query_req.persona_config, user=user
)
temporary_persona = new_persona
temporary_persona = fetch_ee_implementation_or_noop(
"danswer.server.query_and_chat.utils", "create_temporary_persona", None
)(db_session=db_session, persona_config=query_req.persona_config, user=user)

persona = temporary_persona if temporary_persona else chat_session.persona

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from sqlalchemy.orm import Session

from danswer.auth.users import current_admin_user
from danswer.db.api_key import ApiKeyDescriptor
from danswer.db.api_key import fetch_api_keys
from danswer.db.api_key import insert_api_key
from danswer.db.api_key import regenerate_api_key
from danswer.db.api_key import remove_api_key
from danswer.db.api_key import update_api_key
from danswer.db.engine import get_session
from danswer.db.models import User
from ee.danswer.db.api_key import ApiKeyDescriptor
from ee.danswer.db.api_key import fetch_api_keys
from ee.danswer.db.api_key import insert_api_key
from ee.danswer.db.api_key import regenerate_api_key
from ee.danswer.db.api_key import remove_api_key
from ee.danswer.db.api_key import update_api_key
from ee.danswer.server.api_key.models import APIKeyArgs
from danswer.server.api_key.models import APIKeyArgs


router = APIRouter(prefix="/admin/api-key")
Expand Down
File renamed without changes.
11 changes: 9 additions & 2 deletions backend/danswer/server/auth_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from danswer.auth.users import current_user_with_expired_token
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.server.danswer_api.ingestion import api_key_dep
from ee.danswer.auth.users import current_cloud_superuser
from ee.danswer.server.tenants.access import control_plane_dep
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop


PUBLIC_ENDPOINT_SPECS = [
Expand Down Expand Up @@ -81,6 +80,14 @@ def check_router_auth(
(1) have auth enabled OR
(2) are explicitly marked as a public endpoint
"""

control_plane_dep = fetch_ee_implementation_or_noop(
"danswer.server.tenants.access", "control_plane_dep"
)
current_cloud_superuser = fetch_ee_implementation_or_noop(
"danswer.auth.users", "current_cloud_superuser"
)

for route in application.routes:
# explicitly marked as public
if is_route_in_spec_list(route, public_endpoint_specs):
Expand Down
Loading

0 comments on commit 9272d6e

Please sign in to comment.