Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge indexing and heartbeat callbacks (and associated lock reacquisi… #3178

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions backend/danswer/background/celery/celery_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from sqlalchemy.orm import Session

from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
Expand All @@ -17,6 +16,7 @@
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.redis.redis_connector import RedisConnector
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -78,7 +78,7 @@ def document_batch_to_ids(

def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
callback: RunIndexingCallbackInterface | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> set[str]:
"""
If the SlimConnector hasnt been implemented for the given connector, just pull
Expand Down Expand Up @@ -111,10 +111,15 @@ def extract_ids_from_runnable_connector(
for doc_batch in doc_batch_generator:
if callback:
if callback.should_stop():
raise RuntimeError("Stop signal received")
callback.progress(len(doc_batch))
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)

all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))

if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))

return all_connector_doc_ids


Expand Down
15 changes: 10 additions & 5 deletions backend/danswer/background/celery/tasks/indexing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
Expand All @@ -42,6 +41,7 @@
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
Expand All @@ -57,7 +57,7 @@
logger = setup_logger()


class RunIndexingCallback(RunIndexingCallbackInterface):
class IndexingCallback(IndexingHeartbeatInterface):
def __init__(
self,
stop_key: str,
Expand All @@ -73,22 +73,27 @@ def __init__(
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()

self.last_tag: str = ""
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)

def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False

def progress(self, amount: int) -> None:
def progress(self, tag: str, amount: int) -> None:
# logger.debug(f"IndexingCallback: tag={tag} amount={amount}")
rkuo-danswer marked this conversation as resolved.
Show resolved Hide resolved

try:
self.redis_lock.reacquire()
self.last_tag = tag
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"RunIndexingCallback - lock.reacquire exceptioned. "
f"IndexingCallback - lock.reacquire exceptioned. "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
Expand Down Expand Up @@ -619,7 +624,7 @@ def connector_indexing_task(
)

# define a callback class
callback = RunIndexingCallback(
callback = IndexingCallback(
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,
Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/background/celery/tasks/pruning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.background.celery.tasks.indexing.tasks import IndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
Expand Down Expand Up @@ -277,7 +277,7 @@ def connector_pruning_generator_task(
cc_pair.credential,
)

callback = RunIndexingCallback(
callback = IndexingCallback(
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,
Expand Down
36 changes: 9 additions & 27 deletions backend/danswer/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import time
import traceback
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timedelta
from datetime import timezone
Expand Down Expand Up @@ -31,7 +29,7 @@
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import setup_logger
from danswer.utils.logger import TaskAttemptSingleton
Expand All @@ -42,19 +40,6 @@
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5


class RunIndexingCallbackInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""

@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""

@abstractmethod
def progress(self, amount: int) -> None:
"""Send progress updates to the caller."""


def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
Expand Down Expand Up @@ -106,7 +91,7 @@ def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
callback: RunIndexingCallbackInterface | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
Expand Down Expand Up @@ -138,13 +123,7 @@ def _run_indexing(

embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
callback=callback,
)

indexing_pipeline = build_indexing_pipeline(
Expand All @@ -157,6 +136,7 @@ def _run_indexing(
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)

db_cc_pair = index_attempt.connector_credential_pair
Expand Down Expand Up @@ -228,7 +208,9 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError("Connector stop signal detected")
raise RuntimeError(
"_run_indexing: Connector stop signal detected"
)

# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
Expand Down Expand Up @@ -289,7 +271,7 @@ def _run_indexing(
db_session.commit()

if callback:
callback.progress(len(doc_batch))
callback.progress("_run_indexing", len(doc_batch))

# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
Expand Down Expand Up @@ -419,7 +401,7 @@ def run_indexing_entrypoint(
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
callback: RunIndexingCallbackInterface | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
try:
if is_ee:
Expand Down
17 changes: 11 additions & 6 deletions backend/danswer/indexing/chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
mini_chunk_size: int = MINI_CHUNK_SIZE,
heartbeat: Heartbeat | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
from llama_index.text_splitter import SentenceSplitter

Expand All @@ -134,7 +134,7 @@ def __init__(
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.tokenizer = tokenizer
self.heartbeat = heartbeat
self.callback = callback

self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
Expand Down Expand Up @@ -356,9 +356,14 @@ def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
final_chunks: list[DocAwareChunk] = []
for document in documents:
final_chunks.extend(self._handle_single_document(document))
if self.callback:
if self.callback.should_stop():
raise RuntimeError("Chunker.chunk: Stop signal detected")

if self.heartbeat:
self.heartbeat.heartbeat()
chunks = self._handle_single_document(document)
final_chunks.extend(chunks)

if self.callback:
self.callback.progress("Chunker.chunk", len(chunks))

return final_chunks
16 changes: 9 additions & 7 deletions backend/danswer/indexing/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import abstractmethod

from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
api_url: str | None,
api_version: str | None,
deployment_name: str | None,
heartbeat: Heartbeat | None,
callback: IndexingHeartbeatInterface | None,
):
self.model_name = model_name
self.normalize = normalize
Expand All @@ -60,7 +60,7 @@ def __init__(
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
heartbeat=heartbeat,
callback=callback,
)

@abstractmethod
Expand All @@ -83,7 +83,7 @@ def __init__(
api_url: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
heartbeat: Heartbeat | None = None,
callback: IndexingHeartbeatInterface | None = None,
):
super().__init__(
model_name,
Expand All @@ -95,7 +95,7 @@ def __init__(
api_url,
api_version,
deployment_name,
heartbeat,
callback,
)

@log_function_time()
Expand Down Expand Up @@ -201,7 +201,9 @@ def embed_chunks(

@classmethod
def from_db_search_settings(
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
cls,
search_settings: SearchSettings,
callback: IndexingHeartbeatInterface | None = None,
) -> "DefaultIndexingEmbedder":
return cls(
model_name=search_settings.model_name,
Expand All @@ -213,5 +215,5 @@ def from_db_search_settings(
api_url=search_settings.api_url,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
heartbeat=heartbeat,
callback=callback,
)
Loading
Loading