diff --git a/backend/ee/danswer/seeding/load_docs.py b/backend/ee/danswer/seeding/load_docs.py index 1aa6c969f43..31047423c0d 100644 --- a/backend/ee/danswer/seeding/load_docs.py +++ b/backend/ee/danswer/seeding/load_docs.py @@ -1,10 +1,14 @@ import json import os +from typing import cast +from typing import List from cohere import Client from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY +Embedding = List[float] + def load_processed_docs(cohere_enabled: bool) -> list[dict]: base_path = os.path.join(os.getcwd(), "danswer", "seeding") @@ -13,27 +17,27 @@ def load_processed_docs(cohere_enabled: bool) -> list[dict]: initial_docs_path = os.path.join(base_path, "initial_docs_cohere.json") processed_docs = json.load(open(initial_docs_path)) - cohere_client = Client(COHERE_DEFAULT_API_KEY) + cohere_client = Client(api_key=COHERE_DEFAULT_API_KEY) embed_model = "embed-english-v3.0" for doc in processed_docs: title_embed_response = cohere_client.embed( - texts=[doc["title"]], model=embed_model, input_type="search_document" + texts=[doc["title"]], + model=embed_model, + input_type="search_document", ) content_embed_response = cohere_client.embed( - texts=[doc["content"]], model=embed_model, input_type="search_document" + texts=[doc["content"]], + model=embed_model, + input_type="search_document", ) - doc["title_embedding"] = ( - title_embed_response.embeddings[0] - if hasattr(title_embed_response, "embeddings") - else title_embed_response[0] - ) - doc["content_embedding"] = ( - content_embed_response.embeddings[0] - if hasattr(content_embed_response, "embeddings") - else content_embed_response[0] - ) + doc["title_embedding"] = cast( + List[Embedding], title_embed_response.embeddings + )[0] + doc["content_embedding"] = cast( + List[Embedding], content_embed_response.embeddings + )[0] else: initial_docs_path = os.path.join(base_path, "initial_docs.json") processed_docs = json.load(open(initial_docs_path))