+[docs]
+defget_first_message_content(completion:Message)->str:
+r"""When we only need the content of the first message.
+ It is the default parser for chat completion."""
+ returncompletion.content[0].text
+
+
+
+__all__=["AnthropicAPIClient","get_first_message_content"]
+
+
+# NOTE: using customize parser might make the new_component more complex when we have to handle a callable
+
+[docs]
+classAnthropicAPIClient(ModelClient):
+ __doc__=r"""A component wrapper for the Anthropic API client.
+
+ Visit https://docs.anthropic.com/en/docs/intro-to-claude for more api details.
+ """
+
+ def__init__(
+ self,
+ api_key:Optional[str]=None,
+ chat_completion_parser:Callable[[Message],Any]=None,
+ ):
+r"""It is recommended to set the ANTHROPIC_API_KEY environment variable instead of passing it as an argument."""
+ super().__init__()
+ self._api_key=api_key
+ self.sync_client=self.init_sync_client()
+ self.async_client=None# only initialize if the async call is called
+ self.tested_llm_models=["claude-3-opus-20240229"]
+ self.chat_completion_parser=(
+ chat_completion_parserorget_first_message_content
+ )
+
+
+[docs]
+ definit_sync_client(self):
+ api_key=self._api_keyoros.getenv("ANTHROPIC_API_KEY")
+ ifnotapi_key:
+ raiseValueError("Environment variable ANTHROPIC_API_KEY must be set")
+ returnanthropic.Anthropic(api_key=api_key)
+
+
+
+[docs]
+ definit_async_client(self):
+ api_key=self._api_keyoros.getenv("ANTHROPIC_API_KEY")
+ ifnotapi_key:
+ raiseValueError("Environment variable ANTHROPIC_API_KEY must be set")
+ returnanthropic.AsyncAnthropic(api_key=api_key)
+
+
+ # TODO: potentially use <SYS></SYS> to separate the system and user messages. This requires user to follow it. If it is not found, then we will only use user message.
+
+[docs]
+ defconvert_inputs_to_api_kwargs(
+ self,
+ input:Optional[Any]=None,
+ model_kwargs:Dict={},
+ model_type:ModelType=ModelType.UNDEFINED,
+ )->dict:
+r"""Anthropic API messages separates the system and the user messages.
+
+ As we focus on one prompt, we have to use the user message as the input.
+
+ api: https://docs.anthropic.com/en/api/messages
+ """
+ api_kwargs=model_kwargs.copy()
+ ifmodel_type==ModelType.LLM:
+ api_kwargs["messages"]=[
+ {"role":"user","content":input},
+ ]
+ # if input and input != "":
+ # api_kwargs["system"] = input
+ else:
+ raiseValueError(f"Model type {model_type} not supported")
+ returnapi_kwargs
+
+
+
+[docs]
+ @backoff.on_exception(
+ backoff.expo,
+ (
+ APITimeoutError,
+ InternalServerError,
+ RateLimitError,
+ UnprocessableEntityError,
+ BadRequestError,
+ ),
+ max_time=5,
+ )
+ defcall(self,api_kwargs:Dict={},model_type:ModelType=ModelType.UNDEFINED):
+"""
+ kwargs is the combined input and model_kwargs
+ """
+ ifmodel_type==ModelType.EMBEDDER:
+ raiseValueError(f"Model type {model_type} not supported")
+ elifmodel_type==ModelType.LLM:
+ returnself.sync_client.messages.create(**api_kwargs)
+ else:
+ raiseValueError(f"model_type {model_type} is not supported")
+
+
+
+[docs]
+ @backoff.on_exception(
+ backoff.expo,
+ (
+ APITimeoutError,
+ InternalServerError,
+ RateLimitError,
+ UnprocessableEntityError,
+ BadRequestError,
+ ),
+ max_time=5,
+ )
+ asyncdefacall(
+ self,api_kwargs:Dict={},model_type:ModelType=ModelType.UNDEFINED
+ ):
+"""
+ kwargs is the combined input and model_kwargs
+ """
+ ifself.async_clientisNone:
+ self.async_client=self.init_async_client()
+ ifmodel_type==ModelType.EMBEDDER:
+ raiseValueError(f"Model type {model_type} not supported")
+ elifmodel_type==ModelType.LLM:
+ returnawaitself.async_client.messages.create(**api_kwargs)
+ else:
+ raiseValueError(f"model_type {model_type} is not supported")
+[docs]
+classCohereAPIClient(ModelClient):
+ __doc__=r"""A component wrapper for the Cohere API.
+
+ Visit https://docs.cohere.com/ for more api details.
+
+ References:
+ - Cohere reranker: https://docs.cohere.com/reference/rerank
+
+ Tested Cohere models: 6/16/2024
+ - rerank-english-v3.0, rerank-multilingual-v3.0, rerank-english-v2.0, rerank-multilingual-v2.0
+
+ .. note::
+ For all ModelClient integration, such as CohereAPIClient, if you want to subclass CohereAPIClient, you need to import it from the module directly.
+
+ ``from lightrag.components.model_client.cohere_client import CohereAPIClient``
+
+ instead of using the lazy import with:
+
+ ``from lightrag.components.model_client import CohereAPIClient``
+ """
+
+ def__init__(self,api_key:Optional[str]=None):
+r"""It is recommended to set the GROQ_API_KEY environment variable instead of passing it as an argument.
+
+ Args:
+ api_key (Optional[str], optional): Groq API key. Defaults to None.
+ """
+ super().__init__()
+ self._api_key=api_key
+ self.init_sync_client()
+
+ self.async_client=None# only initialize if the async call is called
+
+
+[docs]
+ definit_sync_client(self):
+ api_key=self._api_keyoros.getenv("COHERE_API_KEY")
+ ifnotapi_key:
+ raiseValueError("Environment variable COHERE_API_KEY must be set")
+ self.sync_client=cohere.Client(api_key=api_key)
+
+
+
+[docs]
+ definit_async_client(self):
+ api_key=self._api_keyoros.getenv("COHERE_API_KEY")
+ ifnotapi_key:
+ raiseValueError("Environment variable COHERE_API_KEY must be set")
+ self.async_client=cohere.AsyncClient(api_key=api_key)
+
+
+
+[docs]
+ defconvert_inputs_to_api_kwargs(
+ self,
+ input:Optional[Any]=None,# for retriever, it is a list of string.
+ model_kwargs:Dict={},
+ model_type:ModelType=ModelType.UNDEFINED,
+ )->Dict:
+r"""
+ For rerank model, expect model_kwargs to have the following keys:
+ model: str,
+ query: str,
+ documents: List[str],
+ top_n: int,
+ """
+ final_model_kwargs=model_kwargs.copy()
+ ifmodel_type==ModelType.RERANKER:
+ final_model_kwargs["query"]=input
+ if"model"notinfinal_model_kwargs:
+ raiseValueError("model must be specified")
+ if"documents"notinfinal_model_kwargs:
+ raiseValueError("documents must be specified")
+ if"top_k"notinfinal_model_kwargs:
+ raiseValueError("top_k must be specified")
+
+ # convert top_k to the api specific, which is top_n
+ final_model_kwargs["top_n"]=final_model_kwargs.pop("top_k")
+ returnfinal_model_kwargs
+ else:
+ raiseValueError(f"model_type {model_type} is not supported")
+
+
+
+[docs]
+ @backoff.on_exception(
+ backoff.expo,
+ (
+ BadRequestError,
+ InternalServerError,
+ ),
+ max_time=5,
+ )
+ defcall(self,api_kwargs:Dict={},model_type:ModelType=ModelType.UNDEFINED):
+ assert(
+ "model"inapi_kwargs
+ ),f"model must be specified in api_kwargs: {api_kwargs}"
+ if(
+ model_type==ModelType.RERANKER
+ ):# query -> # scores for top_k documents, index for the top_k documents, return as tuple
+
+ response=self.sync_client.rerank(**api_kwargs)
+ top_k_scores=[result.relevance_scoreforresultinresponse.results]
+ top_k_indices=[result.indexforresultinresponse.results]
+ returntop_k_indices,top_k_scores
+ else:
+ raiseValueError(f"model_type {model_type} is not supported")
+[docs]
+classGroqAPIClient(ModelClient):
+ __doc__=r"""A component wrapper for the Groq API client.
+
+ Visit https://console.groq.com/docs/ for more api details.
+ Check https://console.groq.com/docs/models for the available models.
+
+ Tested Groq models: 4/22/2024
+ - llama3-8b-8192
+ - llama3-70b-8192
+ - mixtral-8x7b-32768
+ - gemma-7b-it
+ """
+
+ def__init__(self,api_key:Optional[str]=None):
+r"""It is recommended to set the GROQ_API_KEY environment variable instead of passing it as an argument.
+
+ Args:
+ api_key (Optional[str], optional): Groq API key. Defaults to None.
+ """
+ super().__init__()
+ self._api_key=api_key
+ self.init_sync_client()
+
+ self.async_client=None# only initialize if the async call is called
+
+
+[docs]
+ definit_sync_client(self):
+ api_key=self._api_keyoros.getenv("GROQ_API_KEY")
+ ifnotapi_key:
+ raiseValueError("Environment variable GROQ_API_KEY must be set")
+ self.sync_client=Groq(api_key=api_key)
+
+
+
+[docs]
+ definit_async_client(self):
+ api_key=self._api_keyoros.getenv("GROQ_API_KEY")
+ ifnotapi_key:
+ raiseValueError("Environment variable GROQ_API_KEY must be set")
+ self.async_client=AsyncGroq(api_key=api_key)
+
+
+
+[docs]
+ defparse_chat_completion(self,completion:Any)->str:
+"""
+ Parse the completion to a string output.
+ """
+ returncompletion.choices[0].message.content
Source code for components.model_client.openai_client
+"""OpenAI ModelClient integration."""
+
+importos
+fromtypingimportDict,Sequence,Optional,List,Any,TypeVar,Callable
+
+importlogging
+importbackoff
+
+
+fromlightrag.core.model_clientimportModelClient
+fromlightrag.core.typesimportModelType,EmbedderOutput,TokenLogProb
+fromlightrag.components.model_client.utilsimportparse_embedding_response
+
+# optional import
+fromlightrag.utils.lazy_importimportsafe_import,OptionalPackages
+
+
+openai=safe_import(OptionalPackages.OPENAI.value[0],OptionalPackages.OPENAI.value[1])
+
+fromopenaiimportOpenAI,AsyncOpenAI
+fromopenaiimport(
+ APITimeoutError,
+ InternalServerError,
+ RateLimitError,
+ UnprocessableEntityError,
+ BadRequestError,
+)
+fromopenai.typesimportCompletion,CreateEmbeddingResponse
+
+
+log=logging.getLogger(__name__)
+T=TypeVar("T")
+
+
+# completion parsing functions and you can combine them into one singple chat completion parser
+
+[docs]
+defget_first_message_content(completion:Completion)->str:
+r"""When we only need the content of the first message.
+ It is the default parser for chat completion."""
+ returncompletion.choices[0].message.content
+
+
+
+
+[docs]
+defget_all_messages_content(completion:Completion)->List[str]:
+r"""When the n > 1, get all the messages content."""
+ return[c.message.contentforcincompletion.choices]
+
+
+
+
+[docs]
+defget_probabilities(completion:Completion)->List[List[TokenLogProb]]:
+r"""Get the probabilities of each token in the completion."""
+ log_probs=[]
+ forcincompletion.choices:
+ content=c.logprobs.content
+ print(content)
+ log_probs_for_choice=[]
+ foropenai_token_logprobincontent:
+ token=openai_token_logprob.token
+ logprob=openai_token_logprob.logprob
+ log_probs_for_choice.append(TokenLogProb(token=token,logprob=logprob))
+ log_probs.append(log_probs_for_choice)
+ returnlog_probs
+
+
+
+
+[docs]
+classOpenAIClient(ModelClient):
+ __doc__=r"""A component wrapper for the OpenAI API client.
+
+ Support both embedding and chat completion API.
+
+ Users (1) simplify use ``Embedder`` and ``Generator`` components by passing OpenAIClient() as the model_client.
+ (2) can use this as an example to create their own API client or extend this class(copying and modifing the code) in their own project.
+
+ Note:
+ We suggest users not to use `response_format` to enforce output data type or `tools` and `tool_choice` in your model_kwargs when calling the API.
+ We do not know how OpenAI is doing the formating or what prompt they have added.
+ Instead
+ - use :ref:`OutputParser<components-output_parsers>` for response parsing and formating.
+
+ Args:
+ api_key (Optional[str], optional): OpenAI API key. Defaults to None.
+ chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
+ Default is `get_first_message_content`.
+
+ References:
+ - Embeddings models: https://platform.openai.com/docs/guides/embeddings
+ - Chat models: https://platform.openai.com/docs/guides/text-generation
+ - OpenAI docs: https://platform.openai.com/docs/introduction
+ """
+
+ def__init__(
+ self,
+ api_key:Optional[str]=None,
+ chat_completion_parser:Callable[[Completion],Any]=None,
+ ):
+r"""It is recommended to set the OPENAI_API_KEY environment variable instead of passing it as an argument.
+
+ Args:
+ api_key (Optional[str], optional): OpenAI API key. Defaults to None.
+ """
+ super().__init__()
+ self._api_key=api_key
+ self.sync_client=self.init_sync_client()
+ self.async_client=None# only initialize if the async call is called
+ self.chat_completion_parser=(
+ chat_completion_parserorget_first_message_content
+ )
+
+
+[docs]
+ definit_sync_client(self):
+ api_key=self._api_keyoros.getenv("OPENAI_API_KEY")
+ ifnotapi_key:
+ raiseValueError("Environment variable OPENAI_API_KEY must be set")
+ returnOpenAI(api_key=api_key)
+
+
+
+[docs]
+ definit_async_client(self):
+ api_key=self._api_keyoros.getenv("OPENAI_API_KEY")
+ ifnotapi_key:
+ raiseValueError("Environment variable OPENAI_API_KEY must be set")
+ returnAsyncOpenAI(api_key=api_key)
+
+
+
+[docs]
+ defparse_chat_completion(self,completion:Completion)->Any:
+"""Parse the completion to a str."""
+ log.debug(f"completion: {completion}")
+ returnself.chat_completion_parser(completion)
+
+
+
+[docs]
+ defparse_embedding_response(
+ self,response:CreateEmbeddingResponse
+ )->EmbedderOutput:
+r"""Parse the embedding response to a structure LightRAG components can understand.
+
+ Should be called in ``Embedder``.
+ """
+ try:
+ returnparse_embedding_response(response)
+ exceptExceptionase:
+ log.error(f"Error parsing the embedding response: {e}")
+ returnEmbedderOutput(data=[],error=str(e),raw_response=response)
+
+
+
+[docs]
+ defconvert_inputs_to_api_kwargs(
+ self,
+ input:Optional[Any]=None,
+ model_kwargs:Dict={},
+ model_type:ModelType=ModelType.UNDEFINED,
+ )->Dict:
+r"""
+ Specify the API input type and output api_kwargs that will be used in _call and _acall methods.
+ Convert the Component's standard input, and system_input(chat model) and model_kwargs into API-specific format
+ """
+ final_model_kwargs=model_kwargs.copy()
+ ifmodel_type==ModelType.EMBEDDER:
+ ifisinstance(input,str):
+ input=[input]
+ # convert input to input
+ ifnotisinstance(input,Sequence):
+ raiseTypeError("input must be a sequence of text")
+ final_model_kwargs["input"]=input
+ elifmodel_type==ModelType.LLM:
+ # convert input to messages
+ messages:List[Dict[str,str]]=[]
+ ifinputisnotNoneandinput!="":
+ messages.append({"role":"system","content":input})
+ final_model_kwargs["messages"]=messages
+ else:
+ raiseValueError(f"model_type {model_type} is not supported")
+ returnfinal_model_kwargs
+[docs]
+ defto_dict(self)->Dict[str,Any]:
+r"""Convert the component to a dictionary."""
+ # TODO: not exclude but save yes or no for recreating the clients
+ exclude=[
+ "sync_client",
+ "async_client",
+ ]# unserializable object
+ output=super().to_dict(exclude=exclude)
+ returnoutput
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/_modules/components/model_client/utils.html b/_modules/components/model_client/utils.html
index b688cbc0..39272ee1 100644
--- a/_modules/components/model_client/utils.html
+++ b/_modules/components/model_client/utils.html
@@ -404,7 +404,6 @@
+[docs]
+classFAISSRetriever(
+ Retriever[FAISSRetrieverDocumentEmbeddingType,FAISSRetrieverQueryType]
+):
+ __doc__=r"""Semantic search/embedding-based retriever using FAISS.
+
+ To use the retriever, you can either pass the index embeddings from the :meth:`__init__` or use the :meth:`build_index_from_documents` method.
+
+
+ Args:
+ embedder (Embedder, optimal): The embedder component to use for converting the queries in string format to embeddings.
+ Ensure the vectorizer is exactly the same as the one used to the embeddings in the index.
+ top_k (int, optional): Number of chunks to retrieve. Defaults to 5.
+ dimensions (Optional[int], optional): Dimension of the embeddings. Defaults to None. It can automatically infer the dimensions from the first chunk.
+ documents (Optional[FAISSRetrieverDocumentType], optional): List of embeddings. Format can be List[List[float]] or List[np.ndarray]. Defaults to None.
+ metric (Literal["cosine", "euclidean", "prob"], optional): The metric to use for the retrieval. Defaults to "prob" which converts cosine similarity to probability.
+
+ How FAISS works:
+
+ The retriever uses in-memory Faiss index to retrieve the top k chunks
+ d: dimension of the vectors
+ xb: number of vectors to put in the index
+ xq: number of queries
+ The data type dtype must be float32.
+
+ Note: When the num of chunks are less than top_k, the last columns will be -1
+
+ Other index options:
+ - faiss.IndexFlatL2: L2 or Euclidean distance, [-inf, inf]
+ - faiss.IndexFlatIP: Inner product of embeddings (inner product of normalized vectors will be cosine similarity, [-1, 1])
+
+ We choose cosine similarity and convert it to range [0, 1] by adding 1 and dividing by 2 to simulate probability in [0, 1]
+
+ References:
+ - FAISS: https://github.com/facebookresearch/faiss
+ """
+
+ def__init__(
+ self,
+ embedder:Optional[Embedder]=None,
+ top_k:int=5,
+ dimensions:Optional[int]=None,
+ documents:Optional[Any]=None,
+ document_map_func:Optional[
+ Callable[[Any],FAISSRetrieverDocumentEmbeddingType]
+ ]=None,
+ metric:Literal["cosine","euclidean","prob"]="prob",
+ ):
+ super().__init__()
+
+ self.reset_index()
+
+ self.dimensions=dimensions
+ self.embedder=embedder# used to vectorize the queries
+ self.top_k=top_k
+ self.metric=metric
+ ifself.metric=="cosine"orself.metric=="prob":
+ self._faiss_index_type=faiss.IndexFlatIP
+ self._needs_normalized_embeddings=True
+ elifself.metric=="euclidean":
+ self._faiss_index_type=faiss.IndexFlatL2
+ self._needs_normalized_embeddings=False
+ else:
+ raiseValueError(f"Invalid metric: {self.metric}")
+
+ ifdocuments:
+ self.documents=documents
+ self.build_index_from_documents(documents,document_map_func)
+
+
+
+
+ def_preprare_faiss_index_from_np_array(self,xb:np.ndarray):
+r"""Prepare the faiss index from the numpy array."""
+ ifnotself.dimensions:
+ self.dimensions=self.xb.shape[1]
+ else:
+ assert(
+ self.dimensions==self.xb.shape[1]
+ ),f"Dimension mismatch: {self.dimensions} != {self.xb.shape[1]}"
+ self.total_documents=xb.shape[0]
+
+ self.index=self._faiss_index_type(self.dimensions)
+ self.index.add(xb)
+ self.indexed=True
+
+
+[docs]
+ defbuild_index_from_documents(
+ self,
+ documents:Sequence[Any],
+ document_map_func:Optional[
+ Callable[[Any],FAISSRetrieverDocumentEmbeddingType]
+ ]=None,
+ ):
+r"""Build index from embeddings.
+
+ Args:
+ documents: List of embeddings. Format can be List[List[float]] or List[np.ndarray]
+
+ If you are using Document format, pass them as [doc.vector for doc in documents]
+ """
+ ifdocument_map_func:
+ assertcallable(document_map_func),"document_map_func should be callable"
+ documents=[document_map_func(doc)fordocindocuments]
+ try:
+ self.documents=documents
+
+ # convert to numpy array
+ ifnotisinstance(documents,np.ndarray)andisinstance(
+ documents[0],Sequence
+ ):
+ # ensure all the embeddings are of the same size
+ assertall(
+ len(doc)==len(documents[0])fordocindocuments
+ ),"All embeddings should be of the same size"
+ self.xb=np.array(documents,dtype=np.float32)
+ else:
+ self.xb=documents
+ ifself._needs_normalized_embeddings:
+ first_vector=self.xb[0]
+ ifnotis_normalized(first_vector):
+ log.warning(
+ "Embeddings are not normalized, normalizing the embeddings"
+ )
+ self.xb=normalize_np_array(self.xb)
+
+ self._preprare_faiss_index_from_np_array(self.xb)
+ log.info(f"Index built with {self.total_documents} chunks")
+ exceptExceptionase:
+ log.error(f"Error building index: {e}, resetting the index")
+ # reset the index
+ self.reset_index()
+ raisee
+
+
+ def_convert_cosine_similarity_to_probability(self,D:np.ndarray)->np.ndarray:
+ D=(D+1)/2
+ D=np.round(D,3)
+ returnD
+
+ def_to_retriever_output(
+ self,Ind:np.ndarray,D:np.ndarray
+ )->RetrieverOutputType:
+r"""Convert the indices and distances to RetrieverOutputType format."""
+ output:RetrieverOutputType=[]
+ # Step 1: Filter out the -1, -1 columns along with its scores when top_k > len(chunks)
+ if-1inInd:
+ valid_columns=~np.any(Ind==-1,axis=0)
+
+ D=D[:,valid_columns]
+ Ind=Ind[:,valid_columns]
+ # Step 2: processing rows (one query at a time)
+ forrowinzip(Ind,D):
+ indices,distances=row
+ # convert from numpy to list
+ retrieved_documents_indices=indices.tolist()
+ retrieved_documents_scores=distances.tolist()
+ output.append(
+ RetrieverOutput(
+ doc_indices=retrieved_documents_indices,
+ doc_scores=retrieved_documents_scores,
+ )
+ )
+
+ returnoutput
+
+
+[docs]
+ defretrieve_embedding_queries(
+ self,
+ input:FAISSRetrieverQueriesEmbeddingType,
+ top_k:Optional[int]=None,
+ )->RetrieverOutputType:
+ ifnotself.indexedorself.index.ntotal==0:
+ raiseValueError(
+ "Index is empty. Please set the chunks to build the index from"
+ )
+ # check if the input is List, convert to numpy array
+ try:
+ ifnotisinstance(input,np.ndarray):
+ xq=np.array(input,dtype=np.float32)
+ else:
+ xq=input
+ exceptExceptionase:
+ log.error(f"Error converting input to numpy array: {e}")
+ raisee
+
+ D,Ind=self.index.search(xq,top_kiftop_kelseself.top_k)
+ ifself.metric=="prob":
+ D=self._convert_cosine_similarity_to_probability(D)
+ output:RetrieverOutputType=self._to_retriever_output(Ind,D)
+ returnoutput
+
+
+
+[docs]
+ defretrieve_string_queries(
+ self,
+ input:Union[str,List[str]],
+ top_k:Optional[int]=None,
+ )->RetrieverOutputType:
+r"""Retrieve the top k chunks given the query or queries in string format.
+
+ Args:
+ input: The query or list of queries in string format. Note: ensure the maximum number of queries fits into the embedder.
+ top_k: The number of chunks to retrieve. When top_k is not provided, it will use the default top_k set during initialization.
+
+ When top_k is not provided, it will use the default top_k set during initialization.
+ """
+ ifnotself.indexedorself.index.ntotal==0:
+ raiseValueError(
+ "Index is empty. Please set the chunks to build the index from"
+ )
+ queries=[input]ifisinstance(input,str)elseinput
+ # filter out empty queries
+ valid_queries:List[str]=[]
+ record_map:Dict[int,int]=(
+ {}
+ )# final index : the position in the initial queries
+ fori,qinenumerate(queries):
+ ifnotq:
+ log.warning("Empty query found, skipping")
+ continue
+ valid_queries.append(q)
+ record_map[len(valid_queries)-1]=i
+ # embed the queries, assume the length fits into a batch.
+ try:
+ embeddings:EmbedderOutputType=self.embedder(valid_queries)
+ queries_embeddings:List[List[float]]=[
+ data.embeddingfordatainembeddings.data
+ ]
+
+ exceptExceptionase:
+ log.error(f"Error embedding queries: {e}")
+ raisee
+ xq=np.array(queries_embeddings,dtype=np.float32)
+ D,Ind=self.index.search(xq,top_kiftop_kelseself.top_k)
+ D=self._convert_cosine_similarity_to_probability(D)
+
+ output:RetrieverOutputType=[
+ RetrieverOutput(doc_indices=[],query=query)forqueryinqueries
+ ]
+ retrieved_output:RetrieverOutputType=self._to_retriever_output(Ind,D)
+
+ # fill in the doc_indices and score for valid queries
+ fori,per_query_outputinenumerate(retrieved_output):
+ initial_index=record_map[i]
+ output[initial_index].doc_indices=per_query_output.doc_indices
+ output[initial_index].doc_scores=per_query_output.doc_scores
+
+ returnoutput
+
+
+ @overload
+ defcall(
+ self,
+ input:FAISSRetrieverQueriesEmbeddingType,
+ top_k:Optional[int]=None,
+ )->RetrieverOutputType:
+r"""Retrieve the top k chunks given the query or queries in embedding format."""
+ ...
+
+ @overload
+ defcall(
+ self,
+ input:FAISSRetrieverQueriesStrType,
+ top_k:Optional[int]=None,
+ )->RetrieverOutputType:
+r"""Retrieve the top k chunks given the query or queries in string format."""
+ ...
+
+
+[docs]
+ defcall(
+ self,
+ input:FAISSRetrieverQueriesType,
+ top_k:Optional[int]=None,
+ )->RetrieverOutputType:
+r"""Retrieve the top k chunks given the query or queries in embedding or string format."""
+ assert(
+ self.indexed
+ ),"Index is not built. Please build the index using build_index_from_documents"
+ ifisinstance(input,str)or(
+ isinstance(input,Sequence)andisinstance(input[0],str)
+ ):
+ assertself.embedder,"Embedder is not provided"
+ returnself.retrieve_string_queries(input,top_k)
+ else:
+ returnself.retrieve_embedding_queries(input,top_k)
Source code for components.retriever.postgres_retriever
+"""Leverage a postgres database to store and retrieve documents."""
+
+fromtypingimportList,Optional,Any
+fromenumimportEnum
+importnumpyasnp
+importlogging
+
+fromlightrag.core.retrieverimport(
+ Retriever,
+)
+fromlightrag.core.embedderimportEmbedder
+
+fromlightrag.core.typesimport(
+ RetrieverOutput,
+ RetrieverStrQueryType,
+ RetrieverStrQueriesType,
+ Document,
+)
+fromlightrag.database.sqlalchemy.sqlachemy_managerimportDatabaseManager
+
+log=logging.getLogger(__name__)
+
+
+
+[docs]
+classDistanceToOperator(Enum):
+ __doc__=r"""Enum for the distance to operator.
+
+ About pgvector:
+
+ 1. L2 distance: <->, inner product (<#>), cosine distance (<=>), and L1 distance (<+>, added in 0.7.0)
+ """
+ L2="<->"
+ INNER_PRODUCT=(
+ "<#>"# cosine similarity when the vector is normalized, in range [-1, 1]
+ )
+ COSINE="<=>"# cosine distance, in range [0, 1] = 1 - cosine_similarity
+ L1="<+>"
+
+
+
+
+[docs]
+classPostgresRetriever(Retriever[Any,RetrieverStrQueryType]):
+ __doc__=r"""Use a postgres database to store and retrieve documents.
+
+ Users can follow this example and to customize the prompt or additionally ask it to output score along with the indices.
+
+ Args:
+ top_k (Optional[int], optional): top k documents to fetch. Defaults to 1.
+ database_url (str): the database url to connect to. Defaults to postgresql://postgres:password@localhost:5432/vector_db.
+
+ References:
+ [1] pgvector extension: https://github.com/pgvector/pgvector
+ """
+
+ def__init__(
+ self,
+ embedder:Embedder,
+ top_k:Optional[int]=1,
+ database_url:str=None,
+ table_name:str="document",
+ distance_operator:DistanceToOperator=DistanceToOperator.INNER_PRODUCT,
+ ):
+ super().__init__()
+ self.top_k=top_k
+ self.table_name=table_name
+ db_name="vector_db"
+ self.database_url=(
+ database_urlorf"postgresql://postgres:password@localhost:5432/{db_name}"
+ )
+ self.db_manager=DatabaseManager(self.database_url)
+ self.embedder=embedder
+ self.distance_operator=distance_operator
+ self.db_score_prob_fun_map={
+ DistanceToOperator.COSINE:self._convert_cosine_distance_to_probability,
+ DistanceToOperator.L2:self._convert_l2_distance_to_probability,
+ DistanceToOperator.INNER_PRODUCT:self._convert_cosine_similarity_to_probability,
+ }
+ self.score_prob_fun=(
+ self.db_score_prob_fun_map[self.distance_operator]
+ ifself.distance_operatorinself.db_score_prob_fun_map
+ elseNone
+ )
+
+
+[docs]
+ @classmethod
+ defformat_vector_search_query(
+ cls,
+ table_name:str,
+ vector_column:str,
+ query_embedding:List[float],
+ top_k:int,
+ distance_operator:DistanceToOperator,
+ sort_desc:bool=True,
+ )->str:
+"""
+ Formats a SQL query string to select all columns from a table, order the results
+ by the distance or similarity score to a provided embedding, and also return
+ that score.
+
+ Args:
+ table_name (str): The name of the table to query.
+ column (str): The name of the column containing the vector data.
+ query_embedding (list or str): The embedding vector to compare against.
+ top_k (int): The number of top results to return.
+
+ Returns:
+ str: A formatted SQL query string that includes the score.
+ """
+
+ # Convert the list embedding to a string format suitable for SQL
+ ifisinstance(query_embedding,list):
+ embedding_str=str(query_embedding).replace(
+ " ",""
+ )# Remove spaces for cleaner SQL
+ else:
+ embedding_str=query_embedding
+
+ # Determine sorting order
+ order_by="DESC"ifsort_descelse"ASC"
+
+ # SQL query that includes the score in the selected columns
+ sql_query=f"""
+ SELECT *, ({vector_column}{distance_operator.value} '{embedding_str}') AS score
+ FROM {table_name}
+ ORDER BY score {order_by}
+ LIMIT {top_k};
+ """
+ returnsql_query
+
+
+
+[docs]
+ defretrieve_by_sql(self,query:str)->List[str]:
+"""Retrieve documents from the postgres database."""
+
+ results=self.db_manager.execute_query(query)
+ print(results)
+ returnresults
+
+
+ def_convert_cosine_similarity_to_probability(
+ self,cosine_similarity:List[float]
+ )->List[float]:
+"""Convert cosine similarity to probability."""
+ return[(1+cosine_similarity)/2forcosine_similarityincosine_similarity]
+
+ def_convert_l2_distance_to_probability(
+ self,l2_distance:List[float]
+ )->List[float]:
+"""Convert L2 distance to probability.
+
+ note:
+
+ Ensure the vector is normalized so that the l2_distance will be in range [0, 2]
+ """
+ distance=np.array(l2_distance)
+ # clip to ensure the distance is in range [0, 2]
+ distance=np.clip(distance,0,2)
+ # convert to probability
+ prob_score=1-distance/2
+ returnprob_score.tolist()
+
+ def_convert_cosine_distance_to_probability(
+ self,cosine_distance:List[float]
+ )->List[float]:
+"""Convert cosine distance to probability."""
+ return[(1-cosine_distance)forcosine_distanceincosine_distance]
+
+
optional_package (OptionalPackages): The optional package to import, it helps define the package name and error message. """
- def__init__(self,import_path:str,optional_package:OptionalPackages):
+ def__init__(
+ self,import_path:str,optional_package:OptionalPackages,*args,**kwargs
+ ):
+ ifargsorkwargs:
+ raiseTypeError(
+ "LazyImport does not support subclassing or additional arguments. "
+ "Import the class directly from its specific module instead. For example, "
+ "from lightrag.components.model_client.cohere_client import CohereAPIClient"
+ "instead of using the lazy import with: from lightrag.components.model_client import CohereAPIClient"
+ )self.import_path=import_pathself.optional_package=optional_packageself.module=None
@@ -516,10 +526,34 @@
Source code for utils.lazy_import
[docs]
-defsafe_import(module_name,install_message):
+defsafe_import(module_name:str,install_message:str)->ModuleType:"""Safely import a module and raise an ImportError with the install message if the module is not found. Mainly used internally to import optional packages only when needed.
+
+ Example:
+
+ 1. Tests
+
+ .. code-block:: python
+
+ try:
+ numpy = safe_import("numpy", "Please install numpy with: pip install numpy")
+ print(numpy.__version__)
+ except ImportError as e:
+ print(e)
+
+ When numpy is not installed, it will raise an ImportError with the install message.
+ When numpy is installed, it will print the numpy version.
+
+ 2. Use it to delay the import of optional packages in the library.
+
+ .. code-block:: python
+
+ from lightrag.utils.lazy_import import safe_import, OptionalPackages
+
+ numpy = safe_import(OptionalPackages.NUMPY.value[0], OptionalPackages.NUMPY.value[1])
+
"""try:returnimportlib.import_module(module_name)
diff --git a/_modules/utils/setup_env.html b/_modules/utils/setup_env.html
index 65b46ec9..38244eea 100644
--- a/_modules/utils/setup_env.html
+++ b/_modules/utils/setup_env.html
@@ -404,6 +404,10 @@
diff --git a/_sources/apis/components/components.model_client.anthropic_client.rst.txt b/_sources/apis/components/components.model_client.anthropic_client.rst.txt
index f3af8985..bd2f49f8 100644
--- a/_sources/apis/components/components.model_client.anthropic_client.rst.txt
+++ b/_sources/apis/components/components.model_client.anthropic_client.rst.txt
@@ -7,3 +7,17 @@ anthropic_client
:members:
:undoc-members:
:show-inheritance:
+
+
+ .. rubric:: Functions
+
+ .. autosummary::
+
+ get_first_message_content
+
+ .. rubric:: Classes
+
+ .. autosummary::
+
+ AnthropicAPIClient
+
diff --git a/_sources/apis/components/components.model_client.cohere_client.rst.txt b/_sources/apis/components/components.model_client.cohere_client.rst.txt
index 762def3a..be9569be 100644
--- a/_sources/apis/components/components.model_client.cohere_client.rst.txt
+++ b/_sources/apis/components/components.model_client.cohere_client.rst.txt
@@ -7,3 +7,11 @@ cohere_client
:members:
:undoc-members:
:show-inheritance:
+
+
+ .. rubric:: Classes
+
+ .. autosummary::
+
+ CohereAPIClient
+
diff --git a/_sources/apis/components/components.model_client.groq_client.rst.txt b/_sources/apis/components/components.model_client.groq_client.rst.txt
index f7e0925a..6dd213a2 100644
--- a/_sources/apis/components/components.model_client.groq_client.rst.txt
+++ b/_sources/apis/components/components.model_client.groq_client.rst.txt
@@ -7,3 +7,11 @@ groq_client
:members:
:undoc-members:
:show-inheritance:
+
+
+ .. rubric:: Classes
+
+ .. autosummary::
+
+ GroqAPIClient
+
diff --git a/_sources/apis/components/components.model_client.openai_client.rst.txt b/_sources/apis/components/components.model_client.openai_client.rst.txt
index 1c00b378..4af3f5ae 100644
--- a/_sources/apis/components/components.model_client.openai_client.rst.txt
+++ b/_sources/apis/components/components.model_client.openai_client.rst.txt
@@ -7,3 +7,19 @@ openai_client
:members:
:undoc-members:
:show-inheritance:
+
+
+ .. rubric:: Functions
+
+ .. autosummary::
+
+ get_all_messages_content
+ get_first_message_content
+ get_probabilities
+
+ .. rubric:: Classes
+
+ .. autosummary::
+
+ OpenAIClient
+
diff --git a/_sources/apis/components/components.retriever.faiss_retriever.rst.txt b/_sources/apis/components/components.retriever.faiss_retriever.rst.txt
index a9ed294b..cd22e668 100644
--- a/_sources/apis/components/components.retriever.faiss_retriever.rst.txt
+++ b/_sources/apis/components/components.retriever.faiss_retriever.rst.txt
@@ -7,3 +7,11 @@ faiss_retriever
:members:
:undoc-members:
:show-inheritance:
+
+
+ .. rubric:: Classes
+
+ .. autosummary::
+
+ FAISSRetriever
+
diff --git a/_sources/apis/components/components.retriever.postgres_retriever.rst.txt b/_sources/apis/components/components.retriever.postgres_retriever.rst.txt
index cdbb70d0..8aa1d9eb 100644
--- a/_sources/apis/components/components.retriever.postgres_retriever.rst.txt
+++ b/_sources/apis/components/components.retriever.postgres_retriever.rst.txt
@@ -7,3 +7,12 @@ postgres_retriever
:members:
:undoc-members:
:show-inheritance:
+
+
+ .. rubric:: Classes
+
+ .. autosummary::
+
+ DistanceToOperator
+ PostgresRetriever
+
diff --git a/_sources/developer_notes/lightrag_design_philosophy.rst.txt b/_sources/developer_notes/lightrag_design_philosophy.rst.txt
index f2fb932c..e7242a0d 100644
--- a/_sources/developer_notes/lightrag_design_philosophy.rst.txt
+++ b/_sources/developer_notes/lightrag_design_philosophy.rst.txt
@@ -6,11 +6,11 @@ Right from the begining, `LightRAG` follows three fundamental principles.
Principle 1: Simplicity over Complexity
-----------------------------------------------------------------------
-We put these three hard rules while designing LightRAG:
+ We put these three hard rules while designing LightRAG:
- Every layer of abstraction needs to be adjusted and overall we do not allow more than 3 layers of abstraction.
- We minimize the lines of code instead of maximizing the lines of code.
-- Go *deep* and *wide* in order to *simplify*. The clarity we achieve is not the result of being easy, but the result of being deep.
+- Go *deep* and *wide* in order to *simplify*. The clarity we achieve is not the result of being easy.
diff --git a/_sources/get_started/lightrag_in_10_mins.rst.txt b/_sources/get_started/lightrag_in_10_mins.rst.txt
index 1bdec107..d046c77f 100644
--- a/_sources/get_started/lightrag_in_10_mins.rst.txt
+++ b/_sources/get_started/lightrag_in_10_mins.rst.txt
@@ -1,13 +1,15 @@
LightRAG in 10 minutes
=============================
-[Li]
+Coming soon...
-Will use end to end trec classifier as an example to demonstrate:
+We will showcase a use case end-to-end, including task pipeline, configuration, logging and tracing, evaluation, and optimization.
-1. Look at the data and task, create `data class`, `prompt`, and `task`, and set up `log` and tracing.
-2. Create datasets with `train`, `eval`, and `test` splits.
-3. Eval zero-shot with manual prompts.
+.. Will use end to end trec classifier as an example to demonstrate:
+.. 1. Look at the data and task, create `data class`, `prompt`, and `task`, and set up `log` and tracing.
+.. 2. Create datasets with `train`, `eval`, and `test` splits.
+.. 3. Eval zero-shot with manual prompts.
-The content will be from `/use_cases/classification/readme.md`.
\ No newline at end of file
+
+.. The content will be from `/use_cases/classification/readme.md`.
diff --git a/_sources/index.rst.txt b/_sources/index.rst.txt
index 64cb9026..58f21f96 100644
--- a/_sources/index.rst.txt
+++ b/_sources/index.rst.txt
@@ -26,7 +26,7 @@
- LightRAG helps developers with both building and optimizing Retriever-Agent-Generator (RAG) pipelines.
+ LightRAG helps developers with both building and optimizing Retriever-Agent-Generator pipelines.
It is light, modular, and robust, with a 100% readable codebase.
Subclass use this to call the API with the sync client.
+model_type: this decides which API, such as chat.completions or embeddings for OpenAI.
+api_kwargs: all the arguments that the API call needs, subclass should implement this method.
+
Additionally in subclass you can implement the error handling and retry logic here. See OpenAIClient for example.
Subclass use this to call the API with the sync client.
+model_type: this decides which API, such as chat.completions or embeddings for OpenAI.
+api_kwargs: all the arguments that the API call needs, subclass should implement this method.
+
Additionally in subclass you can implement the error handling and retry logic here. See OpenAIClient for example.