From 27e8be0a51d11694074f47c8d510414af65537cb Mon Sep 17 00:00:00 2001 From: Li Yin Date: Tue, 10 Dec 2024 10:09:28 -0800 Subject: [PATCH] add rag in the tutorials code --- tutorials/rag.ipynb | 65 ------------------------- tutorials/rag/config.py | 23 +++++++++ tutorials/rag/rag.py | 105 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 65 deletions(-) delete mode 100644 tutorials/rag.ipynb create mode 100644 tutorials/rag/config.py create mode 100644 tutorials/rag/rag.py diff --git a/tutorials/rag.ipynb b/tutorials/rag.ipynb deleted file mode 100644 index b5163e51..00000000 --- a/tutorials/rag.ipynb +++ /dev/null @@ -1,65 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We have already seen how the RAG is implemented in data.\n", - "In this note, we will focus more on how to make each component more configurable, \n", - "espeically the data processing pipeline to help us with experiments where we will see how useful they are in benchmarking." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# the data pipeline and the backend data processing\n", - "from adalflow.core.embedder import Embedder\n", - "from adalflow.core.types import ModelClientType\n", - "from adalflow.components.data_process import TextSplitter, ToEmbeddings\n", - "from adalflow.core.container import Sequential\n", - "\n", - "\n", - "def prepare_data_pipeline():\n", - " model_kwargs = {\n", - " \"model\": \"text-embedding-3-small\",\n", - " \"dimensions\": 256,\n", - " \"encoding_format\": \"float\",\n", - " }\n", - "\n", - " splitter_config = {\"split_by\": \"word\", \"split_length\": 50, \"split_overlap\": 10}\n", - "\n", - " splitter = TextSplitter(**splitter_config)\n", - " embedder = Embedder(\n", - " model_client=ModelClientType.OPENAI(), model_kwargs=model_kwargs\n", - " )\n", - " embedder_transformer = ToEmbeddings(embedder, batch_size=2)\n", - " data_transformer = Sequential(splitter, embedder_transformer)\n", - " print(data_transformer)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "lightrag-project", - "language": "python", - "name": "light-rag-project" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.6" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tutorials/rag/config.py b/tutorials/rag/config.py new file mode 100644 index 00000000..2a6d383f --- /dev/null +++ b/tutorials/rag/config.py @@ -0,0 +1,23 @@ +configs = { + "embedder": { + "batch_size": 100, + "model_kwargs": { + "model": "text-embedding-3-small", + "dimensions": 256, + "encoding_format": "float", + }, + }, + "retriever": { + "top_k": 2, + }, + "generator": { + "model": "gpt-3.5-turbo", + "temperature": 0.3, + "stream": False, + }, + "text_splitter": { + "split_by": "word", + "chunk_size": 400, + "chunk_overlap": 200, + }, +} diff --git a/tutorials/rag/rag.py b/tutorials/rag/rag.py new file mode 100644 index 00000000..ab248879 --- /dev/null +++ b/tutorials/rag/rag.py @@ -0,0 +1,105 @@ +from typing import Optional, Any, List + +import adalflow as adal +from adalflow.core.db import LocalDB + +from adalflow.core.types import ModelClientType + +from adalflow.core.string_parser import JsonParser +from adalflow.components.retriever.faiss_retriever import FAISSRetriever +from adalflow.components.data_process import ( + RetrieverOutputToContextStr, + ToEmbeddings, + TextSplitter, +) + +from adalflow.components.model_client import OpenAIClient + +from tutorials.rag.config import configs + + +def prepare_data_pipeline(): + splitter = TextSplitter(**configs["text_splitter"]) + embedder = adal.Embedder( + model_client=ModelClientType.OPENAI(), + model_kwargs=configs["embedder"]["model_kwargs"], + ) + embedder_transformer = ToEmbeddings( + embedder=embedder, batch_size=configs["embedder"]["batch_size"] + ) + data_transformer = adal.Sequential( + splitter, embedder_transformer + ) # sequential will chain together splitter and embedder + return data_transformer + + +rag_prompt_task_desc = r""" +You are a helpful assistant. + +Your task is to answer the query that may or may not come with context information. +When context is provided, you should stick to the context and less on your prior knowledge to answer the query. + +Output JSON format: +{ + "answer": "The answer to the query", +}""" + + +class RAG(adal.Component): + + def __init__(self, index_path: str = "index.faiss"): + super().__init__() + + self.db = LocalDB.load_state(index_path) + + self.transformed_docs: List[adal.Document] = self.db.get_transformed_data( + "data_transformer" + ) + embedder = adal.Embedder( + model_client=ModelClientType.OPENAI(), + model_kwargs=configs["embedder"]["model_kwargs"], + ) + # map the documents to embeddings + self.retriever = FAISSRetriever( + **configs["retriever"], + embedder=embedder, + documents=self.transformed_docs, + document_map_func=lambda doc: doc.vector, + ) + self.retriever_output_processors = RetrieverOutputToContextStr(deduplicate=True) + + self.generator = adal.Generator( + prompt_kwargs={ + "task_desc_str": rag_prompt_task_desc, + }, + model_client=OpenAIClient(), + model_kwargs=configs["generator"], + output_processors=JsonParser(), + ) + + def generate(self, query: str, context: Optional[str] = None) -> Any: + if not self.generator: + raise ValueError("Generator is not set") + + prompt_kwargs = { + "context_str": context, + "input_str": query, + } + response = self.generator(prompt_kwargs=prompt_kwargs) + return response + + def call(self, query: str) -> Any: + retrieved_documents = self.retriever(query) + # fill in the document + for i, retriever_output in enumerate(retrieved_documents): + retrieved_documents[i].documents = [ + self.transformed_docs[doc_index] + for doc_index in retriever_output.doc_indices + ] + + print(f"retrieved_documents: \n {retrieved_documents}\n") + context_str = self.retriever_output_processors(retrieved_documents) + + print(f"context_str: \n {context_str}\n") + + return self.generate(query, context=context_str), retrieved_documents