Skip to content

Commit

Permalink
fix the bug in hotpot qa
Browse files Browse the repository at this point in the history
  • Loading branch information
liyin2015 committed Jan 3, 2025
1 parent 540b161 commit f0328b0
Show file tree
Hide file tree
Showing 14 changed files with 251 additions and 50 deletions.
8 changes: 4 additions & 4 deletions adalflow/adalflow/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class BackwardPassSetup(DataClass):
metadata={"desc": "Threshold score to compute gradient for errors."},
)
compute_grad_for_errors_only: bool = field(
default=False, metadata={"desc": "Compute gradient for errors only."}
default=True, metadata={"desc": "Compute gradient for errors only."}
)


Expand Down Expand Up @@ -562,11 +562,11 @@ def forward(

def data_to_prompt_map_fn(data: Parameter) -> str:
data: GeneratorOutput = data.data
if data.data is not None:
return data.data
# if data.data is not None:
# return data.data
if data.error is not None:
return f"Response: {data.raw_response} parsed with error: {data.error}"
return f"Response: {data.raw_response}"
return f" {data.raw_response}"

# TODO: all parameter should just wrap the whole output.
# this is for training.
Expand Down
8 changes: 4 additions & 4 deletions adalflow/adalflow/datasets/big_bench_hard.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BigBenchHard(Dataset):
Size for each split:
- train: 50 examples
- val: 50 examples
- val: 100 examples
- test: 100 examples
Args:
Expand Down Expand Up @@ -120,11 +120,11 @@ def _check_or_download_dataset(self, data_path: str = None, split: str = "train"
]
val_examples = [
{"x": ex["input"], "y": ex["target"], "id": str(uuid.uuid4())}
for ex in examples[50:100]
for ex in examples[50:150]
]
test_examples = [
{"x": ex["input"], "y": ex["target"], "id": str(uuid.uuid4())}
for ex in examples[150:250]
for ex in examples[150:]
]
# ensure the

Expand All @@ -150,7 +150,7 @@ def get_default_task_instruction():
if __name__ == "__main__":
from adalflow.datasets.big_bench_hard import BigBenchHard

dataset = BigBenchHard(task_name="word_sorting", split="train")
dataset = BigBenchHard(task_name="object_counting", split="test")
print(dataset[0:10])
print(len(dataset))
print(dataset.get_default_task_instruction())
30 changes: 24 additions & 6 deletions adalflow/adalflow/datasets/hotpot_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ def __init__(
size: int = None,
**kwargs,
) -> None:
r"""
official_train: 15661
sampled_trainset: 11745
sampled_valset: 3916
test: 7405
You can specify the size of the dataset to load by setting the size parameter.
"""
if split not in ["train", "val", "test"]:
raise ValueError("Split must be one of 'train', 'val', 'test'")

Expand All @@ -37,6 +45,7 @@ def __init__(
data_path = prepare_dataset_path(self.root, self.task_name)
# download and save
split_csv_path = os.path.join(data_path, f"{split}.csv")
print(f"split_csv_path: {split_csv_path}")
self._check_or_download_dataset(
split_csv_path, split, only_hard_examples, keep_details
)
Expand Down Expand Up @@ -103,7 +112,7 @@ def _check_or_download_dataset(
elif keep_details == "dev_titles":
keys = ["id", "question", "answer", "supporting_facts"]

official_train = []
official_train = [] # 15661
for raw_example in hf_official_train:
if raw_example["level"] == "hard":
example = {k: raw_example[k] for k in keys}
Expand All @@ -113,21 +122,25 @@ def _check_or_download_dataset(
del example["supporting_facts"]

official_train.append(example)
print(f"official_train: {len(official_train)}")

rng = random.Random(0)
rng.shuffle(official_train)

sampled_trainset = official_train[: len(official_train) * 75 // 100]
sampled_trainset = official_train[: len(official_train) * 75 // 100] # 11745
print(f"sampled_trainset: {len(sampled_trainset)}")

sampled_valset = official_train[
sampled_valset = official_train[ # 3916
len(official_train) * 75 // 100 :
] # this is not the official dev set

print(f"sampled_valset: {len(sampled_valset)}")

# for example in self._train:
# if keep_details == "dev_titles":
# del example["gold_titles"]

test = []
test = [] # 7405
for raw_example in hf_official_dev:
assert raw_example["level"] == "hard"
example = {
Expand All @@ -140,20 +153,25 @@ def _check_or_download_dataset(
test.append(example)

keys = ["id", "question", "answer", "gold_titles"]
data_path_dir = os.path.dirname(data_path)
# save to csv
for split, examples in zip(
["train", "val", "test"],
[sampled_trainset, sampled_valset, test],
):
# target_path = prepare_dataset_path(self.root, task_name, split)
save_csv(examples, f=data_path, fieldnames=keys)
target_path = os.path.join(data_path_dir, f"{split}.csv")
save_csv(examples, f=target_path, fieldnames=keys)
print(f"saved {split} to {target_path}")

if split == "train":
return sampled_trainset
elif split == "val":
return sampled_valset
else:
elif split == "test":
return test
else:
raise ValueError("Split must be one of 'train', 'val', 'test'")

def __getitem__(self, index) -> DataClass:
return self.data[index]
Expand Down
4 changes: 2 additions & 2 deletions adalflow/adalflow/datasets/trec.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def prepare_datasets():
len_train_dataset = len(org_train_dataset)

org_test_dataset = dataset["test"]
eval_size = 6 * num_classes
eval_size = 18 * num_classes

class_sampler = ClassSampler(
org_train_dataset.select(
Expand Down Expand Up @@ -100,7 +100,7 @@ def prepare_datasets():
labels = torch.tensor(org_test_dataset["coarse_label"])
class_weights = calculate_class_weights(labels)

test_size = eval_size * 4
test_size = eval_size * 2
# weighted sampling on the test dataset
test_dataset_split = sample_subset_dataset(
org_test_dataset, test_size, class_weights
Expand Down
4 changes: 2 additions & 2 deletions adalflow/adalflow/optim/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def __init__(
self.successor_map_fn = successor_map_fn or {}

def default_prompt_map_fn(param: Parameter):
# if isinstance(param.data, GeneratorOutput):
# return param.data.raw_response
return param.data

self.data_in_prompt = data_in_prompt or default_prompt_map_fn
Expand Down Expand Up @@ -354,7 +356,6 @@ def get_gradient_and_context_text(self, skip_correct_sample: bool = False) -> st
if g.score > 0.5:
continue
lowest_score_gradients.append(g)
print(f"{i} Score: {g.score} for {g.name}, {type(g.score)}")

gradient_context_combined_str = ""
if lowest_score_gradients and len(lowest_score_gradients) > 0:
Expand Down Expand Up @@ -450,7 +451,6 @@ def get_gradients_component_schema(self, skip_correct_sample: bool = False) -> s
if g.score > 0.5:
continue
lowest_score_gradients.append(g)
print(f"{i} Score: {g.score} for {g.name}, {type(g.score)}")

# Group gradients by `data_id` and calculate average scores
grouped_gradients = defaultdict(
Expand Down
3 changes: 2 additions & 1 deletion adalflow/adalflow/optim/text_grad/backend_engine_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
2. From <CONVERSATION></CONVERSATION> section, you can find how the variable is obtained and used.
3. The variable might have other peers that are used together to instruct the language model. But only focus on the target variable.
4. As there might be peers, and multi-components, it is possible that the feedback/error is not directly related to the variable itself.
In such cases, you can just say "There is no noticeable error".
5. When you reason, really think about the variable's role in the component(infer from the CONVERSATION section) and the VARIABLE section before you provide feedback.
6. Be specific, concise, critical, and direct.
Expand All @@ -41,6 +40,8 @@
##############################################
# Loss Component
##############################################
# In such cases, you can just say "There is no noticeable error".

# 2. Feedback examples: "Since language models have the X failure mode...", "Adding X can fix this error because...", "Removing X can improve the objective function because...", "Changing X to Y would fix the mistake..."

# Objective instruction for LLM as gradComponent with user custom instruction
Expand Down
64 changes: 57 additions & 7 deletions benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ def prepare_loss(self, sample: HotPotQAData, pred: adal.Parameter):

# pred's full_response is the output of the task pipeline which is GeneratorOutput
pred.eval_input = (
pred.full_response.data.answer
if pred.full_response
and pred.full_response.data
and pred.full_response.data.answer
pred.data.data.answer
if pred.data and pred.data.data and pred.data.data.answer
else ""
)
return self.loss_fn, {"kwargs": {"y": pred, "y_gt": y_gt}}
return self.loss_fn, {"kwargs": {"y": pred, "y_gt": y_gt}, "id": sample.id}


from adalflow.core.generator import BackwardPassSetup


# Note: diagnose is quite helpful, it helps you to quickly check if the evalfunction is the right metrics
Expand Down Expand Up @@ -110,13 +111,22 @@ def train(
debug=False,
resume_from_ckpt=None,
exclude_input_fields_from_bootstrap_demos=True,
seed=None,
tg: bool = False,
max_proposals_per_step: int = 5,
):
adal_component = MultiHopRAGAdal(
**gpt_3_model,
teacher_model_config=gpt_3_model,
text_optimizer_model_config=gpt_4o_model, # gpt3.5 is not enough to be used as a good optimizer, it struggles for long contenxt
backward_engine_model_config=gpt_4o_model,
)
backward_pass_setup = None
if tg:
backward_pass_setup = BackwardPassSetup(
all_pred_at_once=False,
compute_grad_for_errors_only=False,
)
# print(adal_component)
trainer = adal.Trainer(
train_batch_size=train_batch_size,
Expand All @@ -131,16 +141,20 @@ def train(
optimization_order=optimization_order,
exclude_input_fields_from_bootstrap_demos=exclude_input_fields_from_bootstrap_demos,
sequential_order=["text", "demo"],
max_proposals_per_step=max_proposals_per_step,
backward_pass_setup=backward_pass_setup,
)
trainer.set_random_seed(seed)
print(trainer)

train_dataset, val_dataset, test_dataset = load_datasets()
trainer.fit(
ckpt, _ = trainer.fit(
train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
resume_from_ckpt=resume_from_ckpt,
)
return ckpt


if __name__ == "__main__":
Expand All @@ -150,17 +164,53 @@ def train(

adal.setup_env()

import json

import random

random.seed(2025)

adal.setup_env()

import argparse

parser = argparse.ArgumentParser()

parser.add_argument("--strategy", type=str, default="constrained")
parser.add_argument("--use_tg", action="store_true")
parser.add_argument("--max_proposals_per_step", type=int, default=5)
parser.add_argument(
"output_path", nargs="?", help="File path to save the checkpoint"
)

args = parser.parse_args()

set_strategy = args.strategy
set_output_path = args.output_path
use_tg = args.use_tg
max_proposals_per_step = args.max_proposals_per_step

# task = MultiHopRAGAdal(**gpt_3_model)
# print(task)

# train_diagnose(**gpt_3_model)

# train: 0.15 before the evaluator converted to lower and 0.4 after the conversion
train(
ckpt = train(
debug=True,
max_steps=12,
seed=2025, # pass the numpy seed
tg=use_tg,
max_proposals_per_step=max_proposals_per_step,
# resume_from_ckpt="/Users/liyin/.adalflow/ckpt/ValinaRAGAdal/random_max_steps_12_7c091_run_1.json",
)
print(f"ckpt: {ckpt}")
if set_output_path:
with open(set_output_path, "w") as f:
json.dump({"ckpt": ckpt}, f)
print(f"Checkpoint saved to {set_output_path}")
else:
print("No file path provided for saving the checkpoint.")

# notes for debug: if have nontype, delete all model cache and try again
# raise ValueError(ValueError: score must be provided for each demo,
Expand Down
Loading

0 comments on commit f0328b0

Please sign in to comment.