From 19d089e2da8479ec329d8c0d3990d98b366f815a Mon Sep 17 00:00:00 2001 From: mengliu1998 <604629@gmail.com> Date: Sat, 18 May 2024 10:55:07 -0700 Subject: [PATCH] Modularize ICL and clean up --- eval/__init__.py | 0 {core => eval}/evaluator.py | 0 icl/__init__.py | 0 icl/retrieval_icl.py | 58 +++++++++++++++++++ use_cases/{simple_icl.py => fewshot_qa.py} | 67 ++++++++-------------- use_cases/rag_hotpotqa.py | 6 +- 6 files changed, 84 insertions(+), 47 deletions(-) create mode 100644 eval/__init__.py rename {core => eval}/evaluator.py (100%) create mode 100644 icl/__init__.py create mode 100644 icl/retrieval_icl.py rename use_cases/{simple_icl.py => fewshot_qa.py} (61%) diff --git a/eval/__init__.py b/eval/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/core/evaluator.py b/eval/evaluator.py similarity index 100% rename from core/evaluator.py rename to eval/evaluator.py diff --git a/icl/__init__.py b/icl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/icl/retrieval_icl.py b/icl/retrieval_icl.py new file mode 100644 index 00000000..60c25191 --- /dev/null +++ b/icl/retrieval_icl.py @@ -0,0 +1,58 @@ +from typing import List + +from core.data_classes import Document +from core.retriever import Retriever +from core.embedder import Embedder +from core.data_components import ( + RetrieverOutputToContextStr, + ToEmbeddings, +) +from core.db import LocalDocumentDB +from core.component import Component, Sequential +from core.document_splitter import DocumentSplitter +from core.functional import generate_component_key + + +class RetrievalICL(Component): + def __init__( + self, + retriever: Retriever, + retriever_output_processors: RetrieverOutputToContextStr, + text_splitter: DocumentSplitter, + vectorizer: Embedder, + db: LocalDocumentDB, + ): + super().__init__() + self.retriever = retriever + self.retriever_output_processors = retriever_output_processors + + self.text_splitter = text_splitter + self.vectorizer = vectorizer + self.data_transformer = Sequential( + self.text_splitter, + ToEmbeddings( + vectorizer=self.vectorizer, + ), + ) + self.data_transformer_key = generate_component_key(self.data_transformer) + self.db = db + + def build_index(self, documents: List[Document]): + self.db.load_documents(documents) + self.map_key = self.db.map_data() + print(f"map_key: {self.map_key}") + self.data_key = self.db.transform_data(self.data_transformer) + print(f"data_key: {self.data_key}") + self.transformed_documents = self.db.get_transformed_data(self.data_key) + self.retriever.build_index_from_documents(self.transformed_documents) + + def call(self, query: str, top_k: int) -> str: + retrieved_documents = self.retriever(query, top_k) + # fill in the document + for i, retriever_output in enumerate(retrieved_documents): + retrieved_documents[i].documents = [ + self.transformed_documents[doc_index] + for doc_index in retriever_output.doc_indexes + ] + example_str = self.retriever_output_processors(retrieved_documents) + return example_str diff --git a/use_cases/simple_icl.py b/use_cases/fewshot_qa.py similarity index 61% rename from use_cases/simple_icl.py rename to use_cases/fewshot_qa.py index 785e3083..665be1a9 100644 --- a/use_cases/simple_icl.py +++ b/use_cases/fewshot_qa.py @@ -1,7 +1,3 @@ -""" -We just need a very basic generator that can be used to generate text from a prompt. -""" - from typing import List from core.generator import Generator @@ -12,19 +8,18 @@ from core.data_components import ( ToEmbedderResponse, RetrieverOutputToContextStr, - ToEmbeddings, ) from core.db import LocalDocumentDB -from core.component import Component, Sequential +from core.component import Component from core.document_splitter import DocumentSplitter -from core.functional import generate_component_key +from icl.retrieval_icl import RetrievalICL import dotenv dotenv.load_dotenv(dotenv_path=".env", override=True) -class SimpleICL(Component): +class FewshotQA(Component): def __init__(self, task_desc: str): super().__init__() self.vectorizer_settings = { @@ -50,6 +45,7 @@ def __init__(self, task_desc: str): model_kwargs=model_kwargs, preset_prompt_kwargs=preset_prompt_kwargs, ) + text_splitter = DocumentSplitter( split_by=self.text_splitter_settings["split_by"], split_length=self.text_splitter_settings["chunk_size"], @@ -57,48 +53,30 @@ def __init__(self, task_desc: str): ) vectorizer = Embedder( model_client=OpenAIClient(), - # batch_size=self.vectorizer_settings["batch_size"], model_kwargs=self.vectorizer_settings["model_kwargs"], output_processors=ToEmbedderResponse(), ) - self.data_transformer = Sequential( - text_splitter, - ToEmbeddings( - vectorizer=vectorizer, - batch_size=self.vectorizer_settings["batch_size"], - ), - ) - self.data_transformer_key = generate_component_key(self.data_transformer) - self.retriever_icl = FAISSRetriever( + self.retriever = FAISSRetriever( top_k=self.retriever_settings["top_k"], dimensions=self.vectorizer_settings["model_kwargs"]["dimensions"], vectorizer=vectorizer, ) - self.retriever_output_processors = RetrieverOutputToContextStr(deduplicate=True) - self.db_icl = LocalDocumentDB() - def build_index(self, documents: List[Document]): - self.db_icl.load_documents(documents) - self.map_key = self.db_icl.map_data() - print(f"map_key: {self.map_key}") - self.data_key = self.db_icl.transform_data(self.data_transformer) - print(f"data_key: {self.data_key}") - self.transformed_documents = self.db_icl.get_transformed_data(self.data_key) - self.retriever_icl.build_index_from_documents(self.transformed_documents) + self.retriever_output_processors = RetrieverOutputToContextStr(deduplicate=True) + self.db = LocalDocumentDB() + self.retrieval_icl = RetrievalICL( + retriever=self.retriever, + retriever_output_processors=self.retriever_output_processors, + text_splitter=text_splitter, + vectorizer=vectorizer, + db=self.db, + ) - def get_few_shot_example_str(self, query: str, top_k: int) -> str: - retrieved_documents = self.retriever_icl(query, top_k) - # fill in the document - for i, retriever_output in enumerate(retrieved_documents): - retrieved_documents[i].documents = [ - self.transformed_documents[doc_index] - for doc_index in retriever_output.doc_indexes - ] - example_str = self.retriever_output_processors(retrieved_documents) - return example_str + def build_icl_index(self, documents: List[Document]): + self.retrieval_icl.build_index(documents) - def call(self, task_desc: str, query: str, top_k: int) -> str: - example_str = self.get_few_shot_example_str(query, top_k=top_k) + def call(self, query: str, top_k: int) -> str: + example_str = self.retrieval_icl(query, top_k=top_k) return ( self.generator.call( input=query, @@ -124,13 +102,14 @@ def call(self, task_desc: str, query: str, top_k: int) -> str: text="Review: The store is not clean and smells bad. Sentiment: Negative", ) - simple_icl = SimpleICL(task_desc) - print(simple_icl) - simple_icl.build_index([example1, example2, example3, example4]) + fewshot_qa = FewshotQA(task_desc) + # build the index for the retriever-based ICL + fewshot_qa.build_icl_index([example1, example2, example3, example4]) + print(fewshot_qa) query = ( "Review: The concert was a lot of fun and the band was energetic and engaging." ) # tok_k: how many examples you want retrieve to show to the model - response, example_str = simple_icl.call(task_desc, query, top_k=2) + response, example_str = fewshot_qa(query, top_k=2) print(f"response: {response}") print(f"example_str: {example_str}") diff --git a/use_cases/rag_hotpotqa.py b/use_cases/rag_hotpotqa.py index 7df61668..7fc43cd1 100644 --- a/use_cases/rag_hotpotqa.py +++ b/use_cases/rag_hotpotqa.py @@ -10,7 +10,7 @@ from core.string_parser import JsonParser from core.component import Sequential -from core.evaluator import ( +from eval.evaluator import ( RetrieverEvaluator, AnswerMacthEvaluator, LLMasJudge, @@ -43,10 +43,10 @@ def get_supporting_sentences( settings = yaml.safe_load(file) print(settings) - # Load the dataset and select the first 10 as the showcase + # Load the dataset and select the first 5 as the showcase # More info about the HotpotQA dataset can be found at https://huggingface.co/datasets/hotpot_qa dataset = load_dataset(path="hotpot_qa", name="fullwiki") - dataset = dataset["train"].select(range(3)) + dataset = dataset["train"].select(range(5)) all_questions = [] all_retrieved_context = []