Skip to content

Commit

Permalink
before reorg the code
Browse files Browse the repository at this point in the history
  • Loading branch information
liyin2015 committed May 23, 2024
1 parent e7da017 commit d339567
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 29 deletions.
2 changes: 2 additions & 0 deletions components/api_client/anthropic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, api_key: Optional[str] = None):
self._api_key = api_key
self.sync_client = self._init_sync_client()
self.async_client = None # only initialize if the async call is called
self.tested_llm_models = ["claude-3-opus-20240229"]

def _init_sync_client(self):
api_key = self._api_key or os.getenv("ANTHROPIC_API_KEY")
Expand Down Expand Up @@ -57,6 +58,7 @@ def convert_input_to_api_kwargs(
api_kwargs["messages"] = [
{"role": "user", "content": input},
]
# api_kwargs["messages"] = [{"role": "system", "content": system_input}]
if system_input and system_input != "":
api_kwargs["system"] = system_input
else:
Expand Down
4 changes: 2 additions & 2 deletions components/api_client/google_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def convert_input_to_api_kwargs(
assert isinstance(input, Sequence), "input must be a sequence of text"
final_model_kwargs["input"] = input
elif model_type == ModelType.LLM:
prompt: str = f"{system_input}\n\nUser query: {input}\n You:"
# prompt: str = f"{system_input}\n\nUser query: {input}\n You:"

final_model_kwargs["prompt"] = prompt
final_model_kwargs["prompt"] = system_input
else:
raise ValueError(f"model_type {model_type} is not supported")
return final_model_kwargs
Expand Down
8 changes: 4 additions & 4 deletions components/api_client/groq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def __init__(self, api_key: Optional[str] = None):
"developer": "Meta",
"context_size": "8192",
},
"llama2-70b-4096": {
"developer": "Meta",
"context_size": "4096",
},
# "llama2-70b-4096": {
# "developer": "Meta",
# "context_size": "4096",
# },
"mixtral-8x7b-32768": {
"developer": "Mistral",
"context_size": "32768",
Expand Down
10 changes: 10 additions & 0 deletions core/default_prompt_template.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
LIGHTRAG_DEFAULT_PROMPT_ARGS = [
"task_desc_str",
"output_format_str",
"tools_str",
"examples_str",
"chat_history_str",
"context_str",
"steps_str",
]

DEFAULT_LIGHTRAG_SYSTEM_PROMPT = r"""{# task desc #}
{% if task_desc_str %}
{{task_desc_str}}
Expand Down
1 change: 1 addition & 0 deletions prompts/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
r"""
Args:
data_class_for_yaml (Type): The dataclass to extract the schema for the YAML output.
example (Type, optional): The example dataclass object to show in the prompt. Defaults to None.
yaml_output_format_template (str, optional): The template for the YAML output format. Defaults to YAML_OUTPUT_FORMAT.
output_processors (Component, optional): The output processors to parse the YAML string to JSON object. Defaults to YAMLParser().
"""
Expand Down
48 changes: 48 additions & 0 deletions use_cases/classification/optimized_cot.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"prog": {
"lm": null,
"traces": [],
"train": [],
"demos": [
{
"augmented": true,
"question": "How many miles is it from NY to Austria ?",
"rationale": "produce the class_index. We see that the question is asking for a numeric value, specifically the distance in miles between two locations.",
"class_name": "NUM",
"class_index": "5"
},
{
"augmented": true,
"question": "Where is Los Vegas ?",
"rationale": "produce the class_index. We see that the question is asking for the location of a place, which falls under the LOC class.",
"class_name": "LOC",
"class_index": "4"
},
{
"augmented": true,
"question": "Who was Sherlock Holmes 's archenemy ?",
"rationale": "produce the class_index. We are looking for a specific person who is an archenemy of Sherlock Holmes, which falls under the category of human beings.",
"class_name": "HUM",
"class_index": "3"
},
{
"augmented": true,
"question": "Who was Shakespeare 's Moorish general ?",
"rationale": "produce the class_index. We are looking for a specific person related to Shakespeare's works, so this question falls under the HUM class.",
"class_name": "HUM",
"class_index": "3"
},
{
"augmented": true,
"question": "What type of exercise burns the most calories ?",
"rationale": "produce the class_index. We can see that the question is asking for a type of exercise, which falls under the category of a specific entity or concept related to physical activity.",
"class_name": "Entity",
"class_index": "1"
}
],
"signature_instructions": "You are a classifier. Given a Question, you need to classify it into one of the following classes:\n Format: class_index. class_name, class_description\n 0. ABBR, Abbreviation\n 1. ENTY, Entity\n 2. DESC, Description and abstract concept\n 3. HUM, Human being\n 4. LOC, Location\n 5. NUM, Numeric value",
"signature_prefix": "Class Index:",
"extended_signature_instructions": "You are a classifier. Given a Question, you need to classify it into one of the following classes:\n Format: class_index. class_name, class_description\n 0. ABBR, Abbreviation\n 1. ENTY, Entity\n 2. DESC, Description and abstract concept\n 3. HUM, Human being\n 4. LOC, Location\n 5. NUM, Numeric value",
"extended_signature_prefix": "Class Index:"
}
}
51 changes: 41 additions & 10 deletions use_cases/classification/task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from core.component import Component, Sequential
from core.generator import Generator
from components.api_client import GroqAPIClient, OpenAIClient
from components.api_client import (
GroqAPIClient,
OpenAIClient,
GoogleGenAIClient,
AnthropicAPIClient,
)
from core.prompt_builder import Prompt
from prompts.outputs import YAMLOutputParser
from core.string_parser import JsonParser
Expand Down Expand Up @@ -73,17 +78,43 @@ def __init__(
)
output_str = yaml_parser.format_instructions()
logger.debug(f"output_str: {output_str}")
groq_model_kwargs = {
"model": "mixtral-8x7b-32768", # "llama3-8b-8192", # "llama3-8b-8192",
"temperature": 0.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"n": 1,
}
openai_model_kwargs = {
"model": "gpt-4-turbo",
"temperature": 0.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"n": 1,
}
google_model_kwargs = {
"model": "gemini-1.5-pro-latest",
"temperature": 0.0,
"top_p": 1,
# "frequency_penalty": 0,
# "presence_penalty": 0,
# "n": 1,
}
anthropic_model_kwargs = {
"model": "claude-3-opus-20240229",
"temperature": 0.0,
"top_p": 1,
# "frequency_penalty": 0,
# "presence_penalty": 0,
# "n": 1,
"max_tokens": 1024,
}

self.generator = Generator(
model_client=OpenAIClient,
model_kwargs={
"model": "gpt-3.5-turbo",
"temperature": 0.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"n": 1,
},
model_client=GoogleGenAIClient,
model_kwargs=google_model_kwargs,
template=TEMPLATE,
preset_prompt_kwargs={
"task_desc_str": task_desc_str,
Expand Down
5 changes: 4 additions & 1 deletion use_cases/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def random_shots(self, shots: int) -> Sequence[str]:
return samples_str

def eval(self):
r"""
TODO: automatically tracking the average inference time
"""
responses = []
targets = []
num_invalid = 0
Expand Down Expand Up @@ -208,4 +211,4 @@ def train(self, shots: int) -> None:
num_classes=6, train_dataset=train_dataset, eval_dataset=eval_dataset
)
print(trainer)
trainer.train(0)
trainer.train(5)
20 changes: 8 additions & 12 deletions use_cases/classification/train_dspy.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
from use_cases.classification.task import TRECClassifier
from use_cases.classification.eval import ClassifierEvaluator
from core.component import Component
from use_cases.classification.data import (
TrecDataset,
ToSampleStr,
dataset,
_COARSE_LABELS_DESC,
_COARSE_LABELS,
_FINE_LABELS,
extract_class_label,
)
from torch.utils.data import DataLoader
import random
from use_cases.classification.task_dspy import TrecClassifier


from typing import Any, Optional, Sequence, Dict
from torch.utils.data.sampler import Sampler, SubsetRandomSampler, RandomSampler
from typing import Sequence


class ExampleOptimizer(Component):
Expand Down Expand Up @@ -179,13 +170,18 @@ def train(self):

# Set up the optimizer: we want to "bootstrap" (i.e., self-generate) 4-shot examples of our CoT program.
config = dict(
max_bootstrapped_demos=5, max_labeled_demos=5, num_candidate_programs=2
max_bootstrapped_demos=5, max_labeled_demos=5 # , num_candidate_programs=2
)
# Optimize! Use the `gsm8k_metric` here. In general, the metric is going to tell the optimizer how well it's doing.
teleprompter = BootstrapFewShotWithRandomSearch(metric=acc_metric, **config)
teleprompter = BootstrapFewShot(metric=acc_metric, **config)

optimized_cot = teleprompter.compile(CoT(), trainset=self.train_example_set)
print(optimized_cot)
# get current path
import os

path = os.getcwd() + "/use_cases/classification/optimized_cot.txt"
optimized_cot.save(path)
metrics = self.eval_baseline(optimized_cot)
print(metrics)
return metrics
Expand Down

0 comments on commit d339567

Please sign in to comment.