Skip to content

Commit

Permalink
doc_sync is refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
hagen-danswer committed Nov 7, 2024
1 parent 2b1dbde commit dd82e37
Show file tree
Hide file tree
Showing 28 changed files with 778 additions and 375 deletions.
7 changes: 7 additions & 0 deletions backend/danswer/access/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ class ExternalAccess:
is_public: bool


@dataclass(frozen=True)
class DocumentExternalAccess:
external_access: ExternalAccess
# The document ID
doc_id: str


@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Danswer users, None indicates admin
Expand Down
7 changes: 7 additions & 0 deletions backend/danswer/background/celery/apps/app_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorDocPermSyncs
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
Expand Down Expand Up @@ -132,6 +133,12 @@ def on_task_postrun(
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
return

if task_id.startswith(RedisConnectorDocPermSyncs.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorDocPermSyncs.remove_from_taskset(int(cc_pair_id), task_id, r)
return


def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
Expand Down
3 changes: 3 additions & 0 deletions backend/danswer/background/celery/apps/primary.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from danswer.db.engine import SqlEngine
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorDocPermSyncs
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
Expand Down Expand Up @@ -131,6 +132,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:

RedisConnectorStop.reset_all(r)

RedisConnectorDocPermSyncs.reset_all(r)


@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion backend/danswer/background/celery/celery_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def extract_ids_from_runnable_connector(
callback: RunIndexingCallbackInterface | None = None,
) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
If the SlimConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.
Expand Down
6 changes: 6 additions & 0 deletions backend/danswer/background/celery/tasks/beat_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-doc-permissions-sync",
"task": "check_for_doc_permissions_sync",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def try_generate_document_cc_pair_cleanup_tasks(
f"cc_pair={cc_pair_id}"
)

if redis_connector.permissions.fenced:
raise TaskDependencyError(
f"Connector deletion - Delayed (permissions in progress): "
f"cc_pair={cc_pair_id}"
)

# add tasks to celery and build up the task set to monitor in redis
redis_connector.delete.taskset_clear()

Expand Down
77 changes: 35 additions & 42 deletions backend/danswer/background/celery/tasks/pruning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,35 @@
logger = setup_logger()


def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if pruning is due."""

# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False

# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False

# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False

# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time

next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False

return True


@shared_task(
name="check_for_pruning",
soft_time_limit=JOB_TIMEOUT,
Expand Down Expand Up @@ -69,7 +98,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
if not cc_pair:
continue

if not is_pruning_due(cc_pair, db_session, r):
if not _is_pruning_due(cc_pair):
continue

tasks_created = try_creating_prune_generator_task(
Expand All @@ -90,47 +119,6 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()


def is_pruning_due(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
) -> bool:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
Returns None if no pruning is triggered (due to not being needed or
other reasons such as simultaneous pruning restrictions.
Checks for scheduling related conditions, then delegates the rest of the checks to
try_creating_prune_generator_task.
"""

# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False

# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False

# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False

# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time

next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False

return True


def try_creating_prune_generator_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
Expand Down Expand Up @@ -172,6 +160,11 @@ def try_creating_prune_generator_task(
if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting
return None

if (
redis_connector.permissions.fenced
): # skip pruning if the cc_pair is deleting
return None

db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
Expand Down
44 changes: 44 additions & 0 deletions backend/danswer/background/celery/tasks/vespa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_ccpair_as_permissions_synced
from danswer.db.connector import mark_ccpair_as_pruned
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
Expand Down Expand Up @@ -58,6 +59,7 @@
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorDocPermSyncs
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
Expand Down Expand Up @@ -546,6 +548,42 @@ def monitor_ccpair_pruning_taskset(
redis_connector.prune.set_fence(False)


def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
)
return

cc_pair_id = int(cc_pair_id_str)

redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.permissions.fenced:
return

initial = redis_connector.permissions.generator_complete
if initial is None:
return

remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return

mark_ccpair_as_permissions_synced(int(cc_pair_id), db_session)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")

redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(False)


def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
Expand Down Expand Up @@ -741,6 +779,12 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)

lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDocPermSyncs.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)

# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
Expand Down
9 changes: 0 additions & 9 deletions backend/danswer/background/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,6 @@
from danswer.db.tasks import register_task


def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
if not connector_id or not credential_id:
task_name = "prune_connector_credential_pair"
return task_name


T = TypeVar("T", bound=Callable)


Expand Down
6 changes: 6 additions & 0 deletions backend/danswer/configs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min

CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min

DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"


Expand Down Expand Up @@ -212,6 +214,7 @@ class DanswerCeleryQueues:
CONNECTOR_DELETION = "connector_deletion"
CONNECTOR_PRUNING = "connector_pruning"
CONNECTOR_INDEXING = "connector_indexing"
CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync"


class DanswerRedisLocks:
Expand All @@ -220,6 +223,9 @@ class DanswerRedisLocks:
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
"da_lock:check_connector_doc_permissions_sync_beat"
)
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"

PRUNING_LOCK_PREFIX = "da_lock:pruning"
Expand Down
12 changes: 12 additions & 0 deletions backend/danswer/db/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,15 @@ def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None:

cc_pair.last_pruned = datetime.now(timezone.utc)
db_session.commit()


def mark_ccpair_as_permissions_synced(cc_pair_id: int, db_session: Session) -> None:
stmt = select(ConnectorCredentialPair).where(
ConnectorCredentialPair.id == cc_pair_id
)
cc_pair = db_session.scalar(stmt)
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")

cc_pair.last_time_perm_sync = datetime.now(timezone.utc)
db_session.commit()
3 changes: 3 additions & 0 deletions backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ class ConnectorCredentialPair(Base):
last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
last_time_external_group_sync: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Time finished, not used for calculating backend jobs which uses time started (created)
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
Expand Down
15 changes: 15 additions & 0 deletions backend/danswer/db/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,18 @@ def batch_add_non_web_user_if_not_exists__no_commit(
db_session.flush() # generate ids

return found_users + new_users


def batch_add_non_web_user_if_not_exists(
db_session: Session, emails: list[str]
) -> list[User]:
found_users, missing_user_emails = get_users_by_emails(db_session, emails)

new_users: list[User] = []
for email in missing_user_emails:
new_users.append(_generate_non_web_user(email=email))

db_session.add_all(new_users)
db_session.commit()

return found_users + new_users
2 changes: 2 additions & 0 deletions backend/danswer/redis/redis_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import redis

from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorDocPermSyncs
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
Expand All @@ -19,6 +20,7 @@ def __init__(self, tenant_id: str | None, id: int) -> None:
self.stop = RedisConnectorStop(tenant_id, id, self.redis)
self.prune = RedisConnectorPrune(tenant_id, id, self.redis)
self.delete = RedisConnectorDelete(tenant_id, id, self.redis)
self.permissions = RedisConnectorDocPermSyncs(tenant_id, id, self.redis)

def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
return RedisConnectorIndex(
Expand Down
Loading

0 comments on commit dd82e37

Please sign in to comment.