From a04537f0cf6a47eb3f499343457df11b99f691ea Mon Sep 17 00:00:00 2001 From: "leizhang.real@gmail.com" Date: Tue, 19 Nov 2024 14:05:50 +0000 Subject: [PATCH 1/2] clean chromadb and add more lancedb api --- mle/utils/memory.py | 220 ++++++++++++++------------------------------ requirements.txt | 1 - 2 files changed, 68 insertions(+), 153 deletions(-) diff --git a/mle/utils/memory.py b/mle/utils/memory.py index b536ba6..52af976 100644 --- a/mle/utils/memory.py +++ b/mle/utils/memory.py @@ -1,162 +1,11 @@ import uuid -import os.path -from datetime import datetime from typing import List, Dict, Optional import lancedb from lancedb.embeddings import get_registry -import chromadb -from chromadb.utils import embedding_functions - from mle.utils import get_config -chromadb.logger.setLevel(chromadb.logging.ERROR) - - -class ChromaDBMemory: - - def __init__( - self, - project_path: str, - embedding_model: str = "text-embedding-ada-002" - ): - """ - Memory: memory and external knowledge management. - Args: - project_path: the path to store the data. - embedding_model: the embedding model to use, default will use the embedding model from ChromaDB, - if the OpenAI has been set in the configuration, it will use the OpenAI embedding model - "text-embedding-ada-002". - """ - self.db_name = '.mle' - self.collection_name = 'memory' - self.client = chromadb.PersistentClient(path=os.path.join(project_path, self.db_name)) - - config = get_config(project_path) - # use the OpenAI embedding function if the openai section is set in the configuration. - if config['platform'] == 'OpenAI': - self.client.get_or_create_collection( - self.collection_name, - embedding_function=embedding_functions.OpenAIEmbeddingFunction( - model_name=embedding_model, - api_key=config['api_key'] - ) - ) - else: - self.client.get_or_create_collection(self.collection_name) - - def add_query( - self, - queries: List[Dict[str, str]], - collection: str = None, - idx: List[str] = None - ): - """ - add_query: add the queries to the memery. - Args: - queries: the queries to add to the memery. Should be in the format of - { - "query": "the query", - "response": "the response" - } - collection: the name of the collection to add the queries. - idx: the ids of the queries, should be in the same length as the queries. - If not provided, the ids will be generated by UUID. - - Return: A list of generated IDs. - """ - if idx: - ids = idx - else: - ids = [str(uuid.uuid4()) for _ in range(len(queries))] - - if not collection: - collection = self.collection_name - - query_list = [query['query'] for query in queries] - added_time = datetime.now().isoformat() - resp_list = [{'response': query['response'], 'created_at': added_time} for query in queries] - # insert the record into the database - self.client.get_or_create_collection(collection).add( - documents=query_list, - metadatas=resp_list, - ids=ids - ) - - return ids - - def query(self, query_texts: List[str], collection: str = None, n_results: int = 5): - """ - query: query the memery. - Args: - query_texts: the query texts to search in the memery. - collection: the name of the collection to search. - n_results: the number of results to return. - - Returns: the top k results. - """ - if not collection: - collection = self.collection_name - return self.client.get_or_create_collection(collection).query(query_texts=query_texts, n_results=n_results) - - def peek(self, collection: str = None, n_results: int = 20): - """ - peek: peek the memery. - Args: - collection: the name of the collection to peek. - n_results: the number of results to return. - - Returns: the top k results. - """ - if not collection: - collection = self.collection_name - return self.client.get_or_create_collection(collection).peek(limit=n_results) - - def get(self, collection: str = None, record_id: str = None): - """ - get: get the record by the id. - Args: - record_id: the id of the record. - collection: the name of the collection to get the record. - - Returns: the record. - """ - if not collection: - collection = self.collection_name - collection = self.client.get_collection(collection) - if not record_id: - return collection.get() - - return collection.get(record_id) - - def delete(self, collection_name=None): - """ - delete: delete the memery collections. - Args: - collection_name: the name of the collection to delete. - """ - if not collection_name: - collection_name = self.collection_name - return self.client.delete_collection(name=collection_name) - - def count(self, collection_name=None): - """ - count: count the number of records in the memery. - Args: - collection_name: the name of the collection to count. - """ - if not collection_name: - collection_name = self.collection_name - return self.client.get_collection(name=collection_name).count() - - def reset(self): - """ - reset: reset the memory. - Notice: You may need to set the environment variable `ALLOW_RESET` to `TRUE` to enable this function. - """ - self.client.reset() - class LanceDBMemory: @@ -221,7 +70,8 @@ def add( ] if table_name not in self.client.table_names(): - self.client.create_table(table_name, data=data) + table = self.client.create_table(table_name, data=data) + table.create_fts_index("id") else: self.client.open_table(table_name).add(data=data) @@ -247,6 +97,56 @@ def query(self, query_texts: List[str], table_name: Optional[str] = None, n_resu results = [table.search(query).limit(n_results).to_list() for query in query_embeds] return results + def list_all_keys(self, table_name: Optional[str] = None): + """ + Lists all IDs in the specified memory table. + + Args: + table_name (Optional[str]): The name of the table to list IDs from. Defaults to the instance's table name. + + Returns: + List[str]: A list of all IDs in the table. + """ + table_name = table_name or self.table_name + table = self.client.open_table(table_name) + return [item["id"] for item in table.search(query_type="fts").to_list()] + + def get(self, record_id: str, table_name: Optional[str] = None): + """ + Retrieves a record by its ID from the specified memory table. + + Args: + record_id (str): The ID of the record to retrieve. + table_name (Optional[str]): The name of the table to query. Defaults to the instance's table name. + + Returns: + List[dict]: A list containing the matching record, or an empty list if not found. + """ + table_name = table_name or self.table_name + table = self.client.open_table(table_name) + return table.search(query_type="fts") \ + .where(f"id = '{record_id}'") \ + .limit(1).to_list() + + def get_by_metadata(self, key: str, value: str, table_name: Optional[str] = None, n_results: int = 5): + """ + Retrieves records matching a specific metadata key-value pair. + + Args: + key (str): The metadata key to filter by. + value (str): The value of the metadata key to filter by. + table_name (Optional[str]): The name of the table to query. Defaults to the instance's table name. + n_results (int): The maximum number of results to retrieve. Defaults to 5. + + Returns: + List[dict]: A list of records matching the metadata criteria. + """ + table_name = table_name or self.table_name + table = self.client.open_table(table_name) + return table.search(query_type="fts") \ + .where(f"metadata.{key} = '{value}'") \ + .limit(n_results).to_list() + def delete(self, record_id: str, table_name: Optional[str] = None) -> bool: """ Deletes a record from the specified memory table. @@ -262,6 +162,22 @@ def delete(self, record_id: str, table_name: Optional[str] = None) -> bool: table = self.client.open_table(table_name) return table.delete(f"id = '{record_id}'") + def delete_by_metadata(self, key: str, value: str, table_name: Optional[str] = None): + """ + Deletes records from the specified memory table based on a metadata key-value pair. + + Args: + key (str): The metadata key to filter by. + value (str): The value of the metadata key to filter by. + table_name (Optional[str]): The name of the table to delete records from. Defaults to the instance's table name. + + Returns: + bool: True if deletion was successful, False otherwise. + """ + table_name = table_name or self.table_name + table = self.client.open_table(table_name) + return table.delete(f"metadata.{key} = '{value}'") + def drop(self, table_name: Optional[str] = None) -> bool: """ Drops (deletes) the specified memory table. diff --git a/requirements.txt b/requirements.txt index 0a8bf0c..07ba3a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ kaggle fastapi uvicorn requests -chromadb GitPython tree-sitter==0.21.3 onnxruntime From 4505f7f6f15418e91bba239d9b4967f2c7b591eb Mon Sep 17 00:00:00 2001 From: "leizhang.real@gmail.com" Date: Tue, 19 Nov 2024 14:07:51 +0000 Subject: [PATCH 2/2] add missing dependency --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 07ba3a0..6ec76fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,4 @@ google-api-python-client~=2.143.0 google-auth-httplib2~=0.2.0 google-auth-oauthlib~=1.2.1 lancedb~=0.15.0 +tantivy~=0.22.0