From ac54120ed0733a403a6de44e28ec933fb85bdf94 Mon Sep 17 00:00:00 2001 From: Weves Date: Mon, 13 Jan 2025 11:30:44 -0800 Subject: [PATCH] Various fixes/improvements to document counting --- .../onyx/background/indexing/run_indexing.py | 32 ++++--- backend/onyx/db/document.py | 39 +++++++-- backend/onyx/db/models.py | 6 ++ backend/onyx/indexing/indexing_pipeline.py | 83 ++++++++++++------- 4 files changed, 112 insertions(+), 48 deletions(-) diff --git a/backend/onyx/background/indexing/run_indexing.py b/backend/onyx/background/indexing/run_indexing.py index 623eb8edbe3..b0ee39883ae 100644 --- a/backend/onyx/background/indexing/run_indexing.py +++ b/backend/onyx/background/indexing/run_indexing.py @@ -97,10 +97,17 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]: for doc in doc_batch: cleaned_doc = doc.model_copy() + # Postgres cannot handle NUL characters in text fields if "\x00" in cleaned_doc.id: logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}") cleaned_doc.id = cleaned_doc.id.replace("\x00", "") + if cleaned_doc.title and "\x00" in cleaned_doc.title: + logger.warning( + f"NUL characters found in document title: {cleaned_doc.title}" + ) + cleaned_doc.title = cleaned_doc.title.replace("\x00", "") + if "\x00" in cleaned_doc.semantic_identifier: logger.warning( f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}" @@ -116,6 +123,12 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]: ) section.link = section.link.replace("\x00", "") + if section.text and "\x00" in section.text: + logger.warning( + f"NUL characters found in document text for document: {cleaned_doc.id}" + ) + section.text = section.text.replace("\x00", "") + cleaned_batch.append(cleaned_doc) return cleaned_batch @@ -234,8 +247,6 @@ def _run_indexing( tenant_id=tenant_id, ) - all_connector_doc_ids: set[str] = set() - tracer_counter = 0 if INDEXING_TRACER_INTERVAL > 0: tracer.snap() @@ -290,27 +301,23 @@ def _run_indexing( index_attempt_md.batch_num = batch_num + 1 # use 1-index for this # real work happens here! - new_docs, total_batch_chunks = indexing_pipeline( + index_pipeline_result = indexing_pipeline( document_batch=doc_batch_cleaned, index_attempt_metadata=index_attempt_md, ) batch_num += 1 - net_doc_change += new_docs - chunk_count += total_batch_chunks - document_count += len(doc_batch_cleaned) - all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned) + net_doc_change += index_pipeline_result.new_docs + chunk_count += index_pipeline_result.total_chunks + document_count += index_pipeline_result.total_docs # commit transaction so that the `update` below begins # with a brand new transaction. Postgres uses the start # of the transactions when computing `NOW()`, so if we have - # a long running transaction, the `time_updated` field will + # a long running transaction, the `time_updated` field will # be inaccurate db_session.commit() - if callback: - callback.progress("_run_indexing", len(doc_batch_cleaned)) - # This new value is updated every batch, so UI can refresh per batch update update_docs_indexed( db_session=db_session, @@ -320,6 +327,9 @@ def _run_indexing( docs_removed_from_index=0, ) + if callback: + callback.progress("_run_indexing", len(doc_batch_cleaned)) + tracer_counter += 1 if ( INDEXING_TRACER_INTERVAL > 0 diff --git a/backend/onyx/db/document.py b/backend/onyx/db/document.py index 27778fe0686..082ee3f94a9 100644 --- a/backend/onyx/db/document.py +++ b/backend/onyx/db/document.py @@ -1,6 +1,7 @@ import contextlib import time from collections.abc import Generator +from collections.abc import Iterable from collections.abc import Sequence from datetime import datetime from datetime import timezone @@ -13,6 +14,7 @@ from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import tuple_ +from sqlalchemy import update from sqlalchemy.dialects.postgresql import insert from sqlalchemy.engine.util import TransactionalContext from sqlalchemy.exc import OperationalError @@ -226,10 +228,13 @@ def get_document_counts_for_cc_pairs( func.count(), ) .where( - tuple_( - DocumentByConnectorCredentialPair.connector_id, - DocumentByConnectorCredentialPair.credential_id, - ).in_(cc_ids) + and_( + tuple_( + DocumentByConnectorCredentialPair.connector_id, + DocumentByConnectorCredentialPair.credential_id, + ).in_(cc_ids), + DocumentByConnectorCredentialPair.has_been_indexed.is_(True), + ) ) .group_by( DocumentByConnectorCredentialPair.connector_id, @@ -382,18 +387,40 @@ def upsert_document_by_connector_credential_pair( id=doc_id, connector_id=connector_id, credential_id=credential_id, + has_been_indexed=False, ) ) for doc_id in document_ids ] ) - # for now, there are no columns to update. If more metadata is added, then this - # needs to change to an `on_conflict_do_update` + # this must be `on_conflict_do_nothing` rather than `on_conflict_do_update` + # since we don't want to update the `has_been_indexed` field for documents + # that already exist on_conflict_stmt = insert_stmt.on_conflict_do_nothing() db_session.execute(on_conflict_stmt) db_session.commit() +def mark_document_as_indexed_for_cc_pair__no_commit( + db_session: Session, + connector_id: int, + credential_id: int, + document_ids: Iterable[str], +) -> None: + """Should be called only after a successful index operation for a batch.""" + db_session.execute( + update(DocumentByConnectorCredentialPair) + .where( + and_( + DocumentByConnectorCredentialPair.connector_id == connector_id, + DocumentByConnectorCredentialPair.credential_id == credential_id, + DocumentByConnectorCredentialPair.id.in_(document_ids), + ) + ) + .values(has_been_indexed=True) + ) + + def update_docs_updated_at__no_commit( ids_to_new_updated_at: dict[str, datetime], db_session: Session, diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 241c137cd5c..8759d84e21f 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -935,6 +935,12 @@ class DocumentByConnectorCredentialPair(Base): ForeignKey("credential.id"), primary_key=True ) + # used to better keep track of document counts at a connector level + # e.g. if a document is added as part of permission syncing, it should + # not be counted as part of the connector's document count until + # the actual indexing is complete + has_been_indexed: Mapped[bool] = mapped_column(Boolean) + connector: Mapped[Connector] = relationship( "Connector", back_populates="documents_by_connector" ) diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index 74c6c08fd8a..2a06daeeb20 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -21,6 +21,7 @@ from onyx.connectors.models import IndexAttemptMetadata from onyx.db.document import fetch_chunk_counts_for_documents from onyx.db.document import get_documents_by_ids +from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit from onyx.db.document import prepare_to_modify_documents from onyx.db.document import update_docs_chunk_count__no_commit from onyx.db.document import update_docs_last_modified__no_commit @@ -55,12 +56,20 @@ class DocumentBatchPrepareContext(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) +class IndexingPipelineResult(BaseModel): + new_docs: int + # NOTE: need total_docs, since the pipeline can skip some docs + # (e.g. not even insert them into Postgres) + total_docs: int + total_chunks: int + + class IndexingPipelineProtocol(Protocol): def __call__( self, document_batch: list[Document], index_attempt_metadata: IndexAttemptMetadata, - ) -> tuple[int, int]: + ) -> IndexingPipelineResult: ... @@ -147,10 +156,12 @@ def index_doc_batch_with_handler( db_session: Session, ignore_time_skip: bool = False, tenant_id: str | None = None, -) -> tuple[int, int]: - r = (0, 0) +) -> IndexingPipelineResult: + index_pipeline_result = IndexingPipelineResult( + new_docs=0, total_docs=len(document_batch), total_chunks=0 + ) try: - r = index_doc_batch( + index_pipeline_result = index_doc_batch( chunker=chunker, embedder=embedder, document_index=document_index, @@ -203,7 +214,7 @@ def index_doc_batch_with_handler( else: pass - return r + return index_pipeline_result def index_doc_batch_prepare( @@ -227,6 +238,15 @@ def index_doc_batch_prepare( if not ignore_time_skip else documents ) + if len(updatable_docs) != len(documents): + updatable_doc_ids = [doc.id for doc in updatable_docs] + skipped_doc_ids = [ + doc.id for doc in documents if doc.id not in updatable_doc_ids + ] + logger.info( + f"Skipping {len(skipped_doc_ids)} documents " + f"because they are up to date. Skipped doc IDs: {skipped_doc_ids}" + ) # for all updatable docs, upsert into the DB # Does not include doc_updated_at which is also used to indicate a successful update @@ -263,21 +283,6 @@ def index_doc_batch_prepare( def filter_documents(document_batch: list[Document]) -> list[Document]: documents: list[Document] = [] for document in document_batch: - # Remove any NUL characters from title/semantic_id - # This is a known issue with the Zendesk connector - # Postgres cannot handle NUL characters in text fields - if document.title: - document.title = document.title.replace("\x00", "") - if document.semantic_identifier: - document.semantic_identifier = document.semantic_identifier.replace( - "\x00", "" - ) - - # Remove NUL characters from all sections - for section in document.sections: - if section.text is not None: - section.text = section.text.replace("\x00", "") - empty_contents = not any(section.text.strip() for section in document.sections) if ( (not document.title or not document.title.strip()) @@ -333,7 +338,7 @@ def index_doc_batch( ignore_time_skip: bool = False, tenant_id: str | None = None, filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents, -) -> tuple[int, int]: +) -> IndexingPipelineResult: """Takes different pieces of the indexing pipeline and applies it to a batch of documents Note that the documents should already be batched at this point so that it does not inflate the memory requirements @@ -359,7 +364,9 @@ def index_doc_batch( db_session=db_session, ) if not ctx: - return 0, 0 + return IndexingPipelineResult( + new_docs=0, total_docs=len(filtered_documents), total_chunks=0 + ) logger.debug("Starting chunking") chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs) @@ -425,7 +432,8 @@ def index_doc_batch( ] logger.debug( - f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in access_aware_chunks]}" + "Indexing the following chunks: " + f"{[chunk.to_short_descriptor() for chunk in access_aware_chunks]}" ) # A document will not be spread across different batches, so all the # documents with chunks in this set, are fully represented by the chunks @@ -440,14 +448,17 @@ def index_doc_batch( ), ) - successful_doc_ids = [record.document_id for record in insertion_records] - successful_docs = [ - doc for doc in ctx.updatable_docs if doc.id in successful_doc_ids - ] + successful_doc_ids = {record.document_id for record in insertion_records} + if successful_doc_ids != set(updatable_ids): + raise RuntimeError( + f"Some documents were not successfully indexed. " + f"Updatable IDs: {updatable_ids}, " + f"Successful IDs: {successful_doc_ids}" + ) last_modified_ids = [] ids_to_new_updated_at = {} - for doc in successful_docs: + for doc in ctx.updatable_docs: last_modified_ids.append(doc.id) # doc_updated_at is the source's idea (on the other end of the connector) # of when the doc was last modified @@ -469,11 +480,21 @@ def index_doc_batch( db_session=db_session, ) + # these documents can now be counted as part of the CC Pairs + # document count, so we need to mark them as indexed + mark_document_as_indexed_for_cc_pair__no_commit( + connector_id=index_attempt_metadata.connector_id, + credential_id=index_attempt_metadata.credential_id, + document_ids=successful_doc_ids, + db_session=db_session, + ) + db_session.commit() - result = ( - len([r for r in insertion_records if r.already_existed is False]), - len(access_aware_chunks), + result = IndexingPipelineResult( + new_docs=len([r for r in insertion_records if r.already_existed is False]), + total_docs=len(filtered_documents), + total_chunks=len(access_aware_chunks), ) return result