diff --git a/adalflow/adalflow/components/agent/react.py b/adalflow/adalflow/components/agent/react.py index abe0d78d..de4a8a12 100644 --- a/adalflow/adalflow/components/agent/react.py +++ b/adalflow/adalflow/components/agent/react.py @@ -49,8 +49,14 @@ """ +# - In this case, you are working as a multi-hop retriever and your answer in finish MUST be verbatim short factoid responses from retrieved context. +# - Answer with only the exact answer phrase, not a full sentence or paragraph. + DEFAULT_REACT_AGENT_SYSTEM_PROMPT = r""" {{react_agent_task_desc}} + +- You have a maximum of {{max_steps}} steps to complete the task. Plan your steps carefully. + {# Tools #} {% if tools %} @@ -95,6 +101,8 @@ "Action": "{{history.action.action}}", {% endif %} "Observation": "{{history.observation}}" + +Current Step/Max Step: {{step_history|length + 1}} / {{max_steps}} ------------------------ {% endfor %} @@ -114,6 +122,14 @@ def map_step_history_to_prompt(x: Parameter) -> str: return "\n".join(output) +def map_step_history_list_to_prompt(x: Parameter) -> str: + output = [] + for i, step in enumerate(x.data.step_history): + step_str = f"Step {i + 1}.\n" + output.append(step_str + step.to_prompt_str()) + return "\n".join(output) + + class AppendStepHistory(GradComponent): def __init__(self): super().__init__() @@ -283,6 +299,7 @@ def __init__( requires_opt=True, ), "context_variables": self.context_variables, + "max_steps": self.max_steps, } self.planner = Generator( template=template, @@ -326,7 +343,7 @@ def llm_tool(input: str, **kwargs) -> str: return None def finish(answer: str, **kwargs) -> str: - """Finish the task with answer.""" + """Finish the task with verbatim short factoid responses from retrieved context.""" return answer self._finish = finish @@ -676,8 +693,16 @@ def _get_answer( return None last_step: StepOutput = None - if isinstance(step_history, Parameter): + if isinstance( + step_history, Parameter + ): # change the step history at the last step try: + output = ReActOutput( + step_history=step_history.data, + answer=str(step_history.data[-1].observation), + ) + step_history.data = output + step_history.data_in_prompt = map_step_history_list_to_prompt return step_history except Exception as e: @@ -687,7 +712,7 @@ def _get_answer( last_step = step_history[-1] # printc(f"last_step: {last_step}", color="yellow") - return last_step.observation + return str(last_step.observation) def call(self, *args, **kwargs) -> ReActOutput: output = self.bicall(*args, **kwargs) diff --git a/adalflow/adalflow/core/generator.py b/adalflow/adalflow/core/generator.py index b40c34f1..5636b78b 100644 --- a/adalflow/adalflow/core/generator.py +++ b/adalflow/adalflow/core/generator.py @@ -67,7 +67,7 @@ @dataclass class BackwardPassSetup(DataClass): all_pred_at_once: bool = field( - default=True, metadata={"desc": "Backward all predecessors at once."} + default=False, metadata={"desc": "Backward all predecessors at once."} ) threshold_score_to_compute_grad_for_errors: float = field( default=0.9, diff --git a/adalflow/adalflow/core/retriever.py b/adalflow/adalflow/core/retriever.py index a6ed6e39..3778fdd8 100644 --- a/adalflow/adalflow/core/retriever.py +++ b/adalflow/adalflow/core/retriever.py @@ -135,7 +135,7 @@ def forward( ) response.trace_forward_pass( input_args={"input": input, "top_k": top_k}, - full_response=response, + full_response=response.data, id=self.id, name=self.name, ) diff --git a/adalflow/adalflow/core/types.py b/adalflow/adalflow/core/types.py index eceac117..0ef80390 100644 --- a/adalflow/adalflow/core/types.py +++ b/adalflow/adalflow/core/types.py @@ -281,19 +281,25 @@ class RetrieverOutput(DataClass): It is up to the subclass of Retriever to specify the type of query and document. """ - doc_indices: List[int] = field(metadata={"desc": "List of document indices"}) - doc_scores: Optional[List[float]] = field( + id: str = field(default=None, metadata={"desc": "The unique id of the output"}) + + doc_indices: List[int] = field( + default=required_field, metadata={"desc": "List of document indices"} + ) + doc_scores: List[float] = field( default=None, metadata={"desc": "List of document scores"} ) - query: Optional[RetrieverQueryType] = field( + query: RetrieverQueryType = field( default=None, metadata={"desc": "The query used to retrieve the documents"} ) - documents: Optional[List[RetrieverDocumentType]] = field( + documents: List[RetrieverDocumentType] = field( default=None, metadata={"desc": "List of retrieved documents"} ) -RetrieverOutputType = List[RetrieverOutput] # so to support multiple queries at once +RetrieverOutputType = Union[ + List[RetrieverOutput], RetrieverOutput +] # so to support multiple queries at once ####################################################################################### diff --git a/adalflow/adalflow/datasets/hotpot_qa.py b/adalflow/adalflow/datasets/hotpot_qa.py index c91d9d06..22919e77 100644 --- a/adalflow/adalflow/datasets/hotpot_qa.py +++ b/adalflow/adalflow/datasets/hotpot_qa.py @@ -1,6 +1,6 @@ import random import os -from typing import Literal +from typing import Literal, List from adalflow.utils.lazy_import import safe_import, OptionalPackages @@ -198,10 +198,11 @@ def _check_or_download_dataset( # target_path = prepare_dataset_path(self.root, task_name, split) target_path = os.path.join(data_path_dir, f"{split}.json") # filter the examples with only the keys - save_examples = [] + save_examples: List[HotPotQAData] = [] for example in examples: save_example = {k: example[k] for k in keys if k in example} - save_examples.append(save_example) + save_example = HotPotQAData.from_dict(save_example) + save_examples.append(save_example.to_dict()) save_json(save_examples, f=target_path) if split == "train": print(f"train example: {examples[0]}") diff --git a/adalflow/adalflow/datasets/types.py b/adalflow/adalflow/datasets/types.py index bd7ee4dd..99878902 100644 --- a/adalflow/adalflow/datasets/types.py +++ b/adalflow/adalflow/datasets/types.py @@ -2,8 +2,6 @@ from dataclasses import dataclass, field from typing import Dict from adalflow.core.base_data_class import DataClass -import json -from typing import Any @dataclass @@ -40,26 +38,26 @@ class HotPotQAData(Example): default=None, ) - @staticmethod - def from_dict(d: Dict[str, Any]) -> "HotPotQAData": - # Preprocess gold_titles - if "gold_titles" in d and isinstance(d["gold_titles"], str): - try: - d["gold_titles"] = json.loads(d["gold_titles"]) - except json.JSONDecodeError: - # Replace single quotes with double quotes - fixed_str = d["gold_titles"].replace("'", '"') - d["gold_titles"] = set(json.loads(fixed_str)) + # @staticmethod + # def from_dict(d: Dict[str, Any]) -> "HotPotQAData": + # # Preprocess gold_titles + # if "gold_titles" in d and isinstance(d["gold_titles"], str): + # try: + # d["gold_titles"] = json.loads(d["gold_titles"]) + # except json.JSONDecodeError: + # # Replace single quotes with double quotes + # fixed_str = d["gold_titles"].replace("'", '"') + # d["gold_titles"] = set(json.loads(fixed_str)) - # Preprocess context - if "context" in d and isinstance(d["context"], str): - try: - d["context"] = json.loads(d["context"]) - except json.JSONDecodeError: - fixed_str = d["context"].replace("'", '"') - d["context"] = json.loads(fixed_str) + # # Preprocess context + # if "context" in d and isinstance(d["context"], str): + # try: + # d["context"] = json.loads(d["context"]) + # except json.JSONDecodeError: + # fixed_str = d["context"].replace("'", '"') + # d["context"] = json.loads(fixed_str) - return HotPotQAData(**d) + # return HotPotQAData(**d) @dataclass @@ -80,3 +78,31 @@ class TrecData(BaseData): __input_fields__ = ["question"] # follow this order too. __output_fields__ = ["class_name", "class_index"] + + +if __name__ == "__main__": + # test the hotpotqa data + data = HotPotQAData( + question="What is the capital of France?", + answer="Paris", + gold_titles=set(["Paris", "France"]), + context={"Paris": "The capital of France"}, + ) + + data_dict = data.to_dict() + print("data_dict", data_dict) + data = HotPotQAData.from_dict(data_dict) + print("data", data) + + from adalflow.utils.file_io import save_json, load_json + + # save json + save_json(data_dict, f="task.json") + # load json + data_dict_loaded = load_json(f="task.json") + + print("data_dict_loaded", data_dict_loaded) + + # restore the data + data_restored = HotPotQAData.from_dict(data_dict_loaded) + print("data_restored", data_restored) diff --git a/adalflow/adalflow/eval/__init__.py b/adalflow/adalflow/eval/__init__.py index 67de685c..1d9ecd08 100644 --- a/adalflow/adalflow/eval/__init__.py +++ b/adalflow/adalflow/eval/__init__.py @@ -1,5 +1,5 @@ from .answer_match_acc import AnswerMatchAcc -from .retriever_recall import RetrieverRecall +from .retriever_recall import RetrieverEvaluator from .llm_as_judge import LLMasJudge, DEFAULT_LLM_EVALUATOR_PROMPT from .g_eval import ( GEvalJudgeEvaluator, @@ -10,7 +10,7 @@ __all__ = [ "AnswerMatchAcc", - "RetrieverRecall", + "RetrieverEvaluator", "LLMasJudge", "DEFAULT_LLM_EVALUATOR_PROMPT", "GEvalJudgeEvaluator", diff --git a/adalflow/adalflow/eval/answer_match_acc.py b/adalflow/adalflow/eval/answer_match_acc.py index 83fe77a8..03da6cfa 100644 --- a/adalflow/adalflow/eval/answer_match_acc.py +++ b/adalflow/adalflow/eval/answer_match_acc.py @@ -3,52 +3,7 @@ from typing import List, Literal from adalflow.eval.base import BaseEvaluator, EvaluationResult from adalflow.optim.parameter import Parameter - -import re - -import string -from collections import Counter - - -def normalize_answer(s): - - def remove_articles(text): - return re.sub(r"\b(a|an|the)\b", " ", text) - - def white_space_fix(text): - return " ".join(text.split()) - - def remove_punc(text): - exclude = set(string.punctuation) - return "".join(ch for ch in text if ch not in exclude) - - def lower(text): - return text.lower() - - return white_space_fix(remove_articles(remove_punc(lower(s)))) - - -def f1_score(prediction, ground_truth): - prediction_tokens = normalize_answer(prediction).split() - ground_truth_tokens = normalize_answer(ground_truth).split() - - common = Counter(prediction_tokens) & Counter(ground_truth_tokens) - num_same = sum(common.values()) - - if len(prediction_tokens) == len(ground_truth_tokens) == 0: - # Unlike most tasks, QReCC and SQuAD-2.0 assign 1.0 in this edge case. We don't for uniformity. - print( - "\n#> F1 Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n" - ) - - if num_same == 0: - return 0 - - precision = 1.0 * num_same / len(prediction_tokens) - recall = 1.0 * num_same / len(ground_truth_tokens) - f1 = (2 * precision * recall) / (precision + recall) - - return f1 +from adalflow.eval.utils import normalize_answer, f1_score class AnswerMatchAcc(BaseEvaluator): diff --git a/adalflow/adalflow/eval/retriever_recall.py b/adalflow/adalflow/eval/retriever_recall.py index 9abe6d52..c433dc65 100644 --- a/adalflow/adalflow/eval/retriever_recall.py +++ b/adalflow/adalflow/eval/retriever_recall.py @@ -1,16 +1,23 @@ """Retriever Recall @k metric.""" -from typing import List, Union +from typing import List, Dict from adalflow.eval.base import BaseEvaluator, EvaluationResult +from adalflow.eval.utils import normalize_answer -class RetrieverRecall(BaseEvaluator): - __doc__ = r"""Recall@k measures the ratio of the number of relevant context strings in the top-k retrieved context to the total number of ground truth relevant context strings. +class RetrieverEvaluator(BaseEvaluator): + __doc__ = r"""Return Recall@k and Precision@k. + + Recall@k = Number of relevant retrieved documents/ Total number of relevant documents (len(gt_contexts)) + Precision@k = Number of relevant retrieved documents/ Total number of retrieved documents (len(retrieved_contexts)) + In our implementation, we use exact string matching between each gt context and the joined retrieved context string. You can use the longest common subsequence (LCS) or other similarity metrics(or embedding based) to decide if it is a match or not. + You can also pass ids of retrieved and the reference. + If you do not even have the ground truth context, but only grounth truth answers, you can consider using RAGAS framework for now. It computes the recall as: @@ -43,36 +50,55 @@ class RetrieverRecall(BaseEvaluator): def __init__(self): super().__init__() - def _compute_single_item( - self, retrieved_context: str, gt_context: Union[str, List[str]] - ) -> float: + def compute_single_item( + self, retrieved_context: List[str], gt_context: List[str] + ) -> Dict[str, float]: r""" Compute the recall of the retrieved context for a single query. Args: - retrieved_context (str): Retrieved context string. - gt_context (Union[str, List[str]]): Context string or list of context strings to compare against. + retrieved_context (List[str]): List of retrieved context strings. + gt_context (List[str]): List of ground truth context strings. Returns: float: Recall value. """ - if isinstance(gt_context, str): - gt_context = [gt_context] - recalled = 0 - for gt_context_sentence in gt_context: - if gt_context_sentence in retrieved_context: - recalled += 1 - return recalled / len(gt_context) + # 1 normalize the text + normalized_retrieved_context = [ + normalize_answer(doc) for doc in retrieved_context + ] + + normalized_gt_context = [normalize_answer(doc) for doc in gt_context] + + set_retrieved = set(normalized_retrieved_context) + set_gt = set(normalized_gt_context) + + # 2 calculate the recall with intersection + + recall = len(set_gt.intersection(set_retrieved)) / len(set_gt) + precision = len(set_gt.intersection(set_retrieved)) / len(set_retrieved) + + return {"recall": recall, "precision": precision} + + # if isinstance(gt_context, str): + # gt_context = [gt_context] + # recalled = 0 + # for gt_context_sentence in gt_context: + # normalized_gt_context = normalize_answer(gt_context_sentence) + # normalized_retrieved_context = normalize_answer(retrieved_context) + # if normalized_gt_context in normalized_retrieved_context: + # recalled += 1 + # return recalled / len(gt_context) def compute( self, - retrieved_contexts: Union[List[str], List[List[str]]], + retrieved_contexts: List[List[str]], gt_contexts: List[List[str]], ) -> EvaluationResult: r""" Compute the recall of the retrieved context for a list of queries. Args: - retrieved_contexts (Union[List[str], List[List[str]]): List of retrieved context strings. Using List[str] we assume you have joined all the context sentences into one string. + retrieved_context: List of retrieved context strings. gt_contexts ( List[List[str]]): List of ground truth context strings. Returns: @@ -84,15 +110,53 @@ def compute( raise ValueError( "The number of retrieved context lists and ground truth context lists should be the same." ) - k = len(retrieved_contexts) - recall_list = [] + k = len(retrieved_contexts[0]) + metric_list = [] for retrieved_context, gt_context in zip(retrieved_contexts, gt_contexts): - if isinstance(retrieved_context, list): - retrieved_context = " ".join(retrieved_context) - recall = self._compute_single_item(retrieved_context, gt_context) - recall_list.append(recall) - - avg_score = sum(recall_list) / len(recall_list) - return EvaluationResult( - avg_score, recall_list, additional_info={"type": f"RetrieverRecall@{k}"} + # if isinstance(retrieved_context, list): + # retrieved_context = " ".join(retrieved_context) + metric = self.compute_single_item(retrieved_context, gt_context) + metric_list.append(metric) + + # average through each key value + + avg_recall = sum([metric["recall"] for metric in metric_list]) / len( + metric_list ) + avg_precision = sum([metric["precision"] for metric in metric_list]) / len( + metric_list + ) + + return { + "avg_recall": avg_recall, + "avg_precision": avg_precision, + "recall_list": [metric["recall"] for metric in metric_list], + "precision_list": [metric["precision"] for metric in metric_list], + "top_k": k, + } + + # return EvaluationResult( + # avg_score, recall_list, additional_info={"type": f"RetrieverRecall@{k}"} + # ) + + +if __name__ == "__main__": + from adalflow.datasets import HotPotQA, HotPotQAData + + train_dataset = HotPotQA(split="train", size=10) + data: HotPotQAData = train_dataset[0] + gold_titles = data.gold_titles + context_titles = data.context["title"] + print(f"gold_titles: {gold_titles}, context_titles: {context_titles}") + print(f"train: {len(train_dataset)}, example: {train_dataset[0]}") + + # compute the recall and precision for 10 items + retriever_eval = RetrieverEvaluator() + + gt_contexts = [list(data.gold_titles) for data in train_dataset[:10]] + + retrieved_contexts = [list(data.context["title"]) for data in train_dataset[:10]] + + result = retriever_eval.compute(retrieved_contexts, gt_contexts) + + print(f"result: {result}") diff --git a/adalflow/adalflow/eval/utils.py b/adalflow/adalflow/eval/utils.py new file mode 100644 index 00000000..babf5b78 --- /dev/null +++ b/adalflow/adalflow/eval/utils.py @@ -0,0 +1,48 @@ +# from hotpotqa github +import re + +import string +from collections import Counter + + +def normalize_answer(s): + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(y: str, y_gt: str) -> float: + if not isinstance(y, str) or not isinstance(y_gt, str): + raise ValueError(f"y: {y},{type(y)}, y_gt: {y_gt},{type(y_gt)} must be string.") + prediction_tokens = normalize_answer(y).split() + ground_truth_tokens = normalize_answer(y_gt).split() + + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + + if len(prediction_tokens) == len(ground_truth_tokens) == 0: + # Unlike most tasks, QReCC and SQuAD-2.0 assign 1.0 in this edge case. We don't for uniformity. + print( + "\n#> F1 Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n" + ) + + if num_same == 0: + return 0 + + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + + return f1 diff --git a/adalflow/adalflow/utils/file_io.py b/adalflow/adalflow/utils/file_io.py index 124c6225..2b05fc5e 100644 --- a/adalflow/adalflow/utils/file_io.py +++ b/adalflow/adalflow/utils/file_io.py @@ -5,16 +5,17 @@ from typing import Mapping, Any, Optional, List, Dict -from adalflow.utils.serialization import ( - to_dict, - serialize, -) +from adalflow.utils.serialization import to_dict, serialize, _deserialize_object_hook log = logging.getLogger(__name__) def save_json(obj: Mapping[str, Any], f: str = "task.json") -> None: - """Save the object to a json file. + """Customized Save the object to a json file. + + Support Set. + We encourage users first save the data as DataClass using to_dict, + and then load it back to DataClass using from_dict. Args: obj (Mapping[str, Any]): The object to be saved. @@ -29,6 +30,15 @@ def save_json(obj: Mapping[str, Any], f: str = "task.json") -> None: raise IOError(f"Error saving object to JSON file {f}: {e}") +# def standard_save_json(obj: Mapping[str, Any], f: str = "task.json") -> None: +# os.makedirs(os.path.dirname(f) or ".", exist_ok=True) +# try: +# with open(f, "w") as file: +# json.dump(obj, file, indent=4) +# except IOError as e: +# raise IOError(f"Error saving object to JSON file {f}: {e}") + + def save_csv( obj: List[Dict[str, Any]], f: str = "task.csv", fieldnames: List[str] = None ) -> None: @@ -91,20 +101,42 @@ def save(obj: Mapping[str, Any], f: str = "task") -> None: raise Exception(f"Error saving object to json and pickle files: {e}") -def load_json(f: str = "task.json") -> Optional[Mapping[str, Any]]: - r"""Load the object from a json file. +# def load_json(f: str = "task.json") -> Optional[Mapping[str, Any]]: +# r"""Load the object from a json file. + +# Args: +# f (str, optional): The file name. Defaults to "task". +# """ +# if not os.path.exists(f): +# log.warning(f"File {f} does not exist.") +# return None +# try: +# with open(f, "r") as file: +# return json.load(file) +# except Exception as e: +# raise Exception(f"Error loading object from JSON file {f}: {e}") + + +def load_json(f: str) -> Any: + """Customized Load a JSON file and deserialize it. Args: - f (str, optional): The file name. Defaults to "task". + f (str): The file name of the JSON file to load. + + Returns: + Any: The deserialized Python object. """ if not os.path.exists(f): - log.warning(f"File {f} does not exist.") - return None + raise FileNotFoundError(f"JSON file not found: {f}") + try: with open(f, "r") as file: - return json.load(file) + data = json.load(file, object_hook=_deserialize_object_hook) + return data + except json.JSONDecodeError as e: + raise ValueError(f"Error decoding JSON file {f}: {e}") except Exception as e: - raise Exception(f"Error loading object from JSON file {f}: {e}") + raise IOError(f"Error loading JSON file {f}: {e}") def load_pickle(f: str = "task.pickle") -> Optional[Mapping[str, Any]]: diff --git a/adalflow/adalflow/utils/serialization.py b/adalflow/adalflow/utils/serialization.py index 5cb1dd27..bd92402d 100644 --- a/adalflow/adalflow/utils/serialization.py +++ b/adalflow/adalflow/utils/serialization.py @@ -58,6 +58,14 @@ def default(o: Any) -> Union[Dict[str, Any], str]: except Exception as e: log.error(f"Error serializing object {o}: {e}") pass + # handle set + elif isinstance(o, set): + return {"type": type(o).__name__, "data": list(o)} + else: + return {"type": type(o).__name__, "data": str(o)} + # raise NotImplementedError( + # f"Object of type {o.__class__.__name__} is not JSON serializable: {o}" + # ) elif obj_type == ObjectTypes.TYPE: log.debug(f"Object {o} is a type of {o.__name__}") try: @@ -101,6 +109,11 @@ def _deserialize_object_hook(d: Dict[str, Any]) -> Any: """Hook to deserialize objects based on their type.""" if "type" in d and "data" in d: class_name = d["type"] + if class_name == "set": + return set(d["data"]) + + # deseralize customized types + # TODO: all customized data types need to be saved class_type = EntityMapping.get(class_name) if class_type: return class_type.from_dict(d) diff --git a/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py b/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py index 5e318de9..ea25f6de 100644 --- a/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py +++ b/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py @@ -53,8 +53,13 @@ class QueryRewritterData(adal.DataClass): {% endif %} -Context: {{context}} Question: {{question}} +{% if last_query is not none %} +Last Query: {{last_query}} +{% endif %} +{% if context is not none %} +Context from last search query: {{context}} +{% endif %} """ @@ -76,160 +81,6 @@ def backward(self, *args, **kwargs): return super().backward(*args, **kwargs) -# User customize an auto-grad operator -# Need this to be a GradComponent - - -# NOTE: deprecated -# class MultiHopRetriever(adal.Retriever): -# def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2): -# super().__init__() - -# self.passages_per_hop = passages_per_hop -# self.max_hops = max_hops - -# self.data_parser = adal.DataClassParser( -# data_class=QueryRewritterData, return_data_class=True, format_type="json" -# ) - -# # Grad Component -# self.query_generators: List[adal.Generator] = [] -# for i in range(self.max_hops): -# self.query_generators.append( -# adal.Generator( -# name=f"query_generator_{i}", -# model_client=model_client, -# model_kwargs=model_kwargs, -# prompt_kwargs={ -# "few_shot_demos": Parameter( -# name="few_shot_demos_1", -# data=None, -# role_desc="To provide few shot demos to the language model", -# requires_opt=True, -# param_type=ParameterType.DEMOS, -# ), -# "task_desc_str": Parameter( -# name="task_desc_str", -# data="""Write a simple search query that will help answer a complex question. - -# You will receive a context(may contain relevant facts) and a question. -# Think step by step.""", -# role_desc="Task description for the language model", -# requires_opt=True, -# param_type=ParameterType.PROMPT, -# ), -# "output_format_str": self.data_parser.get_output_format_str(), -# }, -# template=query_template, -# output_processors=self.data_parser, -# use_cache=True, -# ) -# ) -# self.retriever = DspyRetriever(top_k=passages_per_hop) -# self.deduplicater = DeduplicateList() - -# @staticmethod -# def context_to_str(context: List[str]) -> str: -# return "\n".join(context) - -# @staticmethod -# def deduplicate(seq: list[str]) -> list[str]: -# """ -# Source: https://stackoverflow.com/a/480227/1493011 -# """ - -# seen = set() -# return [x for x in seq if not (x in seen or seen.add(x))] - -# def call(self, *, question: str, id: str = None) -> adal.RetrieverOutput: -# context = [] -# print(f"question: {question}") -# for i in range(self.max_hops): -# gen_out = self.query_generators[i]( -# prompt_kwargs={ -# "context": self.context_to_str(context), -# "question": question, -# }, -# id=id, -# ) - -# query = gen_out.data.query if gen_out.data and gen_out.data.query else None - -# print(f"query {i}: {query}") - -# retrieve_out = self.retriever.call(input=query) -# passages = retrieve_out[0].documents -# context = self.deduplicate(context + passages) -# out = [adal.RetrieverOutput(documents=context, query=query, doc_indices=[])] -# return out - -# def forward(self, *, question: str, id: str = None) -> adal.Parameter: -# # assemble the foundamental building blocks -# context = [] -# print(f"question: {question}") -# # 1. make question a parameter as generator does not have it yet -# # can create the parameter at the leaf, but not the intermediate nodes -# question_param = adal.Parameter( -# name="question", -# data=question, -# role_desc="The question to be answered", -# requires_opt=True, -# param_type=ParameterType.INPUT, -# ) -# context_param = adal.Parameter( -# name="context", -# data=context, -# role_desc="The context to be used for the query", -# requires_opt=True, -# param_type=ParameterType.INPUT, -# ) -# context_param.add_successor_map_fn( -# successor=self.query_generators[0], -# map_fn=lambda x: self.context_to_str(x.data), -# ) - -# for i in range(self.max_hops): - -# gen_out = self.query_generators[i].forward( -# prompt_kwargs={ -# "context": context_param, -# "question": question_param, -# }, -# id=id, -# ) - -# success_map_fn = lambda x: ( # noqa E731 -# x.full_response.data.query -# if x.full_response -# and x.full_response.data -# and x.full_response.data.query -# else None -# ) -# print(f"query {i}: {success_map_fn(gen_out)}") - -# gen_out.add_successor_map_fn( -# successor=self.retriever, map_fn=success_map_fn -# ) - -# retrieve_out = self.retriever.forward(input=gen_out) - -# def retrieve_out_map_fn(x: adal.Parameter): -# return x.data[0].documents if x.data and x.data[0].documents else [] - -# print(f"retrieve_out: {retrieve_out}") - -# retrieve_out.add_successor_map_fn( -# successor=self.deduplicater, map_fn=retrieve_out_map_fn -# ) - -# context_param = self.deduplicater.forward( -# exisiting_list=context_param, new_list=retrieve_out -# ) - -# context_param.param_type = ParameterType.RETRIEVER_OUTPUT - -# return context_param - query_generator_task_desc = """Write a simple search query that will help answer a complex question. You will receive a context(may contain relevant facts) and a question. @@ -379,19 +230,31 @@ def context_to_retrover_output(x): return context -task_desc_str = """Write a simple search query that will help answer a complex question. +# task_desc_str = """Write a simple search query that will help answer a complex question. -You will receive a context(may contain relevant facts) and a question. +# You will receive a context(may contain relevant facts) and a question. +# Think step by step.""" + +task_desc_str = """ +You will receive an original question, last search query, and the retrieved context from the last search query. +Write the next search query to help retrieve all relevant context to answer the original question. Think step by step.""" -task_desc_str = """ You are a query assistant that helps search all relevant context to answer a multi-hop question. -You will a question, and existing context(may contain relevant facts along with its sub-questions). -Write a new simple search query to help retrieve the relevant context to answer the question. -Think step by step.""" +trained_task_desc_strs = [ + "You are tasked with formulating precise search queries using the original question, last search query, and its retrieved context. Prioritize identifying, emphasizing, and explicitly including all crucial entities, relationships, and geographical details mentioned in the question. Ensure comprehensive retrieval by focusing on key elements such as specific individuals (e.g., 'Kyrie Irving'), roles, or contextual details required for accuracy. Demonstrate reasoning by cross-referencing multiple sources and provide clear examples where necessary. Adapt queries to capture all nuances effectively for improved relevance and accuracy. Think step by step.", + "You will receive an original question, the last search query, and the retrieved context from that search. Write the next search query to ensure comprehensive retrieval of all relevant context needed to answer the original question. Emphasize identifying, precisely including, and verifying specific key entities, historical events, and factual names directly linked to the question within the context. Explicitly use the context to confirm and match critical entities to improve recall and ensure consistency with the targeted entities. Avoid irrelevant inclusions or false positives by cross-referencing data and verifying alignment accurately. Think step by step.", +] + + +# task_desc_str = """ You are a query assistant that helps search all relevant context to answer a multi-hop question. +# You will a question, and existing context(may contain relevant facts along with its sub-questions). +# Write a new simple search query to help retrieve the relevant context to answer the question. +# Think step by step.""" -class MultiHopRetriever(adal.Retriever): + +class MultiHopRetriever(adal.Component): def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2): super().__init__() @@ -406,6 +269,7 @@ def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2): self.query_generators: adal.ComponentList[adal.Generator] = adal.ComponentList() self.retrievers: List[Retriever] = [] self.deduplicaters: List[adal.GradComponent] = [] + for i in range(self.max_hops): self.query_generators.append( adal.Generator( @@ -422,10 +286,8 @@ def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2): # ), "task_desc_str": Parameter( name="task_desc_str", - data="""Write a simple search query that will help answer a complex question. - -You will receive a context(may contain relevant facts) and a question. -Think step by step.""", + # data=task_desc_str, + data=trained_task_desc_strs[i], role_desc="Task description for the language model", requires_opt=True, param_type=ParameterType.PROMPT, @@ -453,29 +315,52 @@ def deduplicate(seq: list[str]) -> list[str]: seen = set() return [x for x in seq if not (x in seen or seen.add(x))] - # TODO: simplify and avoid the need where users need to write two methods (call and forward) - def call(self, *, input: str, id: str = None) -> List[adal.RetrieverOutput]: - # assemble the foundamental building blocks - printc(f"question: {input}", "yellow") - out = self.forward(input=input, id=id) + def call(self, *, input: str, id: str = None) -> adal.RetrieverOutput: + context = [] + queries: List[str] = [] + last_query = None + for i in range(self.max_hops): + gen_out = self.query_generators[i]( + prompt_kwargs={ + "context": context, + "question": input, + "last_query": last_query, + }, + id=id, + ) - if not isinstance(out, adal.Parameter): - raise ValueError("The output should be a parameter") + query = gen_out.data.query if gen_out.data and gen_out.data.query else input - return out.data # or full response its up to users + # print(f"query {i}: {query}") + + retrieve_out = self.retrievers[i](input=query, id=id) + + passages = retrieve_out.documents + context = self.deduplicate(context + passages) + queries.append(query) + last_query = query + out = adal.RetrieverOutput( + documents=context, query=queries, doc_indices=[], id=id + ) + printc(f"queries: {queries}", "yellow") + return out def forward(self, *, input: str, id: str = None) -> adal.Parameter: # assemble the foundamental building blocks printc(f"question: {input}", "yellow") - context = [] + # context = [] queries: List[str] = [] + context = [] + last_query = None + for i in range(self.max_hops): gen_out: Parameter = self.query_generators[i].forward( prompt_kwargs={ - "context": context, # can be a list or a parameter + "context": context, + "last_query": last_query, "question": adal.Parameter( name="question", data=input, @@ -488,17 +373,11 @@ def forward(self, *, input: str, id: str = None) -> adal.Parameter: ) success_map_fn = lambda x: ( # noqa E731 - x.full_response.data.query - if x.full_response - and x.full_response.data - and x.full_response.data.query - else ( - x.full_response.raw_response - if x.full_response and x.full_response.raw_response - else None - ) + x.data.data.query + if x.data and x.data.data and x.data.data.query + else (x.data.raw_response if x.data and x.data.raw_response else None) ) - print(f"query {i}: {success_map_fn(gen_out)}") + # printc(f"query {i}: {success_map_fn(gen_out)}") queries.append(success_map_fn(gen_out)) @@ -512,7 +391,7 @@ def forward(self, *, input: str, id: str = None) -> adal.Parameter: retrieve_out = self.retrievers[i].forward(input=gen_out, id=id) def retrieve_out_map_fn(x: adal.Parameter): - return x.data[0].documents if x.data and x.data[0].documents else [] + return x.data.documents if x.data and x.data.documents else [] # print(f"retrieve_out: {retrieve_out}") @@ -523,18 +402,24 @@ def retrieve_out_map_fn(x: adal.Parameter): context = self.deduplicaters[i].forward( exisiting_list=context, new_list=retrieve_out ) + last_query = success_map_fn(gen_out) context.param_type = ParameterType.RETRIEVER_OUTPUT def context_to_retrover_output(x): - return [ - adal.RetrieverOutput( - documents=x.data, query=[input] + queries, doc_indices=[] - ) - ] + return adal.RetrieverOutput( + documents=x.data, query=[input] + queries, doc_indices=[], id=id + ) context.data = context_to_retrover_output(context) + if not isinstance(context.data, adal.RetrieverOutput): + raise ValueError( + f"The output should be a list of RetrieverOutput, got {type(context.data)}" + ) + + printc(f"queries: {queries}", "yellow") + printc(f"MultiHopRetriever grad fn: {context.grad_fn}", "yellow") return context diff --git a/benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py b/benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py index 94329b25..ffa94de6 100644 --- a/benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py +++ b/benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py @@ -112,7 +112,7 @@ def __init__(self, top_k: int = 3): def call( self, input: str, top_k: Optional[int] = None, id: str = None - ) -> List[RetrieverOutput]: + ) -> RetrieverOutput: k = top_k or self.top_k @@ -121,18 +121,15 @@ def call( output = self.dspy_retriever(query_or_queries=input, k=k) # print(f"dsy_retriever output: {output}") - final_output: List[RetrieverOutput] = [] documents = output.passages - final_output.append( - RetrieverOutput( - query=input, - documents=documents, - doc_indices=[], - ) + return RetrieverOutput( + query=input, + documents=documents, + doc_indices=[], ) + # print(f"final_output: {final_output}") - return final_output # def forward(self, *args, **kwargs): # # Explicitly use adal.GradComponent's forward method @@ -147,12 +144,12 @@ def call( # else: # return self.call(*args, **kwargs) - def __call__(self, input: str, top_k: Optional[int] = None, id: str = None): - r"""Retrieves the top k relevant passages from using input as the subquery to obtain context for question""" - if self.training: - return adal.GradComponent.forward(self, input=input, top_k=top_k, id=id) - else: - return self.call(input=input, top_k=top_k, id=id) + # def __call__(self, input: str, top_k: Optional[int] = None, id: str = None): + # r"""Retrieves the top k relevant passages from using input as the subquery to obtain context for question""" + # if self.training: + # return adal.GradComponent.forward(self, input=input, top_k=top_k, id=id) + # else: + # return self.call(input=input, top_k=top_k, id=id) task_desc_str = r"""Answer questions with short factoid answers. @@ -225,7 +222,7 @@ def call(self, question: str, id: str = None) -> adal.GeneratorOutput: retriever_out = self.retriever.call(input=question, id=id) successor_map_fn = lambda x: ( # noqa E731 - "\n\n".join(x[0].documents) if x and x[0] and x[0].documents else "" + "\n\n".join(x.documents) if x and x.documents else "" ) retrieved_context = successor_map_fn(retriever_out) @@ -238,29 +235,16 @@ def call(self, question: str, id: str = None) -> adal.GeneratorOutput: prompt_kwargs=prompt_kwargs, id=id, ) - # self.llm.print_prompt(**prompt_kwargs) - # print(f"retrieved_context: {retrieved_context}") - # print(f"retriever_out: {retriever_out}") - return output - # def call(self, *, question: str, id: str = None) -> adal.GeneratorOutput: - # self.train() - # out = self.forward(question=question, id=id) - # if not isinstance(out, adal.Parameter): - # raise ValueError( - # "This output should be a Parameter, please check the forward function" - # ) - # self.eval() - # return out.data + return output - # TODO: add id in the retriever output def forward(self, question: str, id: str = None) -> adal.Parameter: if not self.training: raise ValueError("This component is not supposed to be called in eval mode") retriever_out = self.retriever.forward(input=question, id=id) successor_map_fn = lambda x: ( # noqa E731 - "\n\n".join(x.data[0].documents) - if x.data and x.data[0] and x.data[0].documents + "\n\n".join(x.data.documents) + if x.data and x.data and x.data.documents else "" ) retriever_out.add_successor_map_fn(successor=self.llm, map_fn=successor_map_fn) @@ -279,8 +263,8 @@ def bicall( retriever_out = self.retriever(input=question) if isinstance(retriever_out, adal.Parameter): successor_map_fn = lambda x: ( # noqa E731 - "\n\n".join(x.data[0].documents) - if x.data and x.data[0] and x.data[0].documents + "\n\n".join(x.data.documents) + if x.data and x.data and x.data.documents else "" ) retriever_out.add_successor_map_fn( @@ -289,7 +273,7 @@ def bicall( # retriever_out.requires_opt = False else: successor_map_fn = lambda x: ( # noqa E731 - "\n\n".join(x[0].documents) if x and x[0] and x[0].documents else "" + "\n\n".join(x.documents) if x and x.documents else "" ) retrieved_context = successor_map_fn(retriever_out) prompt_kwargs = { diff --git a/benchmarks/hotpot_qa/adal_exp/train_agent_rag.py b/benchmarks/hotpot_qa/adal_exp/train_agent_rag.py index e3ddd9ee..778301f8 100644 --- a/benchmarks/hotpot_qa/adal_exp/train_agent_rag.py +++ b/benchmarks/hotpot_qa/adal_exp/train_agent_rag.py @@ -7,6 +7,7 @@ from benchmarks.hotpot_qa.config import load_datasets from benchmarks.hotpot_qa.adal_exp.build_multi_hop_rag import AgenticRAG from use_cases.config import gpt_3_model, gpt_4o_model +from adalflow.utils import printc # TODO: look more into the loss function @@ -28,10 +29,15 @@ def __init__( model_client=model_client, model_kwargs=model_kwargs, ) - eval_fn = AnswerMatchAcc(type="fuzzy_match").compute_single_item + eval_fn = AnswerMatchAcc(type="exact_match").compute_single_item # 0.55 loss_fn = adal.EvalFnToTextLoss( eval_fn=eval_fn, eval_fn_desc="fuzzy_match: 1 if str(y_gt) in str(y) else 0" ) + # eval_fn = f1_score # 0.38 (hand crafted the finish, exat match 0.25) + + # loss_fn = adal.EvalFnToTextLoss( + # eval_fn=eval_fn, eval_fn_desc="Computes the overlaps between y and y_gt" + # ) super().__init__( task=task, eval_fn=eval_fn, @@ -53,11 +59,13 @@ def prepare_task(self, sample: HotPotQAData) -> Tuple[Callable[..., Any], Dict]: # eval mode: get the generator output, directly engage with the eval_fn def prepare_eval(self, sample: HotPotQAData, y_pred: ReActOutput) -> float: - # y_label = "" - # if y_pred and y_pred.data and y_pred.data.answer: - # y_label = y_pred.data.answer + y_label = "" + if y_pred is not None and y_pred.answer: + y_label = y_pred.answer + + printc(f"eval y_label: {y_label}, y_gt: {sample.answer}") - return self.eval_fn, {"y": y_pred, "y_gt": sample.answer} + return self.eval_fn, {"y": y_label, "y_gt": sample.answer} # train mode: get the loss and get the data from the full_response def prepare_loss(self, sample: HotPotQAData, pred: adal.Parameter): @@ -70,10 +78,14 @@ def prepare_loss(self, sample: HotPotQAData, pred: adal.Parameter): ) # pred's full_response is the output of the task pipeline which is GeneratorOutput - print(type(pred.data)) - pred.eval_input = ( - pred.data[-1].observation if pred.data and pred.data[-1] else "" - ) + # pred.eval_input = ( + # pred.data[-1].observation if pred.data and pred.data[-1] else "" + # ) + pred.eval_input = pred.data.answer if pred.data else "" + # pred.eval_input = ( + # pred.data[-1].observation if pred.data and pred.data[-1] else "" + # ) + printc(f"loss eval_input: {pred.eval_input}") return self.loss_fn, {"kwargs": {"y": pred, "y_gt": y_gt}, "id": sample.id} @@ -94,10 +106,10 @@ def train_diagnose( teacher_model_config=gpt_3_model, text_optimizer_model_config=gpt_3_model, ) - trainset = trainset[:5] + # trainset = trainset[:5] trainer = adal.Trainer(adaltask=adal_component) - trainer.diagnose(dataset=trainset, split="train") - # trainer.diagnose(dataset=valset, split="val") + # trainer.diagnose(dataset=trainset, split="train") + trainer.diagnose(dataset=valset, split="val") # trainer.diagnose(dataset=testset, split="test") @@ -196,8 +208,8 @@ def train( # task = MultiHopRAGAdal(**gpt_3_model) # print(task) - train_diagnose(**gpt_3_model) - exit() + # train_diagnose(**gpt_3_model) + # exit() ckpt = train( debug=False, diff --git a/benchmarks/hotpot_qa/adal_exp/train_multi_hop_retriever.py b/benchmarks/hotpot_qa/adal_exp/train_multi_hop_retriever.py new file mode 100644 index 00000000..c27877f1 --- /dev/null +++ b/benchmarks/hotpot_qa/adal_exp/train_multi_hop_retriever.py @@ -0,0 +1,259 @@ +from typing import Any, Callable, Dict, Tuple, List + +import adalflow as adal +from adalflow.eval.retriever_recall import RetrieverEvaluator +from adalflow.datasets.types import HotPotQAData +from benchmarks.hotpot_qa.config import load_datasets + +from benchmarks.hotpot_qa.adal_exp.build_multi_hop_rag import ( + MultiHopRetriever, +) +from use_cases.config import gpt_3_model, gpt_4o_model + + +def retriever_recall(y: List[str], y_gt: List[str]) -> float: + return RetrieverEvaluator().compute_single_item(y, y_gt)["recall"] + + +class MultiHopRetrieverAdal(adal.AdalComponent): + def __init__( + self, + model_client: adal.ModelClient, + model_kwargs: Dict, + backward_engine_model_config: Dict | None = None, + teacher_model_config: Dict | None = None, + text_optimizer_model_config: Dict | None = None, + ): + task = MultiHopRetriever( + model_client=model_client, + model_kwargs=model_kwargs, + passages_per_hop=3, + max_hops=2, + ) + eval_fn = retriever_recall + loss_fn = adal.EvalFnToTextLoss( + eval_fn=eval_fn, + eval_fn_desc="recall: len(y_gt.intersection(y)) / len(y_gt)", + ) + super().__init__( + task=task, + eval_fn=eval_fn, + loss_fn=loss_fn, + backward_engine_model_config=backward_engine_model_config, + teacher_model_config=teacher_model_config, + text_optimizer_model_config=text_optimizer_model_config, + ) + + # tell the trainer how to call the task + def prepare_task(self, sample: HotPotQAData) -> Tuple[Callable[..., Any], Dict]: + if self.task.training: + return self.task.forward, {"input": sample.question, "id": sample.id} + else: + return self.task.call, {"input": sample.question, "id": sample.id} + + # TODO: use two map fn to make the cde even simpler + + # eval mode: get the generator output, directly engage with the eval_fn + def prepare_eval(self, sample: HotPotQAData, y_pred: adal.RetrieverOutput) -> float: + if isinstance(y_pred, adal.Parameter): + raise ValueError("y_pred is not a RetrieverOutput") + documents = y_pred.documents + # get titles by split | + y_pred_titles = [] + for doc in documents: + title, content = doc.split("|") + y_pred_titles.append(title) + + return self.eval_fn, { + "y": y_pred_titles, + "y_gt": list(sample.gold_titles), + } + + # train mode: get the loss and get the data from the full_response + def prepare_loss(self, sample: HotPotQAData, pred: adal.Parameter): + # prepare gt parameter + y_gt = adal.Parameter( + name="y_gt", + data=sample.gold_titles, + eval_input=list(sample.gold_titles), + requires_opt=False, + ) + + pred_titles = [] + for doc in pred.data.documents: + title, content = doc.split("|") + pred_titles.append(title) + + # pred's full_response is the output of the task pipeline which is GeneratorOutput + # pred.eval_input = ( + # pred.data.do + # if pred.data and pred.data.data and pred.data.data.answer + # else "" + # ) + pred.eval_input = pred_titles + return self.loss_fn, { + "kwargs": {"y": pred, "y_gt": y_gt}, + "id": sample.id, + } + + +from adalflow.core.generator import BackwardPassSetup + + +# Note: diagnose is quite helpful, it helps you to quickly check if the evalfunction is the right metrics +# i checked the eval which does fuzzy match, and found some yes and Yes are not matched, then converted both strings to lower and +# the performances have gone up from 0.15 to 0.4 +def train_diagnose( + model_client: adal.ModelClient, + model_kwargs: Dict, +) -> Dict: + + trainset, valset, testset = load_datasets() + + adal_component = MultiHopRetrieverAdal( + model_client, + model_kwargs, + backward_engine_model_config=gpt_4o_model, + teacher_model_config=gpt_3_model, + text_optimizer_model_config=gpt_3_model, + ) + trainer = adal.Trainer(adaltask=adal_component) + # trainer.diagnose(dataset=trainset, split="train") # 0.69 recall + # trainer.diagnose(dataset=valset, split="val") # 0.675 recall + trainer.diagnose(dataset=testset, split="test") # 0.71 (0.665) + + +def train( + train_batch_size=4, # larger batch size is not that effective, probably because of llm's lost in the middle + raw_shots: int = 0, + bootstrap_shots: int = 4, + max_steps=1, + num_workers=10, + strategy="constrained", + optimization_order="sequential", + debug=False, + resume_from_ckpt=None, + exclude_input_fields_from_bootstrap_demos=True, + seed=None, + tg: bool = False, + max_proposals_per_step: int = 5, +): + adal_component = MultiHopRetrieverAdal( + **gpt_3_model, + teacher_model_config=gpt_4o_model, + text_optimizer_model_config=gpt_4o_model, # gpt3.5 is not enough to be used as a good optimizer, it struggles for long contenxt + backward_engine_model_config=gpt_4o_model, + ) + backward_pass_setup = None + if tg: + backward_pass_setup = BackwardPassSetup( + all_pred_at_once=False, + compute_grad_for_errors_only=False, + ) + # print(adal_component) + trainer = adal.Trainer( + train_batch_size=train_batch_size, + adaltask=adal_component, + strategy=strategy, + max_steps=max_steps, + num_workers=num_workers, + raw_shots=raw_shots, + bootstrap_shots=bootstrap_shots, + debug=debug, + weighted_sampling=True, + optimization_order=optimization_order, + exclude_input_fields_from_bootstrap_demos=exclude_input_fields_from_bootstrap_demos, + sequential_order=["text", "demo"], + max_proposals_per_step=max_proposals_per_step, + backward_pass_setup=backward_pass_setup, + ) + trainer.set_random_seed(seed) + print(trainer) + + train_dataset, val_dataset, test_dataset = load_datasets() + ckpt, _ = trainer.fit( + train_dataset=train_dataset, + val_dataset=val_dataset, + test_dataset=test_dataset, + resume_from_ckpt=resume_from_ckpt, + ) + return ckpt + + +if __name__ == "__main__": + from use_cases.config import gpt_3_model + + # log = adal.get_logger(level="DEBUG", enable_console=False) + + adal.setup_env() + + import json + + import random + + random.seed(2025) + + adal.setup_env() + + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument("--strategy", type=str, default="constrained") + parser.add_argument("--use_tg", action="store_false") + parser.add_argument("--max_proposals_per_step", type=int, default=5) + parser.add_argument( + "output_path", nargs="?", help="File path to save the checkpoint" + ) + + args = parser.parse_args() + + set_strategy = args.strategy + set_output_path = args.output_path + use_tg = args.use_tg + max_proposals_per_step = args.max_proposals_per_step + + # task = MultiHopRAGAdal(**gpt_3_model) + # print(task) + + # train_diagnose(**gpt_3_model) + # exit() + + # train: 0.15 before the evaluator converted to lower and 0.4 after the conversion + ckpt = train( + debug=False, + max_steps=12, + seed=2025, # pass the numpy seed + tg=use_tg, + strategy=set_strategy, + max_proposals_per_step=max_proposals_per_step, + # resume_from_ckpt="/Users/liyin/.adalflow/ckpt/ValinaRAGAdal/random_max_steps_12_7c091_run_1.json", + ) + print(f"ckpt: {ckpt}") + if set_output_path: + with open(set_output_path, "w") as f: + json.dump({"ckpt": ckpt}, f) + print(f"Checkpoint saved to {set_output_path}") + else: + print("No file path provided for saving the checkpoint.") + + # notes for debug: if have nontype, delete all model cache and try again + # raise ValueError(ValueError: score must be provided for each demo, + + # 12/11/2024 + # demo only: /Users/liyin/Documents/test/LightRAG/.adalflow/ckpt/MultiHopRAGAdal/constrained_max_steps_12_8cdfc_run_9.json + + # why text grad did not improve in the rag case? Do we need to improve the meta prompt? + # /Users/liyin/.adalflow/ckpt/MultiHopRAGAdal/constrained_max_steps_12_2686e_run_1.json + # 0.58 -> 0.68 on the test split + # 0.72 text grad /Users/liyin/.adalflow/ckpt/MultiHopRAGAdal/constrained_max_steps_12_c1660_run_1.json + # try cycle next + # 0.66 /Users/liyin/.adalflow/ckpt/MultiHopRAGAdal/constrained_max_steps_12_1d189_run_1.json + # no gradients 1021s (/Users/liyin/.adalflow/ckpt/MultiHopRAGAdal/constrained_max_steps_12_68e7e_run_1.json) -> 0.64 -> 0.68, pass 10/10+28 + # no gradient but scores (positive & negative) /Users/liyin/.adalflow/ckpt/MultiHopRAGAdal/constrained_max_steps_12_83871_run_1.json 0.64->0.66, test 0.64 -> 0.66 + # no gradient but only negative score + # no gradient but score + teacher demonstration. + # feedback while seeing the gt + y + # only negative feedback /Users/liyin/.adalflow/ckpt/MultiHopRAGAdal/constrained_max_steps_12_f5506_run_1.json 0.62 -> 0.7 + # /Users/liyin/.adalflow/ckpt/MultiHopRAGAdal/constrained_max_steps_12_b4aa5_run_1.json 0.74 pass rate 8 32 + # random cycle rag: /Users/liyin/.adalflow/ckpt/MultiHopRAGCycleAdal/random_max_steps_12_82bd2_run_1.json 0.64