-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c6aa825
commit d184d49
Showing
6 changed files
with
256 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import uuid | ||
import os.path | ||
from datetime import datetime | ||
from typing import List, Dict | ||
|
||
from agent.const import * | ||
from agent.utils import Config, CONFIG_HOME | ||
|
||
import chromadb | ||
from chromadb.utils import embedding_functions | ||
|
||
chromadb.logger.setLevel(chromadb.logging.ERROR) | ||
|
||
|
||
class Memory: | ||
def __init__( | ||
self, | ||
data_path: str = CONFIG_HOME | ||
): | ||
""" | ||
RAG for Termax: memory and external knowledge management. | ||
Args: | ||
data_path: the path to store the data. | ||
""" | ||
self.config = Config().read() | ||
self.client = chromadb.PersistentClient(path=os.path.join(data_path, DB_NAME)) | ||
|
||
def create_collection(self, collection_name: str, embedding_model: str = "text-embedding-ada-002"): | ||
""" | ||
create_collection: create a collection in the memory. | ||
:param collection_name: the name of the collection. | ||
:param embedding_model: the embedding model to use, default will use the embedding model from ChromaDB, | ||
if the OpenAI has been set in the configuration, it will use the OpenAI embedding model | ||
"text-embedding-ada-002". | ||
:return: the collection. | ||
""" | ||
# use the OpenAI embedding function if the openai section is set in the configuration. | ||
if self.config.get(LLM_TYPE_OPENAI, None): | ||
self.client.get_or_create_collection( | ||
collection_name, | ||
embedding_function=embedding_functions.OpenAIEmbeddingFunction( | ||
model_name=embedding_model, | ||
api_key=self.config[LLM_TYPE_OPENAI][CONFIG_SEC_API_KEY] | ||
) | ||
) | ||
else: | ||
self.client.get_or_create_collection(collection_name) | ||
|
||
def add_query( | ||
self, | ||
queries: List[Dict[str, str]], | ||
collection: str, | ||
idx: List[str] = None | ||
): | ||
""" | ||
add_query: add the queries to the memery. | ||
Args: | ||
queries: the queries to add to the memery. Should be in the format of | ||
{ | ||
"query": "the query", | ||
"response": "the response" | ||
} | ||
collection: the name of the collection to add the queries. | ||
idx: the ids of the queries, should be in the same length as the queries. | ||
If not provided, the ids will be generated by UUID. | ||
Return: A list of generated IDs. | ||
""" | ||
if idx: | ||
ids = idx | ||
else: | ||
ids = [str(uuid.uuid4()) for _ in range(len(queries))] | ||
|
||
query_list = [query['query'] for query in queries] | ||
added_time = datetime.now().isoformat() | ||
resp_list = [{'response': query['response'], 'created_at': added_time} for query in queries] | ||
# insert the record into the database | ||
self.client.get_or_create_collection(collection).add( | ||
documents=query_list, | ||
metadatas=resp_list, | ||
ids=ids | ||
) | ||
|
||
return ids | ||
|
||
def query(self, query_texts: List[str], collection: str, n_results: int = 5): | ||
""" | ||
query: query the memery. | ||
Args: | ||
query_texts: the query texts to search in the memery. | ||
collection: the name of the collection to search, default is the command history. | ||
n_results: the number of results to return. | ||
Returns: the top k results. | ||
""" | ||
return self.client.get_or_create_collection(collection).query(query_texts=query_texts, n_results=n_results) | ||
|
||
def peek(self, collection: str, n_results: int = 20): | ||
""" | ||
peek: peek the memery. | ||
Args: | ||
collection: the name of the collection to peek, default is the command history. | ||
n_results: the number of results to return. | ||
Returns: the top k results. | ||
""" | ||
return self.client.get_or_create_collection(collection).peek(limit=n_results) | ||
|
||
def get(self, collection: str, record_id: str = None, ): | ||
""" | ||
get: get the record by the id. | ||
Args: | ||
collection: the name of the collection to get the record. | ||
record_id: the id of the record. | ||
Returns: the record. | ||
""" | ||
collection = self.client.get_collection(collection) | ||
if not record_id: | ||
return collection.get() | ||
|
||
return collection.get(record_id) | ||
|
||
def delete(self, collection_name: str): | ||
""" | ||
delete: delete the memery collections. | ||
Args: | ||
collection_name: the name of the collection to delete. | ||
""" | ||
return self.client.delete_collection(name=collection_name) | ||
|
||
def count(self, collection_name: str): | ||
""" | ||
count: count the number of records in the memery. | ||
Args: | ||
collection_name: the name of the collection to count. | ||
""" | ||
|
||
return self.client.get_collection(name=collection_name).count() | ||
|
||
def reset(self): | ||
""" | ||
reset: reset the memory. | ||
Notice: You may need to set the environment variable `ALLOW_RESET` to `TRUE` to enable this function. | ||
""" | ||
self.client.reset() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .config import Config, CONFIG_HOME |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import os | ||
import configparser | ||
from pathlib import Path | ||
|
||
from agent.const import * | ||
|
||
CONFIG_HOME = os.path.join(str(Path.home()), ".mle") | ||
CONFIG_PATH = os.path.join(CONFIG_HOME, "config") | ||
|
||
|
||
class Config: | ||
""" | ||
Config: the overall system configuration for Termax. | ||
""" | ||
|
||
def __init__(self): | ||
self.home = CONFIG_HOME | ||
Path(self.home).mkdir(parents=True, exist_ok=True) | ||
|
||
self.config_path = CONFIG_PATH | ||
self.config = configparser.ConfigParser() | ||
self.config.read(self.config_path) | ||
|
||
def read(self): | ||
""" | ||
read: read the configuration file. | ||
Returns: a dictionary of the configuration. | ||
""" | ||
self.reload_config(CONFIG_PATH) | ||
config_dict = {} | ||
|
||
for section in self.config.sections(): | ||
options_dict = {option: self.config.get(section, option) for option in self.config.options(section)} | ||
config_dict[section] = options_dict | ||
|
||
return config_dict | ||
|
||
def reload_config(self, config_path): | ||
""" | ||
reload_config: The default configuration will load ~/.mle/config, if user want to specify | ||
customize, the method is required. | ||
@param config_path: the path of new configuration file. | ||
""" | ||
self.config.read(config_path) | ||
|
||
def load_openai_config(self): | ||
""" | ||
load_openai_config: load a OpenAI configuration when required. | ||
""" | ||
if self.config.has_section(LLM_TYPE_OPENAI): | ||
return self.config[LLM_TYPE_OPENAI] | ||
else: | ||
raise ValueError("there is no '[openai]' section found in the configuration file.") | ||
|
||
def write_general(self, config_dict: dict): | ||
""" | ||
write_general: write the general configuration. | ||
@param config_dict: the configuration dictionary. | ||
""" | ||
if not self.config.has_section(CONFIG_SEC_GENERAL): | ||
self.config.add_section(CONFIG_SEC_GENERAL) | ||
|
||
self.config[CONFIG_SEC_GENERAL] = config_dict | ||
|
||
# save the new configuration and reload. | ||
with open(self.config_path, 'w') as configfile: | ||
self.config.write(configfile) | ||
self.reload_config(self.config_path) | ||
|
||
def write_platform( | ||
self, | ||
config_dict: dict, | ||
platform: str = LLM_TYPE_OPENAI | ||
): | ||
""" | ||
write_platform: indicate and generate the platform related configuration. | ||
@param config_dict: the configuration dictionary. | ||
@param platform: the platform to configure. | ||
""" | ||
# create the configuration to connect with OpenAI. | ||
if not self.config.has_section(platform): | ||
self.config.add_section(platform) | ||
|
||
self.config[platform] = config_dict | ||
|
||
# save the new configuration and reload. | ||
with open(self.config_path, 'w') as configfile: | ||
self.config.write(configfile) | ||
self.reload_config(self.config_path) |