From f25ad1edbf5d6c922febccc90d6784bd8251d957 Mon Sep 17 00:00:00 2001 From: "Richard Kuo (Danswer)" Date: Wed, 20 Nov 2024 12:44:37 -0800 Subject: [PATCH 1/2] merge indexing and heartbeat callbacks (and associated lock reacquisition). no db updates --- .../danswer/background/celery/celery_utils.py | 13 ++- .../background/celery/tasks/indexing/tasks.py | 15 ++-- .../background/celery/tasks/pruning/tasks.py | 4 +- .../background/indexing/run_indexing.py | 36 +++------ backend/danswer/indexing/chunker.py | 17 ++-- backend/danswer/indexing/embedder.py | 16 ++-- .../danswer/indexing/indexing_heartbeat.py | 71 +++++++++------- backend/danswer/indexing/indexing_pipeline.py | 12 +-- .../search_nlp_models.py | 14 ++-- .../tests/unit/danswer/indexing/conftest.py | 11 +-- .../unit/danswer/indexing/test_chunker.py | 2 +- .../unit/danswer/indexing/test_heartbeat.py | 80 ------------------- 12 files changed, 111 insertions(+), 180 deletions(-) delete mode 100644 backend/tests/unit/danswer/indexing/test_heartbeat.py diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index c8a125d11b1..22142fee202 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -4,7 +4,6 @@ from sqlalchemy.orm import Session -from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, @@ -17,6 +16,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair from danswer.db.enums import TaskStatus from danswer.db.models import TaskQueueState +from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface from danswer.redis.redis_connector import RedisConnector from danswer.server.documents.models import DeletionAttemptSnapshot from danswer.utils.logger import setup_logger @@ -78,7 +78,7 @@ def document_batch_to_ids( def extract_ids_from_runnable_connector( runnable_connector: BaseConnector, - callback: RunIndexingCallbackInterface | None = None, + callback: IndexingHeartbeatInterface | None = None, ) -> set[str]: """ If the SlimConnector hasnt been implemented for the given connector, just pull @@ -111,10 +111,15 @@ def extract_ids_from_runnable_connector( for doc_batch in doc_batch_generator: if callback: if callback.should_stop(): - raise RuntimeError("Stop signal received") - callback.progress(len(doc_batch)) + raise RuntimeError( + "extract_ids_from_runnable_connector: Stop signal detected" + ) + all_connector_doc_ids.update(doc_batch_processing_func(doc_batch)) + if callback: + callback.progress("extract_ids_from_runnable_connector", len(doc_batch)) + return all_connector_doc_ids diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index 3bcb650e7c7..5e574944b21 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -16,7 +16,6 @@ from danswer.background.celery.apps.app_base import task_logger from danswer.background.indexing.job_client import SimpleJobClient from danswer.background.indexing.run_indexing import run_indexing_entrypoint -from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT @@ -42,6 +41,7 @@ from danswer.db.search_settings import get_current_search_settings from danswer.db.search_settings import get_secondary_search_settings from danswer.db.swap_index import check_index_swap +from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder from danswer.redis.redis_connector import RedisConnector @@ -57,7 +57,7 @@ logger = setup_logger() -class RunIndexingCallback(RunIndexingCallbackInterface): +class IndexingCallback(IndexingHeartbeatInterface): def __init__( self, stop_key: str, @@ -73,6 +73,7 @@ def __init__( self.started: datetime = datetime.now(timezone.utc) self.redis_lock.reacquire() + self.last_tag: str = "" self.last_lock_reacquire: datetime = datetime.now(timezone.utc) def should_stop(self) -> bool: @@ -80,15 +81,19 @@ def should_stop(self) -> bool: return True return False - def progress(self, amount: int) -> None: + def progress(self, tag: str, amount: int) -> None: + # logger.debug(f"IndexingCallback: tag={tag} amount={amount}") + try: self.redis_lock.reacquire() + self.last_tag = tag self.last_lock_reacquire = datetime.now(timezone.utc) except LockError: logger.exception( - f"RunIndexingCallback - lock.reacquire exceptioned. " + f"IndexingCallback - lock.reacquire exceptioned. " f"lock_timeout={self.redis_lock.timeout} " f"start={self.started} " + f"last_tag={self.last_tag} " f"last_reacquired={self.last_lock_reacquire} " f"now={datetime.now(timezone.utc)}" ) @@ -619,7 +624,7 @@ def connector_indexing_task( ) # define a callback class - callback = RunIndexingCallback( + callback = IndexingCallback( redis_connector.stop.fence_key, redis_connector_index.generator_progress_key, lock, diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index 6398e0a6cc2..67b781f228f 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -12,7 +12,7 @@ from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector -from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback +from danswer.background.celery.tasks.indexing.tasks import IndexingCallback from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT @@ -277,7 +277,7 @@ def connector_pruning_generator_task( cc_pair.credential, ) - callback = RunIndexingCallback( + callback = IndexingCallback( redis_connector.stop.fence_key, redis_connector.prune.generator_progress_key, lock, diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 8f3ae65fa3d..699e4682caa 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -1,7 +1,5 @@ import time import traceback -from abc import ABC -from abc import abstractmethod from datetime import datetime from datetime import timedelta from datetime import timezone @@ -31,7 +29,7 @@ from danswer.db.models import IndexModelStatus from danswer.document_index.factory import get_default_document_index from danswer.indexing.embedder import DefaultIndexingEmbedder -from danswer.indexing.indexing_heartbeat import IndexingHeartbeat +from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface from danswer.indexing.indexing_pipeline import build_indexing_pipeline from danswer.utils.logger import setup_logger from danswer.utils.logger import TaskAttemptSingleton @@ -42,19 +40,6 @@ INDEXING_TRACER_NUM_PRINT_ENTRIES = 5 -class RunIndexingCallbackInterface(ABC): - """Defines a callback interface to be passed to - to run_indexing_entrypoint.""" - - @abstractmethod - def should_stop(self) -> bool: - """Signal to stop the looping function in flight.""" - - @abstractmethod - def progress(self, amount: int) -> None: - """Send progress updates to the caller.""" - - def _get_connector_runner( db_session: Session, attempt: IndexAttempt, @@ -106,7 +91,7 @@ def _run_indexing( db_session: Session, index_attempt: IndexAttempt, tenant_id: str | None, - callback: RunIndexingCallbackInterface | None = None, + callback: IndexingHeartbeatInterface | None = None, ) -> None: """ 1. Get documents which are either new or updated from specified application @@ -138,13 +123,7 @@ def _run_indexing( embedding_model = DefaultIndexingEmbedder.from_db_search_settings( search_settings=search_settings, - heartbeat=IndexingHeartbeat( - index_attempt_id=index_attempt.id, - db_session=db_session, - # let the world know we're still making progress after - # every 10 batches - freq=10, - ), + callback=callback, ) indexing_pipeline = build_indexing_pipeline( @@ -157,6 +136,7 @@ def _run_indexing( ), db_session=db_session, tenant_id=tenant_id, + callback=callback, ) db_cc_pair = index_attempt.connector_credential_pair @@ -228,7 +208,9 @@ def _run_indexing( # contents still need to be initially pulled. if callback: if callback.should_stop(): - raise RuntimeError("Connector stop signal detected") + raise RuntimeError( + "_run_indexing: Connector stop signal detected" + ) # TODO: should we move this into the above callback instead? db_session.refresh(db_cc_pair) @@ -289,7 +271,7 @@ def _run_indexing( db_session.commit() if callback: - callback.progress(len(doc_batch)) + callback.progress("_run_indexing", len(doc_batch)) # This new value is updated every batch, so UI can refresh per batch update update_docs_indexed( @@ -419,7 +401,7 @@ def run_indexing_entrypoint( tenant_id: str | None, connector_credential_pair_id: int, is_ee: bool = False, - callback: RunIndexingCallbackInterface | None = None, + callback: IndexingHeartbeatInterface | None = None, ) -> None: try: if is_ee: diff --git a/backend/danswer/indexing/chunker.py b/backend/danswer/indexing/chunker.py index 57f05a66e1e..35dd919af6b 100644 --- a/backend/danswer/indexing/chunker.py +++ b/backend/danswer/indexing/chunker.py @@ -10,7 +10,7 @@ get_metadata_keys_to_ignore, ) from danswer.connectors.models import Document -from danswer.indexing.indexing_heartbeat import Heartbeat +from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface from danswer.indexing.models import DocAwareChunk from danswer.natural_language_processing.utils import BaseTokenizer from danswer.utils.logger import setup_logger @@ -125,7 +125,7 @@ def __init__( chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE, chunk_overlap: int = CHUNK_OVERLAP, mini_chunk_size: int = MINI_CHUNK_SIZE, - heartbeat: Heartbeat | None = None, + callback: IndexingHeartbeatInterface | None = None, ) -> None: from llama_index.text_splitter import SentenceSplitter @@ -134,7 +134,7 @@ def __init__( self.enable_multipass = enable_multipass self.enable_large_chunks = enable_large_chunks self.tokenizer = tokenizer - self.heartbeat = heartbeat + self.callback = callback self.blurb_splitter = SentenceSplitter( tokenizer=tokenizer.tokenize, @@ -356,9 +356,14 @@ def _handle_single_document(self, document: Document) -> list[DocAwareChunk]: def chunk(self, documents: list[Document]) -> list[DocAwareChunk]: final_chunks: list[DocAwareChunk] = [] for document in documents: - final_chunks.extend(self._handle_single_document(document)) + if self.callback: + if self.callback.should_stop(): + raise RuntimeError("Chunker.chunk: Stop signal detected") - if self.heartbeat: - self.heartbeat.heartbeat() + chunks = self._handle_single_document(document) + final_chunks.extend(chunks) + + if self.callback: + self.callback.progress("Chunker.chunk", len(chunks)) return final_chunks diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index 1c11a01b390..2e975324186 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -2,7 +2,7 @@ from abc import abstractmethod from danswer.db.models import SearchSettings -from danswer.indexing.indexing_heartbeat import Heartbeat +from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface from danswer.indexing.models import ChunkEmbedding from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import IndexChunk @@ -34,7 +34,7 @@ def __init__( api_url: str | None, api_version: str | None, deployment_name: str | None, - heartbeat: Heartbeat | None, + callback: IndexingHeartbeatInterface | None, ): self.model_name = model_name self.normalize = normalize @@ -60,7 +60,7 @@ def __init__( server_host=INDEXING_MODEL_SERVER_HOST, server_port=INDEXING_MODEL_SERVER_PORT, retrim_content=True, - heartbeat=heartbeat, + callback=callback, ) @abstractmethod @@ -83,7 +83,7 @@ def __init__( api_url: str | None = None, api_version: str | None = None, deployment_name: str | None = None, - heartbeat: Heartbeat | None = None, + callback: IndexingHeartbeatInterface | None = None, ): super().__init__( model_name, @@ -95,7 +95,7 @@ def __init__( api_url, api_version, deployment_name, - heartbeat, + callback, ) @log_function_time() @@ -201,7 +201,9 @@ def embed_chunks( @classmethod def from_db_search_settings( - cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None + cls, + search_settings: SearchSettings, + callback: IndexingHeartbeatInterface | None = None, ) -> "DefaultIndexingEmbedder": return cls( model_name=search_settings.model_name, @@ -213,5 +215,5 @@ def from_db_search_settings( api_url=search_settings.api_url, api_version=search_settings.api_version, deployment_name=search_settings.deployment_name, - heartbeat=heartbeat, + callback=callback, ) diff --git a/backend/danswer/indexing/indexing_heartbeat.py b/backend/danswer/indexing/indexing_heartbeat.py index c500a0ad559..195929d2654 100644 --- a/backend/danswer/indexing/indexing_heartbeat.py +++ b/backend/danswer/indexing/indexing_heartbeat.py @@ -1,41 +1,52 @@ -import abc -from typing import Any +from abc import ABC +from abc import abstractmethod -from sqlalchemy import func -from sqlalchemy.orm import Session - -from danswer.db.index_attempt import get_index_attempt from danswer.utils.logger import setup_logger +# from danswer.db.index_attempt import get_index_attempt + logger = setup_logger() -class Heartbeat(abc.ABC): - """Useful for any long-running work that goes through a bunch of items - and needs to occasionally give updates on progress. - e.g. chunking, embedding, updating vespa, etc.""" +# class Heartbeat(abc.ABC): +# """Useful for any long-running work that goes through a bunch of items +# and needs to occasionally give updates on progress. +# e.g. chunking, embedding, updating vespa, etc.""" + +# @abc.abstractmethod +# def heartbeat(self, metadata: Any = None) -> None: +# raise NotImplementedError + + +# class IndexingHeartbeat(Heartbeat): +# def __init__(self, index_attempt_id: int, db_session: Session, freq: int): +# self.cnt = 0 + +# self.index_attempt_id = index_attempt_id +# self.db_session = db_session +# self.freq = freq - @abc.abstractmethod - def heartbeat(self, metadata: Any = None) -> None: - raise NotImplementedError +# def heartbeat(self, metadata: Any = None) -> None: +# self.cnt += 1 +# if self.cnt % self.freq == 0: +# index_attempt = get_index_attempt( +# db_session=self.db_session, index_attempt_id=self.index_attempt_id +# ) +# if index_attempt: +# index_attempt.time_updated = func.now() +# self.db_session.commit() +# else: +# logger.error("Index attempt not found, this should not happen!") -class IndexingHeartbeat(Heartbeat): - def __init__(self, index_attempt_id: int, db_session: Session, freq: int): - self.cnt = 0 +class IndexingHeartbeatInterface(ABC): + """Defines a callback interface to be passed to + to run_indexing_entrypoint.""" - self.index_attempt_id = index_attempt_id - self.db_session = db_session - self.freq = freq + @abstractmethod + def should_stop(self) -> bool: + """Signal to stop the looping function in flight.""" - def heartbeat(self, metadata: Any = None) -> None: - self.cnt += 1 - if self.cnt % self.freq == 0: - index_attempt = get_index_attempt( - db_session=self.db_session, index_attempt_id=self.index_attempt_id - ) - if index_attempt: - index_attempt.time_updated = func.now() - self.db_session.commit() - else: - logger.error("Index attempt not found, this should not happen!") + @abstractmethod + def progress(self, tag: str, amount: int) -> None: + """Send progress updates to the caller.""" diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index 28cfffcb0ca..b1ee8f4d944 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -34,7 +34,7 @@ from danswer.document_index.interfaces import DocumentMetadata from danswer.indexing.chunker import Chunker from danswer.indexing.embedder import IndexingEmbedder -from danswer.indexing.indexing_heartbeat import IndexingHeartbeat +from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import DocMetadataAwareIndexChunk from danswer.utils.logger import setup_logger @@ -414,6 +414,7 @@ def build_indexing_pipeline( ignore_time_skip: bool = False, attempt_id: int | None = None, tenant_id: str | None = None, + callback: IndexingHeartbeatInterface | None = None, ) -> IndexingPipelineProtocol: """Builds a pipeline which takes in a list (batch) of docs and indexes them.""" search_settings = get_current_search_settings(db_session) @@ -440,13 +441,8 @@ def build_indexing_pipeline( tokenizer=embedder.embedding_model.tokenizer, enable_multipass=multipass, enable_large_chunks=enable_large_chunks, - # after every doc, update status in case there are a bunch of - # really long docs - heartbeat=IndexingHeartbeat( - index_attempt_id=attempt_id, db_session=db_session, freq=1 - ) - if attempt_id - else None, + # after every doc, update status in case there are a bunch of really long docs + callback=callback, ) return partial( diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index d75fce304d6..9a3d575c0f9 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -16,7 +16,7 @@ ) from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.models import SearchSettings -from danswer.indexing.indexing_heartbeat import Heartbeat +from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface from danswer.natural_language_processing.utils import get_tokenizer from danswer.natural_language_processing.utils import tokenizer_trim_content from danswer.utils.logger import setup_logger @@ -99,7 +99,7 @@ def __init__( api_url: str | None, provider_type: EmbeddingProvider | None, retrim_content: bool = False, - heartbeat: Heartbeat | None = None, + callback: IndexingHeartbeatInterface | None = None, api_version: str | None = None, deployment_name: str | None = None, ) -> None: @@ -116,7 +116,7 @@ def __init__( self.tokenizer = get_tokenizer( model_name=model_name, provider_type=provider_type ) - self.heartbeat = heartbeat + self.callback = callback model_server_url = build_model_server_url(server_host, server_port) self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed" @@ -160,6 +160,10 @@ def _batch_encode_texts( embeddings: list[Embedding] = [] for idx, text_batch in enumerate(text_batches, start=1): + if self.callback: + if self.callback.should_stop(): + raise RuntimeError("_batch_encode_texts detected stop signal") + logger.debug(f"Encoding batch {idx} of {len(text_batches)}") embed_request = EmbedRequest( model_name=self.model_name, @@ -179,8 +183,8 @@ def _batch_encode_texts( response = self._make_model_server_request(embed_request) embeddings.extend(response.embeddings) - if self.heartbeat: - self.heartbeat.heartbeat() + if self.callback: + self.callback.progress("_batch_encode_texts", 1) return embeddings def encode( diff --git a/backend/tests/unit/danswer/indexing/conftest.py b/backend/tests/unit/danswer/indexing/conftest.py index 36e5659143f..193e53b828d 100644 --- a/backend/tests/unit/danswer/indexing/conftest.py +++ b/backend/tests/unit/danswer/indexing/conftest.py @@ -1,15 +1,16 @@ -from typing import Any - import pytest -from danswer.indexing.indexing_heartbeat import Heartbeat +from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface -class MockHeartbeat(Heartbeat): +class MockHeartbeat(IndexingHeartbeatInterface): def __init__(self) -> None: self.call_count = 0 - def heartbeat(self, metadata: Any = None) -> None: + def should_stop(self) -> bool: + return False + + def progress(self, tag: str, amount: int) -> None: self.call_count += 1 diff --git a/backend/tests/unit/danswer/indexing/test_chunker.py b/backend/tests/unit/danswer/indexing/test_chunker.py index 71c3bbd886f..065af49b95d 100644 --- a/backend/tests/unit/danswer/indexing/test_chunker.py +++ b/backend/tests/unit/danswer/indexing/test_chunker.py @@ -74,7 +74,7 @@ def test_chunker_heartbeat( chunker = Chunker( tokenizer=embedder.embedding_model.tokenizer, enable_multipass=False, - heartbeat=mock_heartbeat, + callback=mock_heartbeat, ) chunks = chunker.chunk([document]) diff --git a/backend/tests/unit/danswer/indexing/test_heartbeat.py b/backend/tests/unit/danswer/indexing/test_heartbeat.py deleted file mode 100644 index a59fac81283..00000000000 --- a/backend/tests/unit/danswer/indexing/test_heartbeat.py +++ /dev/null @@ -1,80 +0,0 @@ -from unittest.mock import MagicMock -from unittest.mock import patch - -import pytest -from sqlalchemy.orm import Session - -from danswer.db.index_attempt import IndexAttempt -from danswer.indexing.indexing_heartbeat import IndexingHeartbeat - - -@pytest.fixture -def mock_db_session() -> MagicMock: - return MagicMock(spec=Session) - - -@pytest.fixture -def mock_index_attempt() -> MagicMock: - return MagicMock(spec=IndexAttempt) - - -def test_indexing_heartbeat( - mock_db_session: MagicMock, mock_index_attempt: MagicMock -) -> None: - with patch( - "danswer.indexing.indexing_heartbeat.get_index_attempt" - ) as mock_get_index_attempt: - mock_get_index_attempt.return_value = mock_index_attempt - - heartbeat = IndexingHeartbeat( - index_attempt_id=1, db_session=mock_db_session, freq=5 - ) - - # Test that heartbeat doesn't update before freq is reached - for _ in range(4): - heartbeat.heartbeat() - - mock_db_session.commit.assert_not_called() - - # Test that heartbeat updates when freq is reached - heartbeat.heartbeat() - - mock_get_index_attempt.assert_called_once_with( - db_session=mock_db_session, index_attempt_id=1 - ) - assert mock_index_attempt.time_updated is not None - mock_db_session.commit.assert_called_once() - - # Reset mock calls - mock_db_session.reset_mock() - mock_get_index_attempt.reset_mock() - - # Test that heartbeat updates again after freq more calls - for _ in range(5): - heartbeat.heartbeat() - - mock_get_index_attempt.assert_called_once() - mock_db_session.commit.assert_called_once() - - -def test_indexing_heartbeat_not_found(mock_db_session: MagicMock) -> None: - with patch( - "danswer.indexing.indexing_heartbeat.get_index_attempt" - ) as mock_get_index_attempt, patch( - "danswer.indexing.indexing_heartbeat.logger" - ) as mock_logger: - mock_get_index_attempt.return_value = None - - heartbeat = IndexingHeartbeat( - index_attempt_id=1, db_session=mock_db_session, freq=1 - ) - - heartbeat.heartbeat() - - mock_get_index_attempt.assert_called_once_with( - db_session=mock_db_session, index_attempt_id=1 - ) - mock_logger.error.assert_called_once_with( - "Index attempt not found, this should not happen!" - ) - mock_db_session.commit.assert_not_called() From 3b3008ff29ce236430a25aafac4a0fb3ea2ce769 Mon Sep 17 00:00:00 2001 From: "Richard Kuo (Danswer)" Date: Thu, 21 Nov 2024 11:24:38 -0800 Subject: [PATCH 2/2] review fixes --- .../background/celery/tasks/indexing/tasks.py | 2 - .../danswer/indexing/indexing_heartbeat.py | 37 ------------------- 2 files changed, 39 deletions(-) diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index 5e574944b21..a353e7fe215 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -82,8 +82,6 @@ def should_stop(self) -> bool: return False def progress(self, tag: str, amount: int) -> None: - # logger.debug(f"IndexingCallback: tag={tag} amount={amount}") - try: self.redis_lock.reacquire() self.last_tag = tag diff --git a/backend/danswer/indexing/indexing_heartbeat.py b/backend/danswer/indexing/indexing_heartbeat.py index 195929d2654..fe5f83d0b86 100644 --- a/backend/danswer/indexing/indexing_heartbeat.py +++ b/backend/danswer/indexing/indexing_heartbeat.py @@ -1,43 +1,6 @@ from abc import ABC from abc import abstractmethod -from danswer.utils.logger import setup_logger - -# from danswer.db.index_attempt import get_index_attempt - -logger = setup_logger() - - -# class Heartbeat(abc.ABC): -# """Useful for any long-running work that goes through a bunch of items -# and needs to occasionally give updates on progress. -# e.g. chunking, embedding, updating vespa, etc.""" - -# @abc.abstractmethod -# def heartbeat(self, metadata: Any = None) -> None: -# raise NotImplementedError - - -# class IndexingHeartbeat(Heartbeat): -# def __init__(self, index_attempt_id: int, db_session: Session, freq: int): -# self.cnt = 0 - -# self.index_attempt_id = index_attempt_id -# self.db_session = db_session -# self.freq = freq - -# def heartbeat(self, metadata: Any = None) -> None: -# self.cnt += 1 -# if self.cnt % self.freq == 0: -# index_attempt = get_index_attempt( -# db_session=self.db_session, index_attempt_id=self.index_attempt_id -# ) -# if index_attempt: -# index_attempt.time_updated = func.now() -# self.db_session.commit() -# else: -# logger.error("Index attempt not found, this should not happen!") - class IndexingHeartbeatInterface(ABC): """Defines a callback interface to be passed to