Skip to content

Commit

Permalink
Various fixes/improvements to document counting
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed Jan 15, 2025
1 parent b195773 commit ac54120
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 48 deletions.
32 changes: 21 additions & 11 deletions backend/onyx/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,17 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
for doc in doc_batch:
cleaned_doc = doc.model_copy()

# Postgres cannot handle NUL characters in text fields
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")

if cleaned_doc.title and "\x00" in cleaned_doc.title:
logger.warning(
f"NUL characters found in document title: {cleaned_doc.title}"
)
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")

if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
Expand All @@ -116,6 +123,12 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
)
section.link = section.link.replace("\x00", "")

if section.text and "\x00" in section.text:
logger.warning(
f"NUL characters found in document text for document: {cleaned_doc.id}"
)
section.text = section.text.replace("\x00", "")

cleaned_batch.append(cleaned_doc)

return cleaned_batch
Expand Down Expand Up @@ -234,8 +247,6 @@ def _run_indexing(
tenant_id=tenant_id,
)

all_connector_doc_ids: set[str] = set()

tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
Expand Down Expand Up @@ -290,27 +301,23 @@ def _run_indexing(
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this

# real work happens here!
new_docs, total_batch_chunks = indexing_pipeline(
index_pipeline_result = indexing_pipeline(
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
)

batch_num += 1
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch_cleaned)
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
net_doc_change += index_pipeline_result.new_docs
chunk_count += index_pipeline_result.total_chunks
document_count += index_pipeline_result.total_docs

# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
# of the transactions when computing `NOW()`, so if we have
# a long running transaction, the `time_updated` field will
# a long running transaction, the `time_updated` field will
# be inaccurate
db_session.commit()

if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))

# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
db_session=db_session,
Expand All @@ -320,6 +327,9 @@ def _run_indexing(
docs_removed_from_index=0,
)

if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))

tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
Expand Down
39 changes: 33 additions & 6 deletions backend/onyx/db/document.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import time
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
Expand All @@ -13,6 +14,7 @@
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import tuple_
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine.util import TransactionalContext
from sqlalchemy.exc import OperationalError
Expand Down Expand Up @@ -226,10 +228,13 @@ def get_document_counts_for_cc_pairs(
func.count(),
)
.where(
tuple_(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
).in_(cc_ids)
and_(
tuple_(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
).in_(cc_ids),
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
)
)
.group_by(
DocumentByConnectorCredentialPair.connector_id,
Expand Down Expand Up @@ -382,18 +387,40 @@ def upsert_document_by_connector_credential_pair(
id=doc_id,
connector_id=connector_id,
credential_id=credential_id,
has_been_indexed=False,
)
)
for doc_id in document_ids
]
)
# for now, there are no columns to update. If more metadata is added, then this
# needs to change to an `on_conflict_do_update`
# this must be `on_conflict_do_nothing` rather than `on_conflict_do_update`
# since we don't want to update the `has_been_indexed` field for documents
# that already exist
on_conflict_stmt = insert_stmt.on_conflict_do_nothing()
db_session.execute(on_conflict_stmt)
db_session.commit()


def mark_document_as_indexed_for_cc_pair__no_commit(
db_session: Session,
connector_id: int,
credential_id: int,
document_ids: Iterable[str],
) -> None:
"""Should be called only after a successful index operation for a batch."""
db_session.execute(
update(DocumentByConnectorCredentialPair)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
DocumentByConnectorCredentialPair.id.in_(document_ids),
)
)
.values(has_been_indexed=True)
)


def update_docs_updated_at__no_commit(
ids_to_new_updated_at: dict[str, datetime],
db_session: Session,
Expand Down
6 changes: 6 additions & 0 deletions backend/onyx/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,12 @@ class DocumentByConnectorCredentialPair(Base):
ForeignKey("credential.id"), primary_key=True
)

# used to better keep track of document counts at a connector level
# e.g. if a document is added as part of permission syncing, it should
# not be counted as part of the connector's document count until
# the actual indexing is complete
has_been_indexed: Mapped[bool] = mapped_column(Boolean)

connector: Mapped[Connector] = relationship(
"Connector", back_populates="documents_by_connector"
)
Expand Down
83 changes: 52 additions & 31 deletions backend/onyx/indexing/indexing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.document import fetch_chunk_counts_for_documents
from onyx.db.document import get_documents_by_ids
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
from onyx.db.document import prepare_to_modify_documents
from onyx.db.document import update_docs_chunk_count__no_commit
from onyx.db.document import update_docs_last_modified__no_commit
Expand Down Expand Up @@ -55,12 +56,20 @@ class DocumentBatchPrepareContext(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)


class IndexingPipelineResult(BaseModel):
new_docs: int
# NOTE: need total_docs, since the pipeline can skip some docs
# (e.g. not even insert them into Postgres)
total_docs: int
total_chunks: int


class IndexingPipelineProtocol(Protocol):
def __call__(
self,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
) -> tuple[int, int]:
) -> IndexingPipelineResult:
...


Expand Down Expand Up @@ -147,10 +156,12 @@ def index_doc_batch_with_handler(
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> tuple[int, int]:
r = (0, 0)
) -> IndexingPipelineResult:
index_pipeline_result = IndexingPipelineResult(
new_docs=0, total_docs=len(document_batch), total_chunks=0
)
try:
r = index_doc_batch(
index_pipeline_result = index_doc_batch(
chunker=chunker,
embedder=embedder,
document_index=document_index,
Expand Down Expand Up @@ -203,7 +214,7 @@ def index_doc_batch_with_handler(
else:
pass

return r
return index_pipeline_result


def index_doc_batch_prepare(
Expand All @@ -227,6 +238,15 @@ def index_doc_batch_prepare(
if not ignore_time_skip
else documents
)
if len(updatable_docs) != len(documents):
updatable_doc_ids = [doc.id for doc in updatable_docs]
skipped_doc_ids = [
doc.id for doc in documents if doc.id not in updatable_doc_ids
]
logger.info(
f"Skipping {len(skipped_doc_ids)} documents "
f"because they are up to date. Skipped doc IDs: {skipped_doc_ids}"
)

# for all updatable docs, upsert into the DB
# Does not include doc_updated_at which is also used to indicate a successful update
Expand Down Expand Up @@ -263,21 +283,6 @@ def index_doc_batch_prepare(
def filter_documents(document_batch: list[Document]) -> list[Document]:
documents: list[Document] = []
for document in document_batch:
# Remove any NUL characters from title/semantic_id
# This is a known issue with the Zendesk connector
# Postgres cannot handle NUL characters in text fields
if document.title:
document.title = document.title.replace("\x00", "")
if document.semantic_identifier:
document.semantic_identifier = document.semantic_identifier.replace(
"\x00", ""
)

# Remove NUL characters from all sections
for section in document.sections:
if section.text is not None:
section.text = section.text.replace("\x00", "")

empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())
Expand Down Expand Up @@ -333,7 +338,7 @@ def index_doc_batch(
ignore_time_skip: bool = False,
tenant_id: str | None = None,
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
) -> tuple[int, int]:
) -> IndexingPipelineResult:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements
Expand All @@ -359,7 +364,9 @@ def index_doc_batch(
db_session=db_session,
)
if not ctx:
return 0, 0
return IndexingPipelineResult(
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
)

logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
Expand Down Expand Up @@ -425,7 +432,8 @@ def index_doc_batch(
]

logger.debug(
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
"Indexing the following chunks: "
f"{[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
)
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
Expand All @@ -440,14 +448,17 @@ def index_doc_batch(
),
)

successful_doc_ids = [record.document_id for record in insertion_records]
successful_docs = [
doc for doc in ctx.updatable_docs if doc.id in successful_doc_ids
]
successful_doc_ids = {record.document_id for record in insertion_records}
if successful_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Successful IDs: {successful_doc_ids}"
)

last_modified_ids = []
ids_to_new_updated_at = {}
for doc in successful_docs:
for doc in ctx.updatable_docs:
last_modified_ids.append(doc.id)
# doc_updated_at is the source's idea (on the other end of the connector)
# of when the doc was last modified
Expand All @@ -469,11 +480,21 @@ def index_doc_batch(
db_session=db_session,
)

# these documents can now be counted as part of the CC Pairs
# document count, so we need to mark them as indexed
mark_document_as_indexed_for_cc_pair__no_commit(
connector_id=index_attempt_metadata.connector_id,
credential_id=index_attempt_metadata.credential_id,
document_ids=successful_doc_ids,
db_session=db_session,
)

db_session.commit()

result = (
len([r for r in insertion_records if r.already_existed is False]),
len(access_aware_chunks),
result = IndexingPipelineResult(
new_docs=len([r for r in insertion_records if r.already_existed is False]),
total_docs=len(filtered_documents),
total_chunks=len(access_aware_chunks),
)

return result
Expand Down

0 comments on commit ac54120

Please sign in to comment.