Skip to content

Commit

Permalink
✨ feat: added RAG system
Browse files Browse the repository at this point in the history
  • Loading branch information
huangyz0918 committed Apr 17, 2024
1 parent c6aa825 commit d184d49
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 5 deletions.
12 changes: 8 additions & 4 deletions agent/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from rich.console import Console

import agent
from agent.utils import Config

config = Config()
# avoid the tokenizers parallelism issue
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

Expand Down Expand Up @@ -48,16 +50,18 @@ def resolve_command(self, ctx, args):
@click.version_option(version=agent.__version__)
def cli():
"""
Termax: A CLI tool to generate and execute commands from natural language.
MLE-Agent: The CLI tool to build machine learning projects.
"""
pass





@cli.command(default_command=True)
@click.argument('text', nargs=-1)
def generate(text):
def ask(text):
"""
Generate the code from the natural language text.
ASK the agent a question to build an ML project.
"""
console = Console()
console.log(text)
4 changes: 3 additions & 1 deletion agent/const.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
DB_NAME = 'database'

CONFIG_SEC_GENERAL = 'general'
CONFIG_SEC_API_KEY = 'api_key'
CONFIG_LLM_LIST = { # with the default model.
Expand All @@ -10,4 +12,4 @@
}

# LLMs
CONFIG_SEC_OPENAI = 'openai'
LLM_TYPE_OPENAI = 'openai'
Empty file added agent/function/data.py
Empty file.
149 changes: 149 additions & 0 deletions agent/prompt/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import uuid
import os.path
from datetime import datetime
from typing import List, Dict

from agent.const import *
from agent.utils import Config, CONFIG_HOME

import chromadb
from chromadb.utils import embedding_functions

chromadb.logger.setLevel(chromadb.logging.ERROR)


class Memory:
def __init__(
self,
data_path: str = CONFIG_HOME
):
"""
RAG for Termax: memory and external knowledge management.
Args:
data_path: the path to store the data.
"""
self.config = Config().read()
self.client = chromadb.PersistentClient(path=os.path.join(data_path, DB_NAME))

def create_collection(self, collection_name: str, embedding_model: str = "text-embedding-ada-002"):
"""
create_collection: create a collection in the memory.
:param collection_name: the name of the collection.
:param embedding_model: the embedding model to use, default will use the embedding model from ChromaDB,
if the OpenAI has been set in the configuration, it will use the OpenAI embedding model
"text-embedding-ada-002".
:return: the collection.
"""
# use the OpenAI embedding function if the openai section is set in the configuration.
if self.config.get(LLM_TYPE_OPENAI, None):
self.client.get_or_create_collection(
collection_name,
embedding_function=embedding_functions.OpenAIEmbeddingFunction(
model_name=embedding_model,
api_key=self.config[LLM_TYPE_OPENAI][CONFIG_SEC_API_KEY]
)
)
else:
self.client.get_or_create_collection(collection_name)

def add_query(
self,
queries: List[Dict[str, str]],
collection: str,
idx: List[str] = None
):
"""
add_query: add the queries to the memery.
Args:
queries: the queries to add to the memery. Should be in the format of
{
"query": "the query",
"response": "the response"
}
collection: the name of the collection to add the queries.
idx: the ids of the queries, should be in the same length as the queries.
If not provided, the ids will be generated by UUID.
Return: A list of generated IDs.
"""
if idx:
ids = idx
else:
ids = [str(uuid.uuid4()) for _ in range(len(queries))]

query_list = [query['query'] for query in queries]
added_time = datetime.now().isoformat()
resp_list = [{'response': query['response'], 'created_at': added_time} for query in queries]
# insert the record into the database
self.client.get_or_create_collection(collection).add(
documents=query_list,
metadatas=resp_list,
ids=ids
)

return ids

def query(self, query_texts: List[str], collection: str, n_results: int = 5):
"""
query: query the memery.
Args:
query_texts: the query texts to search in the memery.
collection: the name of the collection to search, default is the command history.
n_results: the number of results to return.
Returns: the top k results.
"""
return self.client.get_or_create_collection(collection).query(query_texts=query_texts, n_results=n_results)

def peek(self, collection: str, n_results: int = 20):
"""
peek: peek the memery.
Args:
collection: the name of the collection to peek, default is the command history.
n_results: the number of results to return.
Returns: the top k results.
"""
return self.client.get_or_create_collection(collection).peek(limit=n_results)

def get(self, collection: str, record_id: str = None, ):
"""
get: get the record by the id.
Args:
collection: the name of the collection to get the record.
record_id: the id of the record.
Returns: the record.
"""
collection = self.client.get_collection(collection)
if not record_id:
return collection.get()

return collection.get(record_id)

def delete(self, collection_name: str):
"""
delete: delete the memery collections.
Args:
collection_name: the name of the collection to delete.
"""
return self.client.delete_collection(name=collection_name)

def count(self, collection_name: str):
"""
count: count the number of records in the memery.
Args:
collection_name: the name of the collection to count.
"""

return self.client.get_collection(name=collection_name).count()

def reset(self):
"""
reset: reset the memory.
Notice: You may need to set the environment variable `ALLOW_RESET` to `TRUE` to enable this function.
"""
self.client.reset()
1 change: 1 addition & 0 deletions agent/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .config import Config, CONFIG_HOME
95 changes: 95 additions & 0 deletions agent/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import configparser
from pathlib import Path

from agent.const import *

CONFIG_HOME = os.path.join(str(Path.home()), ".mle")
CONFIG_PATH = os.path.join(CONFIG_HOME, "config")


class Config:
"""
Config: the overall system configuration for Termax.
"""

def __init__(self):
self.home = CONFIG_HOME
Path(self.home).mkdir(parents=True, exist_ok=True)

self.config_path = CONFIG_PATH
self.config = configparser.ConfigParser()
self.config.read(self.config_path)

def read(self):
"""
read: read the configuration file.
Returns: a dictionary of the configuration.
"""
self.reload_config(CONFIG_PATH)
config_dict = {}

for section in self.config.sections():
options_dict = {option: self.config.get(section, option) for option in self.config.options(section)}
config_dict[section] = options_dict

return config_dict

def reload_config(self, config_path):
"""
reload_config: The default configuration will load ~/.mle/config, if user want to specify
customize, the method is required.
@param config_path: the path of new configuration file.
"""
self.config.read(config_path)

def load_openai_config(self):
"""
load_openai_config: load a OpenAI configuration when required.
"""
if self.config.has_section(LLM_TYPE_OPENAI):
return self.config[LLM_TYPE_OPENAI]
else:
raise ValueError("there is no '[openai]' section found in the configuration file.")

def write_general(self, config_dict: dict):
"""
write_general: write the general configuration.
@param config_dict: the configuration dictionary.
"""
if not self.config.has_section(CONFIG_SEC_GENERAL):
self.config.add_section(CONFIG_SEC_GENERAL)

self.config[CONFIG_SEC_GENERAL] = config_dict

# save the new configuration and reload.
with open(self.config_path, 'w') as configfile:
self.config.write(configfile)
self.reload_config(self.config_path)

def write_platform(
self,
config_dict: dict,
platform: str = LLM_TYPE_OPENAI
):
"""
write_platform: indicate and generate the platform related configuration.
@param config_dict: the configuration dictionary.
@param platform: the platform to configure.
"""
# create the configuration to connect with OpenAI.
if not self.config.has_section(platform):
self.config.add_section(platform)

self.config[platform] = config_dict

# save the new configuration and reload.
with open(self.config_path, 'w') as configfile:
self.config.write(configfile)
self.reload_config(self.config_path)

0 comments on commit d184d49

Please sign in to comment.