Skip to content

Commit

Permalink
Merge pull request #269 from leeeizhang/lei/enhance-memory-api
Browse files Browse the repository at this point in the history
[MRG] clean chromadb and add manage api for lancedb memory
  • Loading branch information
huangyz0918 authored Nov 19, 2024
2 parents dd7da4e + 4505f7f commit ea7393b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 153 deletions.
220 changes: 68 additions & 152 deletions mle/utils/memory.py
Original file line number Diff line number Diff line change
@@ -1,162 +1,11 @@
import uuid
import os.path
from datetime import datetime
from typing import List, Dict, Optional

import lancedb
from lancedb.embeddings import get_registry

import chromadb
from chromadb.utils import embedding_functions

from mle.utils import get_config

chromadb.logger.setLevel(chromadb.logging.ERROR)


class ChromaDBMemory:

def __init__(
self,
project_path: str,
embedding_model: str = "text-embedding-ada-002"
):
"""
Memory: memory and external knowledge management.
Args:
project_path: the path to store the data.
embedding_model: the embedding model to use, default will use the embedding model from ChromaDB,
if the OpenAI has been set in the configuration, it will use the OpenAI embedding model
"text-embedding-ada-002".
"""
self.db_name = '.mle'
self.collection_name = 'memory'
self.client = chromadb.PersistentClient(path=os.path.join(project_path, self.db_name))

config = get_config(project_path)
# use the OpenAI embedding function if the openai section is set in the configuration.
if config['platform'] == 'OpenAI':
self.client.get_or_create_collection(
self.collection_name,
embedding_function=embedding_functions.OpenAIEmbeddingFunction(
model_name=embedding_model,
api_key=config['api_key']
)
)
else:
self.client.get_or_create_collection(self.collection_name)

def add_query(
self,
queries: List[Dict[str, str]],
collection: str = None,
idx: List[str] = None
):
"""
add_query: add the queries to the memery.
Args:
queries: the queries to add to the memery. Should be in the format of
{
"query": "the query",
"response": "the response"
}
collection: the name of the collection to add the queries.
idx: the ids of the queries, should be in the same length as the queries.
If not provided, the ids will be generated by UUID.
Return: A list of generated IDs.
"""
if idx:
ids = idx
else:
ids = [str(uuid.uuid4()) for _ in range(len(queries))]

if not collection:
collection = self.collection_name

query_list = [query['query'] for query in queries]
added_time = datetime.now().isoformat()
resp_list = [{'response': query['response'], 'created_at': added_time} for query in queries]
# insert the record into the database
self.client.get_or_create_collection(collection).add(
documents=query_list,
metadatas=resp_list,
ids=ids
)

return ids

def query(self, query_texts: List[str], collection: str = None, n_results: int = 5):
"""
query: query the memery.
Args:
query_texts: the query texts to search in the memery.
collection: the name of the collection to search.
n_results: the number of results to return.
Returns: the top k results.
"""
if not collection:
collection = self.collection_name
return self.client.get_or_create_collection(collection).query(query_texts=query_texts, n_results=n_results)

def peek(self, collection: str = None, n_results: int = 20):
"""
peek: peek the memery.
Args:
collection: the name of the collection to peek.
n_results: the number of results to return.
Returns: the top k results.
"""
if not collection:
collection = self.collection_name
return self.client.get_or_create_collection(collection).peek(limit=n_results)

def get(self, collection: str = None, record_id: str = None):
"""
get: get the record by the id.
Args:
record_id: the id of the record.
collection: the name of the collection to get the record.
Returns: the record.
"""
if not collection:
collection = self.collection_name
collection = self.client.get_collection(collection)
if not record_id:
return collection.get()

return collection.get(record_id)

def delete(self, collection_name=None):
"""
delete: delete the memery collections.
Args:
collection_name: the name of the collection to delete.
"""
if not collection_name:
collection_name = self.collection_name
return self.client.delete_collection(name=collection_name)

def count(self, collection_name=None):
"""
count: count the number of records in the memery.
Args:
collection_name: the name of the collection to count.
"""
if not collection_name:
collection_name = self.collection_name
return self.client.get_collection(name=collection_name).count()

def reset(self):
"""
reset: reset the memory.
Notice: You may need to set the environment variable `ALLOW_RESET` to `TRUE` to enable this function.
"""
self.client.reset()


class LanceDBMemory:

Expand Down Expand Up @@ -221,7 +70,8 @@ def add(
]

if table_name not in self.client.table_names():
self.client.create_table(table_name, data=data)
table = self.client.create_table(table_name, data=data)
table.create_fts_index("id")
else:
self.client.open_table(table_name).add(data=data)

Expand All @@ -247,6 +97,56 @@ def query(self, query_texts: List[str], table_name: Optional[str] = None, n_resu
results = [table.search(query).limit(n_results).to_list() for query in query_embeds]
return results

def list_all_keys(self, table_name: Optional[str] = None):
"""
Lists all IDs in the specified memory table.
Args:
table_name (Optional[str]): The name of the table to list IDs from. Defaults to the instance's table name.
Returns:
List[str]: A list of all IDs in the table.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
return [item["id"] for item in table.search(query_type="fts").to_list()]

def get(self, record_id: str, table_name: Optional[str] = None):
"""
Retrieves a record by its ID from the specified memory table.
Args:
record_id (str): The ID of the record to retrieve.
table_name (Optional[str]): The name of the table to query. Defaults to the instance's table name.
Returns:
List[dict]: A list containing the matching record, or an empty list if not found.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
return table.search(query_type="fts") \
.where(f"id = '{record_id}'") \
.limit(1).to_list()

def get_by_metadata(self, key: str, value: str, table_name: Optional[str] = None, n_results: int = 5):
"""
Retrieves records matching a specific metadata key-value pair.
Args:
key (str): The metadata key to filter by.
value (str): The value of the metadata key to filter by.
table_name (Optional[str]): The name of the table to query. Defaults to the instance's table name.
n_results (int): The maximum number of results to retrieve. Defaults to 5.
Returns:
List[dict]: A list of records matching the metadata criteria.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
return table.search(query_type="fts") \
.where(f"metadata.{key} = '{value}'") \
.limit(n_results).to_list()

def delete(self, record_id: str, table_name: Optional[str] = None) -> bool:
"""
Deletes a record from the specified memory table.
Expand All @@ -262,6 +162,22 @@ def delete(self, record_id: str, table_name: Optional[str] = None) -> bool:
table = self.client.open_table(table_name)
return table.delete(f"id = '{record_id}'")

def delete_by_metadata(self, key: str, value: str, table_name: Optional[str] = None):
"""
Deletes records from the specified memory table based on a metadata key-value pair.
Args:
key (str): The metadata key to filter by.
value (str): The value of the metadata key to filter by.
table_name (Optional[str]): The name of the table to delete records from. Defaults to the instance's table name.
Returns:
bool: True if deletion was successful, False otherwise.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
return table.delete(f"metadata.{key} = '{value}'")

def drop(self, table_name: Optional[str] = None) -> bool:
"""
Drops (deletes) the specified memory table.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ kaggle
fastapi
uvicorn
requests
chromadb
GitPython
tree-sitter==0.21.3
onnxruntime
Expand All @@ -23,3 +22,4 @@ google-api-python-client~=2.143.0
google-auth-httplib2~=0.2.0
google-auth-oauthlib~=1.2.1
lancedb~=0.15.0
tantivy~=0.22.0

0 comments on commit ea7393b

Please sign in to comment.