-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
127 lines (100 loc) · 3.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# import required dependencies
# https://docs.chainlit.io/integrations/langchain
import os
import warnings
import chainlit as cl
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import RetrievalQA
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import Ollama
from langchain_community.vectorstores import Chroma
from langchain_core.prompts import PromptTemplate
from helpers import load_env
warnings.simplefilter("ignore")
attrs = load_env()
template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Helpful Answer:"""
prompt = PromptTemplate.from_template(template)
def load_model():
return Ollama(
model=attrs["OLLAMA_MODEL"],
verbose=True,
callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
)
def load_vector_db():
chroma_db_dir_default = attrs["VECTOR_DB_PATH"]
chroma_db_dir = os.environ.get("VECTOR_DB_PATH", chroma_db_dir_default)
print(f"Loading Chroma DB from {chroma_db_dir}")
return Chroma(
persist_directory=chroma_db_dir,
embedding_function=HuggingFaceEmbeddings(
model_name="jinaai/jina-embeddings-v2-base-code",
model_kwargs={"trust_remote_code": True},
),
)
def retrieval_qa_chain(llm, vectorstore):
return RetrievalQA.from_chain_type(
llm,
retriever=vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": 10},
),
chain_type_kwargs={"prompt": prompt},
return_source_documents=True,
)
def qa_bot():
llm = load_model()
vectorstore = load_vector_db()
return retrieval_qa_chain(llm, vectorstore)
@cl.on_chat_start
async def start():
"""
Initializes the bot when a new chat starts.
This asynchronous function creates a new instance of the retrieval QA bot,
sends a welcome message, and stores the bot instance in the user's session.
"""
chain = qa_bot()
welcome_message = cl.Message(content="Starting the bot...")
await welcome_message.send()
welcome_message.content = (
"Hi, Welcome to Granny RAG. Ask me anything about your code!"
)
await welcome_message.update()
cl.user_session.set("chain", chain)
@cl.on_message
async def main(message):
"""
Processes incoming chat messages.
This asynchronous function retrieves the QA bot instance from the user's session,
sets up a callback handler for the bot's response, and executes the bot's
call method with the given message and callback. The bot's answer and source
documents are then extracted from the response.
"""
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler()
cb.answer_reached = True
res = await chain.acall(message.content, callbacks=[cb])
answer = res["result"]
source_documents = res["source_documents"]
text_elements = [] # type: List[cl.Text]
if source_documents:
for source_idx, source_doc in enumerate(source_documents):
source_name = f"{source_doc.metadata['source']}"
language = source_doc.metadata.get('language', 'text')
# Create the text element referenced in the message
text_elements.append(
cl.Text(
content=source_doc.page_content, name=source_name, display="side", language=language
)
)
source_names = [text_el.name for text_el in text_elements]
if source_names:
answer += f"\nSources:\n"
answer += ",\n ".join(source_names)
else:
answer += "\nNo sources found"
await cl.Message(content=answer, elements=text_elements).send()