Skip to content

Commit

Permalink
Modularize ICL and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
mengliu1998 committed May 18, 2024
1 parent c487598 commit 19d089e
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 47 deletions.
Empty file added eval/__init__.py
Empty file.
File renamed without changes.
Empty file added icl/__init__.py
Empty file.
58 changes: 58 additions & 0 deletions icl/retrieval_icl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import List

from core.data_classes import Document
from core.retriever import Retriever
from core.embedder import Embedder
from core.data_components import (
RetrieverOutputToContextStr,
ToEmbeddings,
)
from core.db import LocalDocumentDB
from core.component import Component, Sequential
from core.document_splitter import DocumentSplitter
from core.functional import generate_component_key


class RetrievalICL(Component):
def __init__(
self,
retriever: Retriever,
retriever_output_processors: RetrieverOutputToContextStr,
text_splitter: DocumentSplitter,
vectorizer: Embedder,
db: LocalDocumentDB,
):
super().__init__()
self.retriever = retriever
self.retriever_output_processors = retriever_output_processors

self.text_splitter = text_splitter
self.vectorizer = vectorizer
self.data_transformer = Sequential(
self.text_splitter,
ToEmbeddings(
vectorizer=self.vectorizer,
),
)
self.data_transformer_key = generate_component_key(self.data_transformer)
self.db = db

def build_index(self, documents: List[Document]):
self.db.load_documents(documents)
self.map_key = self.db.map_data()
print(f"map_key: {self.map_key}")
self.data_key = self.db.transform_data(self.data_transformer)
print(f"data_key: {self.data_key}")
self.transformed_documents = self.db.get_transformed_data(self.data_key)
self.retriever.build_index_from_documents(self.transformed_documents)

def call(self, query: str, top_k: int) -> str:
retrieved_documents = self.retriever(query, top_k)
# fill in the document
for i, retriever_output in enumerate(retrieved_documents):
retrieved_documents[i].documents = [
self.transformed_documents[doc_index]
for doc_index in retriever_output.doc_indexes
]
example_str = self.retriever_output_processors(retrieved_documents)
return example_str
67 changes: 23 additions & 44 deletions use_cases/simple_icl.py → use_cases/fewshot_qa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
"""
We just need a very basic generator that can be used to generate text from a prompt.
"""

from typing import List

from core.generator import Generator
Expand All @@ -12,19 +8,18 @@
from core.data_components import (
ToEmbedderResponse,
RetrieverOutputToContextStr,
ToEmbeddings,
)
from core.db import LocalDocumentDB
from core.component import Component, Sequential
from core.component import Component
from core.document_splitter import DocumentSplitter
from core.functional import generate_component_key
from icl.retrieval_icl import RetrievalICL

import dotenv

dotenv.load_dotenv(dotenv_path=".env", override=True)


class SimpleICL(Component):
class FewshotQA(Component):
def __init__(self, task_desc: str):
super().__init__()
self.vectorizer_settings = {
Expand All @@ -50,55 +45,38 @@ def __init__(self, task_desc: str):
model_kwargs=model_kwargs,
preset_prompt_kwargs=preset_prompt_kwargs,
)

text_splitter = DocumentSplitter(
split_by=self.text_splitter_settings["split_by"],
split_length=self.text_splitter_settings["chunk_size"],
split_overlap=self.text_splitter_settings["chunk_overlap"],
)
vectorizer = Embedder(
model_client=OpenAIClient(),
# batch_size=self.vectorizer_settings["batch_size"],
model_kwargs=self.vectorizer_settings["model_kwargs"],
output_processors=ToEmbedderResponse(),
)
self.data_transformer = Sequential(
text_splitter,
ToEmbeddings(
vectorizer=vectorizer,
batch_size=self.vectorizer_settings["batch_size"],
),
)
self.data_transformer_key = generate_component_key(self.data_transformer)
self.retriever_icl = FAISSRetriever(
self.retriever = FAISSRetriever(
top_k=self.retriever_settings["top_k"],
dimensions=self.vectorizer_settings["model_kwargs"]["dimensions"],
vectorizer=vectorizer,
)
self.retriever_output_processors = RetrieverOutputToContextStr(deduplicate=True)
self.db_icl = LocalDocumentDB()

def build_index(self, documents: List[Document]):
self.db_icl.load_documents(documents)
self.map_key = self.db_icl.map_data()
print(f"map_key: {self.map_key}")
self.data_key = self.db_icl.transform_data(self.data_transformer)
print(f"data_key: {self.data_key}")
self.transformed_documents = self.db_icl.get_transformed_data(self.data_key)
self.retriever_icl.build_index_from_documents(self.transformed_documents)
self.retriever_output_processors = RetrieverOutputToContextStr(deduplicate=True)
self.db = LocalDocumentDB()
self.retrieval_icl = RetrievalICL(
retriever=self.retriever,
retriever_output_processors=self.retriever_output_processors,
text_splitter=text_splitter,
vectorizer=vectorizer,
db=self.db,
)

def get_few_shot_example_str(self, query: str, top_k: int) -> str:
retrieved_documents = self.retriever_icl(query, top_k)
# fill in the document
for i, retriever_output in enumerate(retrieved_documents):
retrieved_documents[i].documents = [
self.transformed_documents[doc_index]
for doc_index in retriever_output.doc_indexes
]
example_str = self.retriever_output_processors(retrieved_documents)
return example_str
def build_icl_index(self, documents: List[Document]):
self.retrieval_icl.build_index(documents)

def call(self, task_desc: str, query: str, top_k: int) -> str:
example_str = self.get_few_shot_example_str(query, top_k=top_k)
def call(self, query: str, top_k: int) -> str:
example_str = self.retrieval_icl(query, top_k=top_k)
return (
self.generator.call(
input=query,
Expand All @@ -124,13 +102,14 @@ def call(self, task_desc: str, query: str, top_k: int) -> str:
text="Review: The store is not clean and smells bad. Sentiment: Negative",
)

simple_icl = SimpleICL(task_desc)
print(simple_icl)
simple_icl.build_index([example1, example2, example3, example4])
fewshot_qa = FewshotQA(task_desc)
# build the index for the retriever-based ICL
fewshot_qa.build_icl_index([example1, example2, example3, example4])
print(fewshot_qa)
query = (
"Review: The concert was a lot of fun and the band was energetic and engaging."
)
# tok_k: how many examples you want retrieve to show to the model
response, example_str = simple_icl.call(task_desc, query, top_k=2)
response, example_str = fewshot_qa(query, top_k=2)
print(f"response: {response}")
print(f"example_str: {example_str}")
6 changes: 3 additions & 3 deletions use_cases/rag_hotpotqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from core.string_parser import JsonParser
from core.component import Sequential
from core.evaluator import (
from eval.evaluator import (
RetrieverEvaluator,
AnswerMacthEvaluator,
LLMasJudge,
Expand Down Expand Up @@ -43,10 +43,10 @@ def get_supporting_sentences(
settings = yaml.safe_load(file)
print(settings)

# Load the dataset and select the first 10 as the showcase
# Load the dataset and select the first 5 as the showcase
# More info about the HotpotQA dataset can be found at https://huggingface.co/datasets/hotpot_qa
dataset = load_dataset(path="hotpot_qa", name="fullwiki")
dataset = dataset["train"].select(range(3))
dataset = dataset["train"].select(range(5))

all_questions = []
all_retrieved_context = []
Expand Down

0 comments on commit 19d089e

Please sign in to comment.