From 3409e768f6cc753d8966253030982d7490520138 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 11 Nov 2024 14:59:21 -0800 Subject: [PATCH] update --- backend/ee/danswer/seeding/load_docs.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/backend/ee/danswer/seeding/load_docs.py b/backend/ee/danswer/seeding/load_docs.py index 92bda7590db..1aa6c969f43 100644 --- a/backend/ee/danswer/seeding/load_docs.py +++ b/backend/ee/danswer/seeding/load_docs.py @@ -17,12 +17,23 @@ def load_processed_docs(cohere_enabled: bool) -> list[dict]: embed_model = "embed-english-v3.0" for doc in processed_docs: - doc["title_embedding"] = cohere_client.embed( - texts=[doc["title"]], model=embed_model - ).embeddings[0] - doc["content_embedding"] = cohere_client.embed( - texts=[doc["content"]], model=embed_model - ).embeddings[0] + title_embed_response = cohere_client.embed( + texts=[doc["title"]], model=embed_model, input_type="search_document" + ) + content_embed_response = cohere_client.embed( + texts=[doc["content"]], model=embed_model, input_type="search_document" + ) + + doc["title_embedding"] = ( + title_embed_response.embeddings[0] + if hasattr(title_embed_response, "embeddings") + else title_embed_response[0] + ) + doc["content_embedding"] = ( + content_embed_response.embeddings[0] + if hasattr(content_embed_response, "embeddings") + else content_embed_response[0] + ) else: initial_docs_path = os.path.join(base_path, "initial_docs.json") processed_docs = json.load(open(initial_docs_path))