diff --git a/backend/danswer/seeding/load_docs.py b/backend/danswer/seeding/load_docs.py index 243cad8862a..1567f7f6bbb 100644 --- a/backend/danswer/seeding/load_docs.py +++ b/backend/danswer/seeding/load_docs.py @@ -3,7 +3,6 @@ import os from typing import cast -from cohere import Client from sqlalchemy.orm import Session from danswer.access.models import default_public_access @@ -33,7 +32,7 @@ from danswer.server.documents.models import ConnectorBase from danswer.utils.logger import setup_logger from danswer.utils.retry_wrapper import retry_builder -from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY +from danswer.utils.variable_functionality import fetch_versioned_implementation logger = setup_logger() @@ -92,6 +91,18 @@ def _create_indexable_chunks( return list(ids_to_documents.values()), chunks +# Cohere is used in EE version +def load_processed_docs(cohere_enabled: bool) -> list[dict]: + initial_docs_path = os.path.join( + os.getcwd(), + "danswer", + "seeding", + "initial_docs.json", + ) + processed_docs = json.load(open(initial_docs_path)) + return processed_docs + + def seed_initial_documents( db_session: Session, tenant_id: str | None, cohere_enabled: bool = False ) -> None: @@ -177,32 +188,10 @@ def seed_initial_documents( last_successful_index_time=last_index_time, ) cc_pair_id = cast(int, result.data) - - if cohere_enabled: - initial_docs_path = os.path.join( - os.getcwd(), "danswer", "seeding", "initial_docs_cohere.json" - ) - - cohere_client = Client(COHERE_DEFAULT_API_KEY) - processed_docs = json.load(open(initial_docs_path)) - for doc in processed_docs: - title_embedding = cohere_client.embed( - texts=[doc["title"]], model="embed-english-v3.0" - ).embeddings[0] - content_embedding = cohere_client.embed( - texts=[doc["content"]], model="embed-english-v3.0" - ).embeddings[0] - doc["title_embedding"] = title_embedding - doc["content_embedding"] = content_embedding - - else: - initial_docs_path = os.path.join( - os.getcwd(), - "danswer", - "seeding", - "initial_docs.json", - ) - processed_docs = json.load(open(initial_docs_path)) + processed_docs = fetch_versioned_implementation( + "danswer.seeding.load_docs", + "load_processed_docs", + )(cohere_enabled) docs, chunks = _create_indexable_chunks(processed_docs, tenant_id) diff --git a/backend/ee/danswer/seeding/load_docs.py b/backend/ee/danswer/seeding/load_docs.py new file mode 100644 index 00000000000..92bda7590db --- /dev/null +++ b/backend/ee/danswer/seeding/load_docs.py @@ -0,0 +1,30 @@ +import json +import os + +from cohere import Client + +from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY + + +def load_processed_docs(cohere_enabled: bool) -> list[dict]: + base_path = os.path.join(os.getcwd(), "danswer", "seeding") + + if cohere_enabled and COHERE_DEFAULT_API_KEY: + initial_docs_path = os.path.join(base_path, "initial_docs_cohere.json") + processed_docs = json.load(open(initial_docs_path)) + + cohere_client = Client(COHERE_DEFAULT_API_KEY) + embed_model = "embed-english-v3.0" + + for doc in processed_docs: + doc["title_embedding"] = cohere_client.embed( + texts=[doc["title"]], model=embed_model + ).embeddings[0] + doc["content_embedding"] = cohere_client.embed( + texts=[doc["content"]], model=embed_model + ).embeddings[0] + else: + initial_docs_path = os.path.join(base_path, "initial_docs.json") + processed_docs = json.load(open(initial_docs_path)) + + return processed_docs