From 27e8be0a51d11694074f47c8d510414af65537cb Mon Sep 17 00:00:00 2001
From: Li Yin
Date: Tue, 10 Dec 2024 10:09:28 -0800
Subject: [PATCH] add rag in the tutorials code
---
tutorials/rag.ipynb | 65 -------------------------
tutorials/rag/config.py | 23 +++++++++
tutorials/rag/rag.py | 105 ++++++++++++++++++++++++++++++++++++++++
3 files changed, 128 insertions(+), 65 deletions(-)
delete mode 100644 tutorials/rag.ipynb
create mode 100644 tutorials/rag/config.py
create mode 100644 tutorials/rag/rag.py
diff --git a/tutorials/rag.ipynb b/tutorials/rag.ipynb
deleted file mode 100644
index b5163e51..00000000
--- a/tutorials/rag.ipynb
+++ /dev/null
@@ -1,65 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We have already seen how the RAG is implemented in data.\n",
- "In this note, we will focus more on how to make each component more configurable, \n",
- "espeically the data processing pipeline to help us with experiments where we will see how useful they are in benchmarking."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "# the data pipeline and the backend data processing\n",
- "from adalflow.core.embedder import Embedder\n",
- "from adalflow.core.types import ModelClientType\n",
- "from adalflow.components.data_process import TextSplitter, ToEmbeddings\n",
- "from adalflow.core.container import Sequential\n",
- "\n",
- "\n",
- "def prepare_data_pipeline():\n",
- " model_kwargs = {\n",
- " \"model\": \"text-embedding-3-small\",\n",
- " \"dimensions\": 256,\n",
- " \"encoding_format\": \"float\",\n",
- " }\n",
- "\n",
- " splitter_config = {\"split_by\": \"word\", \"split_length\": 50, \"split_overlap\": 10}\n",
- "\n",
- " splitter = TextSplitter(**splitter_config)\n",
- " embedder = Embedder(\n",
- " model_client=ModelClientType.OPENAI(), model_kwargs=model_kwargs\n",
- " )\n",
- " embedder_transformer = ToEmbeddings(embedder, batch_size=2)\n",
- " data_transformer = Sequential(splitter, embedder_transformer)\n",
- " print(data_transformer)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "lightrag-project",
- "language": "python",
- "name": "light-rag-project"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/tutorials/rag/config.py b/tutorials/rag/config.py
new file mode 100644
index 00000000..2a6d383f
--- /dev/null
+++ b/tutorials/rag/config.py
@@ -0,0 +1,23 @@
+configs = {
+ "embedder": {
+ "batch_size": 100,
+ "model_kwargs": {
+ "model": "text-embedding-3-small",
+ "dimensions": 256,
+ "encoding_format": "float",
+ },
+ },
+ "retriever": {
+ "top_k": 2,
+ },
+ "generator": {
+ "model": "gpt-3.5-turbo",
+ "temperature": 0.3,
+ "stream": False,
+ },
+ "text_splitter": {
+ "split_by": "word",
+ "chunk_size": 400,
+ "chunk_overlap": 200,
+ },
+}
diff --git a/tutorials/rag/rag.py b/tutorials/rag/rag.py
new file mode 100644
index 00000000..ab248879
--- /dev/null
+++ b/tutorials/rag/rag.py
@@ -0,0 +1,105 @@
+from typing import Optional, Any, List
+
+import adalflow as adal
+from adalflow.core.db import LocalDB
+
+from adalflow.core.types import ModelClientType
+
+from adalflow.core.string_parser import JsonParser
+from adalflow.components.retriever.faiss_retriever import FAISSRetriever
+from adalflow.components.data_process import (
+ RetrieverOutputToContextStr,
+ ToEmbeddings,
+ TextSplitter,
+)
+
+from adalflow.components.model_client import OpenAIClient
+
+from tutorials.rag.config import configs
+
+
+def prepare_data_pipeline():
+ splitter = TextSplitter(**configs["text_splitter"])
+ embedder = adal.Embedder(
+ model_client=ModelClientType.OPENAI(),
+ model_kwargs=configs["embedder"]["model_kwargs"],
+ )
+ embedder_transformer = ToEmbeddings(
+ embedder=embedder, batch_size=configs["embedder"]["batch_size"]
+ )
+ data_transformer = adal.Sequential(
+ splitter, embedder_transformer
+ ) # sequential will chain together splitter and embedder
+ return data_transformer
+
+
+rag_prompt_task_desc = r"""
+You are a helpful assistant.
+
+Your task is to answer the query that may or may not come with context information.
+When context is provided, you should stick to the context and less on your prior knowledge to answer the query.
+
+Output JSON format:
+{
+ "answer": "The answer to the query",
+}"""
+
+
+class RAG(adal.Component):
+
+ def __init__(self, index_path: str = "index.faiss"):
+ super().__init__()
+
+ self.db = LocalDB.load_state(index_path)
+
+ self.transformed_docs: List[adal.Document] = self.db.get_transformed_data(
+ "data_transformer"
+ )
+ embedder = adal.Embedder(
+ model_client=ModelClientType.OPENAI(),
+ model_kwargs=configs["embedder"]["model_kwargs"],
+ )
+ # map the documents to embeddings
+ self.retriever = FAISSRetriever(
+ **configs["retriever"],
+ embedder=embedder,
+ documents=self.transformed_docs,
+ document_map_func=lambda doc: doc.vector,
+ )
+ self.retriever_output_processors = RetrieverOutputToContextStr(deduplicate=True)
+
+ self.generator = adal.Generator(
+ prompt_kwargs={
+ "task_desc_str": rag_prompt_task_desc,
+ },
+ model_client=OpenAIClient(),
+ model_kwargs=configs["generator"],
+ output_processors=JsonParser(),
+ )
+
+ def generate(self, query: str, context: Optional[str] = None) -> Any:
+ if not self.generator:
+ raise ValueError("Generator is not set")
+
+ prompt_kwargs = {
+ "context_str": context,
+ "input_str": query,
+ }
+ response = self.generator(prompt_kwargs=prompt_kwargs)
+ return response
+
+ def call(self, query: str) -> Any:
+ retrieved_documents = self.retriever(query)
+ # fill in the document
+ for i, retriever_output in enumerate(retrieved_documents):
+ retrieved_documents[i].documents = [
+ self.transformed_docs[doc_index]
+ for doc_index in retriever_output.doc_indices
+ ]
+
+ print(f"retrieved_documents: \n {retrieved_documents}\n")
+ context_str = self.retriever_output_processors(retrieved_documents)
+
+ print(f"context_str: \n {context_str}\n")
+
+ return self.generate(query, context=context_str), retrieved_documents