Skip to content

Commit

Permalink
ensure when the generator fails in the middle of a map, use the raw_r…
Browse files Browse the repository at this point in the history
…esponse at least it is better than not having value especially be careful with None
  • Loading branch information
liyin2015 committed Dec 12, 2024
1 parent 0f96139 commit fbe3348
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 24 deletions.
22 changes: 13 additions & 9 deletions adalflow/adalflow/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
import re
import os
from pathlib import Path

from typing import Any, Dict, Optional, Union, Callable, Tuple, List
Expand Down Expand Up @@ -47,6 +48,8 @@

log = logging.getLogger(__name__)

DEBUG_MODE = os.environ.get("DEBUG_MODE", False)

PromptArgType = Dict[str, Union[str, Parameter]]


Expand Down Expand Up @@ -465,11 +468,11 @@ def forward(
unwrapped_prompt_kwargs[k] = v.map_to_successor(self)
else:
unwrapped_prompt_kwargs[k] = v

print(
f"unwrapped_prompt_kwargs: {unwrapped_prompt_kwargs}, model_kwargs: {model_kwargs}"
)
print(f"prompt template: {self.template}")
if DEBUG_MODE:
print(
f"unwrapped_prompt_kwargs: {unwrapped_prompt_kwargs}, model_kwargs: {model_kwargs}"
)
print(f"prompt template: {self.template}")

output: GeneratorOutputType = None
input_args = {}
Expand All @@ -478,10 +481,11 @@ def forward(
else:
if self.teacher_mode and not isinstance(self, BackwardEngine):
if not self._teacher:
print(
f"unwrapped_prompt_kwargs: {unwrapped_prompt_kwargs}, model_kwargs: {model_kwargs}"
)
print(f"names: {self.name}")
if DEBUG_MODE:
print(
f"unwrapped_prompt_kwargs: {unwrapped_prompt_kwargs}, model_kwargs: {model_kwargs}"
)
print(f"names: {self.name}")
raise ValueError("Teacher generator is not set.")
log.info(f"Using teacher: {self._teacher}")
input_args = {
Expand Down
2 changes: 2 additions & 0 deletions adalflow/adalflow/core/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def forward(
requires_opt=True,
param_type=ParameterType.HYPERPARAM,
)
if input is None:
raise ValueError("Input cannot be empty")
response = super().forward(input, top_k=top_k, **kwargs)
response.param_type = (
ParameterType.RETRIEVER_OUTPUT
Expand Down
13 changes: 12 additions & 1 deletion adalflow/adalflow/optim/few_shot/bootstrap_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,16 @@ def sample(
weighted: bool = True,
):
r"""Performs weighted sampling, ensure the score is in range [0, 1]. The higher score means better accuracy."""
# 1. sample from augmented demos
# 1. sample from augmented demos (from teacher)
# set weights to be score
# add 1 to all score to avoid negative weights
augmented_options = list(augmented_demos.values())

# get the teacher scores length and the augmented demos length
len_teacher_scores = len(self._teacher_scores)
len_augmented_options = len(augmented_options)
print(f"len_teacher_scores: {len_teacher_scores}")
print(f"len_augmented_options: {len_augmented_options}")
weights = None
if weighted:
weights: List[float] = []
Expand Down Expand Up @@ -229,6 +235,11 @@ def propose(self):
if demo_param.requires_opt:
augmented_demos = demo_param._traces
demos = demo_param._student_traces

if len(augmented_demos) != len(demos):
log.warning(
f"augmented and raw demos must have the same length, got {len(augmented_demos)} and {len(demos)} \n {augmented_demos} \n and student demos {demos}"
)
try:
sampled_augmented_demos, sampled_raw_demos = self.sample(
augmented_demos=augmented_demos,
Expand Down
6 changes: 6 additions & 0 deletions adalflow/adalflow/optim/trainer/adal.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,11 @@ def train_step(self, batch, batch_idx, num_workers: int = 2) -> List:
samples[i] = sample # Keep the sample order aligned
# check the ordering

if isinstance(y_pred, Parameter):
raise ValueError(f"y_pred_{i} is a Parameter, {y_pred}")

print(f"y_pred: {y_pred})")

assert (
y_pred.id == sample.id
), f"ID mismatch: {y_pred.id} != {sample.id}, type: {type(y_pred)}"
Expand Down Expand Up @@ -469,6 +474,7 @@ def validation_step(self, batch, batch_idx, num_workers: int = 2) -> List:
"""
# TODO: let use decide which mode to be
self.task.eval()
self.task.use_teacher(mode=False) # ensure the teacher is not used
completed_y_preds, completed_samples, index_to_score = self.pred_step(
batch, batch_idx, num_workers, running_eval=True, min_score=minimum_score
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"deprecated"
"""We will use dspy's retriever to keep that the same and only use our generator and optimizer"""

import dspy
Expand All @@ -22,9 +23,9 @@

def load_datasets():

trainset = HotPotQA(split="train", size=4) # 20
valset = HotPotQA(split="val", size=4) # 50
testset = HotPotQA(split="test", size=4) # to keep the same as the dspy #50
trainset = HotPotQA(split="train", size=20) # 20
valset = HotPotQA(split="val", size=50) # 50
testset = HotPotQA(split="test", size=50) # to keep the same as the dspy #50
print(f"trainset, valset: {len(trainset)}, {len(valset)}, example: {trainset[0]}")
return trainset, valset, testset

Expand Down
14 changes: 7 additions & 7 deletions benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from adalflow.core.retriever import Retriever

from benchmarks.hotpot_qa.adal_exp.build_vanilla_rag import DspyRetriever

from adalflow.utils.logger import printc

colbertv2_wiki17_abstracts = dspy.ColBERTv2(
url="http://20.102.90.50:2017/wiki17_abstracts"
Expand Down Expand Up @@ -71,7 +71,6 @@ def call(self, exisiting_list: List[str], new_list: List[str]) -> List[str]:
return [x for x in exisiting_list + new_list if not (x in seen or seen.add(x))]

def backward(self, *args, **kwargs):
from adalflow.utils.logger import printc

printc(f"DeduplicateList backward: {args}", "yellow")
return super().backward(*args, **kwargs)
Expand Down Expand Up @@ -297,6 +296,7 @@ def deduplicate(seq: list[str]) -> list[str]:
# TODO: simplify and avoid the need where users need to write two methods (call and forward)
def call(self, *, input: str, id: str = None) -> List[adal.RetrieverOutput]:
# assemble the foundamental building blocks
printc(f"question: {input}", "yellow")
out = self.forward(input=input, id=id)

if not isinstance(out, adal.Parameter):
Expand All @@ -306,8 +306,8 @@ def call(self, *, input: str, id: str = None) -> List[adal.RetrieverOutput]:

def forward(self, *, input: str, id: str = None) -> adal.Parameter:
# assemble the foundamental building blocks
printc(f"question: {input}", "yellow")
context = []
print(f"question: {input}")

queries: List[str] = []

Expand All @@ -326,7 +326,7 @@ def forward(self, *, input: str, id: str = None) -> adal.Parameter:
if x.full_response
and x.full_response.data
and x.full_response.data.query
else None
else x.data.raw_response if x.data and x.data.raw_response else None
)
print(f"query {i}: {success_map_fn(gen_out)}")

Expand All @@ -336,6 +336,9 @@ def forward(self, *, input: str, id: str = None) -> adal.Parameter:
successor=self.retrievers[i], map_fn=success_map_fn
)

if success_map_fn(gen_out) is None:
raise ValueError(f"The query is None, please check the generator {i}")

retrieve_out = self.retrievers[i].forward(input=gen_out, id=id)

def retrieve_out_map_fn(x: adal.Parameter):
Expand All @@ -362,14 +365,11 @@ def context_to_retrover_output(x):

context.data = context_to_retrover_output(context)

from adalflow.utils.logger import printc

printc(f"MultiHopRetriever2 grad fn: {context.grad_fn}", "yellow")

return context

def backward(self, *args, **kwargs):
from adalflow.utils.logger import printc

printc(f"MultiHopRetriever2 backward: {args}", "yellow")
super().backward(*args, **kwargs)
Expand Down
13 changes: 13 additions & 0 deletions benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def call(

k = top_k or self.top_k

if not input:
raise ValueError(f"Input cannot be empty, top_k: {k}")

output = self.dspy_retriever(query_or_queries=input, k=k)
# print(f"dsy_retriever output: {output}")
final_output: List[RetrieverOutput] = []
Expand Down Expand Up @@ -203,6 +206,16 @@ def call(self, question: str, id: str = None) -> adal.GeneratorOutput:
# print(f"retriever_out: {retriever_out}")
return output

# def call(self, *, question: str, id: str = None) -> adal.GeneratorOutput:
# self.train()
# out = self.forward(question=question, id=id)
# if not isinstance(out, adal.Parameter):
# raise ValueError(
# "This output should be a Parameter, please check the forward function"
# )
# self.eval()
# return out.data

# TODO: add id in the retriever output
def forward(self, question: str, id: str = None) -> adal.Parameter:
if not self.training:
Expand Down
9 changes: 6 additions & 3 deletions benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from adalflow.eval.answer_match_acc import AnswerMatchAcc
from adalflow.datasets.types import HotPotQAData

from benchmarks.hotpot_qa.adal_train import load_datasets
from benchmarks.hotpot_qa._adal_train import load_datasets
from benchmarks.hotpot_qa.adal_exp.build_multi_hop_rag import MultiHopRAG
from use_cases.config import gpt_3_model, gpt_4o_model

Expand Down Expand Up @@ -157,7 +157,10 @@ def train(
# train: 0.15 before the evaluator converted to lower and 0.4 after the conversion
# TODO: test debug mode
train(
debug=True,
max_steps=5,
debug=False,
max_steps=12,
# resume_from_ckpt="/Users/liyin/.adalflow/ckpt/ValinaRAGAdal/random_max_steps_12_7c091_run_1.json",
)

# notes for debug: if have nontype, delete all model cache and try again
# raise ValueError(ValueError: score must be provided for each demo,
2 changes: 1 addition & 1 deletion benchmarks/hotpot_qa/adal_exp/train_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from adalflow.eval.answer_match_acc import AnswerMatchAcc
from adalflow.datasets.types import HotPotQAData

from benchmarks.hotpot_qa.adal_train import load_datasets
from benchmarks.hotpot_qa._adal_train import load_datasets
from benchmarks.hotpot_qa.adal_exp.build_vanilla_rag import VanillaRAG
from use_cases.config import gpt_3_model, gpt_4o_model

Expand Down

0 comments on commit fbe3348

Please sign in to comment.