Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
mengliu1998 committed May 18, 2024
1 parent 2d78c12 commit c487598
Showing 1 changed file with 12 additions and 18 deletions.
30 changes: 12 additions & 18 deletions use_cases/simple_icl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,18 @@
ToEmbeddings,
)
from core.db import LocalDocumentDB

from core.component import Component

# TODO: make the environment variable loading more robust, and let users specify the .env path
import dotenv


from core.component import Component, Sequential
from core.document_splitter import DocumentSplitter
from core.component import Sequential

from core.functional import generate_component_key

dotenv.load_dotenv()
import dotenv

dotenv.load_dotenv(dotenv_path=".env", override=True)


class SimpleICL(Component):
def __init__(self, task_desc: str):
super().__init__()
model_kwargs = {"model": "gpt-3.5-turbo"}
preset_prompt_kwargs = {"task_desc_str": task_desc}
self.vectorizer_settings = {
"batch_size": 100,
"model_kwargs": {
Expand All @@ -51,12 +43,13 @@ def __init__(self, task_desc: str):
"chunk_size": 400,
"chunk_overlap": 200,
}
model_kwargs = {"model": "gpt-3.5-turbo"}
preset_prompt_kwargs = {"task_desc_str": task_desc}
self.generator = Generator(
model_client=OpenAIClient(),
model_kwargs=model_kwargs,
preset_prompt_kwargs=preset_prompt_kwargs,
)
self.generator.print_prompt()
text_splitter = DocumentSplitter(
split_by=self.text_splitter_settings["split_by"],
split_length=self.text_splitter_settings["chunk_size"],
Expand Down Expand Up @@ -93,7 +86,6 @@ def build_index(self, documents: List[Document]):
self.transformed_documents = self.db_icl.get_transformed_data(self.data_key)
self.retriever_icl.build_index_from_documents(self.transformed_documents)

### TODO: use retriever to get the few shot
def get_few_shot_example_str(self, query: str, top_k: int) -> str:
retrieved_documents = self.retriever_icl(query, top_k)
# fill in the document
Expand All @@ -102,13 +94,11 @@ def get_few_shot_example_str(self, query: str, top_k: int) -> str:
self.transformed_documents[doc_index]
for doc_index in retriever_output.doc_indexes
]
# convert all the documents to context string

example_str = self.retriever_output_processors(retrieved_documents)
return example_str

def call(self, task_desc: str, query: str, top_k: int) -> str:
example_str = self.get_few_shot_example_str(query, top_k=2)
example_str = self.get_few_shot_example_str(query, top_k=top_k)
return (
self.generator.call(
input=query,
Expand All @@ -130,13 +120,17 @@ def call(self, task_desc: str, query: str, top_k: int) -> str:
example3 = Document(
text="Review: What a fantastic movie! Had a great time and would watch it again! Sentiment: Positive",
)
example4 = Document(
text="Review: The store is not clean and smells bad. Sentiment: Negative",
)

simple_icl = SimpleICL(task_desc)
print(simple_icl)
simple_icl.build_index([example1, example2, example3])
simple_icl.build_index([example1, example2, example3, example4])
query = (
"Review: The concert was a lot of fun and the band was energetic and engaging."
)
# tok_k: how many examples you want retrieve to show to the model
response, example_str = simple_icl.call(task_desc, query, top_k=2)
print(f"response: {response}")
print(f"example_str: {example_str}")

0 comments on commit c487598

Please sign in to comment.