Skip to content

Commit

Permalink
fix the hotpotqa dataaset, add the retriever recall and precision, tr…
Browse files Browse the repository at this point in the history
…ain multi-hop retriever
  • Loading branch information
liyin2015 committed Jan 6, 2025
1 parent 0ad2490 commit e06fd7b
Show file tree
Hide file tree
Showing 16 changed files with 671 additions and 361 deletions.
31 changes: 28 additions & 3 deletions adalflow/adalflow/components/agent/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,14 @@
<END_OF_TASK_SPEC>
"""

# - In this case, you are working as a multi-hop retriever and your answer in finish MUST be verbatim short factoid responses from retrieved context.
# - Answer with only the exact answer phrase, not a full sentence or paragraph.

DEFAULT_REACT_AGENT_SYSTEM_PROMPT = r"""<START_OF_SYSTEM_PROMPT>
{{react_agent_task_desc}}
- You have a maximum of {{max_steps}} steps to complete the task. Plan your steps carefully.
{# Tools #}
{% if tools %}
<START_OF_TOOLS>
Expand Down Expand Up @@ -95,6 +101,8 @@
"Action": "{{history.action.action}}",
{% endif %}
"Observation": "{{history.observation}}"
Current Step/Max Step: {{step_history|length + 1}} / {{max_steps}}
------------------------
{% endfor %}
</STEPS>
Expand All @@ -114,6 +122,14 @@ def map_step_history_to_prompt(x: Parameter) -> str:
return "\n".join(output)


def map_step_history_list_to_prompt(x: Parameter) -> str:
output = []
for i, step in enumerate(x.data.step_history):
step_str = f"Step {i + 1}.\n"
output.append(step_str + step.to_prompt_str())
return "\n".join(output)


class AppendStepHistory(GradComponent):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -283,6 +299,7 @@ def __init__(
requires_opt=True,
),
"context_variables": self.context_variables,
"max_steps": self.max_steps,
}
self.planner = Generator(
template=template,
Expand Down Expand Up @@ -326,7 +343,7 @@ def llm_tool(input: str, **kwargs) -> str:
return None

def finish(answer: str, **kwargs) -> str:
"""Finish the task with answer."""
"""Finish the task with verbatim short factoid responses from retrieved context."""
return answer

self._finish = finish
Expand Down Expand Up @@ -676,8 +693,16 @@ def _get_answer(
return None

last_step: StepOutput = None
if isinstance(step_history, Parameter):
if isinstance(
step_history, Parameter
): # change the step history at the last step
try:
output = ReActOutput(
step_history=step_history.data,
answer=str(step_history.data[-1].observation),
)
step_history.data = output
step_history.data_in_prompt = map_step_history_list_to_prompt
return step_history

except Exception as e:
Expand All @@ -687,7 +712,7 @@ def _get_answer(
last_step = step_history[-1]
# printc(f"last_step: {last_step}", color="yellow")

return last_step.observation
return str(last_step.observation)

def call(self, *args, **kwargs) -> ReActOutput:
output = self.bicall(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion adalflow/adalflow/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
@dataclass
class BackwardPassSetup(DataClass):
all_pred_at_once: bool = field(
default=True, metadata={"desc": "Backward all predecessors at once."}
default=False, metadata={"desc": "Backward all predecessors at once."}
)
threshold_score_to_compute_grad_for_errors: float = field(
default=0.9,
Expand Down
2 changes: 1 addition & 1 deletion adalflow/adalflow/core/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def forward(
)
response.trace_forward_pass(
input_args={"input": input, "top_k": top_k},
full_response=response,
full_response=response.data,
id=self.id,
name=self.name,
)
Expand Down
16 changes: 11 additions & 5 deletions adalflow/adalflow/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,25 @@ class RetrieverOutput(DataClass):
It is up to the subclass of Retriever to specify the type of query and document.
"""

doc_indices: List[int] = field(metadata={"desc": "List of document indices"})
doc_scores: Optional[List[float]] = field(
id: str = field(default=None, metadata={"desc": "The unique id of the output"})

doc_indices: List[int] = field(
default=required_field, metadata={"desc": "List of document indices"}
)
doc_scores: List[float] = field(
default=None, metadata={"desc": "List of document scores"}
)
query: Optional[RetrieverQueryType] = field(
query: RetrieverQueryType = field(
default=None, metadata={"desc": "The query used to retrieve the documents"}
)
documents: Optional[List[RetrieverDocumentType]] = field(
documents: List[RetrieverDocumentType] = field(
default=None, metadata={"desc": "List of retrieved documents"}
)


RetrieverOutputType = List[RetrieverOutput] # so to support multiple queries at once
RetrieverOutputType = Union[
List[RetrieverOutput], RetrieverOutput
] # so to support multiple queries at once


#######################################################################################
Expand Down
7 changes: 4 additions & 3 deletions adalflow/adalflow/datasets/hotpot_qa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
import os
from typing import Literal
from typing import Literal, List

from adalflow.utils.lazy_import import safe_import, OptionalPackages

Expand Down Expand Up @@ -198,10 +198,11 @@ def _check_or_download_dataset(
# target_path = prepare_dataset_path(self.root, task_name, split)
target_path = os.path.join(data_path_dir, f"{split}.json")
# filter the examples with only the keys
save_examples = []
save_examples: List[HotPotQAData] = []
for example in examples:
save_example = {k: example[k] for k in keys if k in example}
save_examples.append(save_example)
save_example = HotPotQAData.from_dict(save_example)
save_examples.append(save_example.to_dict())
save_json(save_examples, f=target_path)
if split == "train":
print(f"train example: {examples[0]}")
Expand Down
66 changes: 46 additions & 20 deletions adalflow/adalflow/datasets/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from dataclasses import dataclass, field
from typing import Dict
from adalflow.core.base_data_class import DataClass
import json
from typing import Any


@dataclass
Expand Down Expand Up @@ -40,26 +38,26 @@ class HotPotQAData(Example):
default=None,
)

@staticmethod
def from_dict(d: Dict[str, Any]) -> "HotPotQAData":
# Preprocess gold_titles
if "gold_titles" in d and isinstance(d["gold_titles"], str):
try:
d["gold_titles"] = json.loads(d["gold_titles"])
except json.JSONDecodeError:
# Replace single quotes with double quotes
fixed_str = d["gold_titles"].replace("'", '"')
d["gold_titles"] = set(json.loads(fixed_str))
# @staticmethod
# def from_dict(d: Dict[str, Any]) -> "HotPotQAData":
# # Preprocess gold_titles
# if "gold_titles" in d and isinstance(d["gold_titles"], str):
# try:
# d["gold_titles"] = json.loads(d["gold_titles"])
# except json.JSONDecodeError:
# # Replace single quotes with double quotes
# fixed_str = d["gold_titles"].replace("'", '"')
# d["gold_titles"] = set(json.loads(fixed_str))

# Preprocess context
if "context" in d and isinstance(d["context"], str):
try:
d["context"] = json.loads(d["context"])
except json.JSONDecodeError:
fixed_str = d["context"].replace("'", '"')
d["context"] = json.loads(fixed_str)
# # Preprocess context
# if "context" in d and isinstance(d["context"], str):
# try:
# d["context"] = json.loads(d["context"])
# except json.JSONDecodeError:
# fixed_str = d["context"].replace("'", '"')
# d["context"] = json.loads(fixed_str)

return HotPotQAData(**d)
# return HotPotQAData(**d)


@dataclass
Expand All @@ -80,3 +78,31 @@ class TrecData(BaseData):

__input_fields__ = ["question"] # follow this order too.
__output_fields__ = ["class_name", "class_index"]


if __name__ == "__main__":
# test the hotpotqa data
data = HotPotQAData(
question="What is the capital of France?",
answer="Paris",
gold_titles=set(["Paris", "France"]),
context={"Paris": "The capital of France"},
)

data_dict = data.to_dict()
print("data_dict", data_dict)
data = HotPotQAData.from_dict(data_dict)
print("data", data)

from adalflow.utils.file_io import save_json, load_json

# save json
save_json(data_dict, f="task.json")
# load json
data_dict_loaded = load_json(f="task.json")

print("data_dict_loaded", data_dict_loaded)

# restore the data
data_restored = HotPotQAData.from_dict(data_dict_loaded)
print("data_restored", data_restored)
4 changes: 2 additions & 2 deletions adalflow/adalflow/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .answer_match_acc import AnswerMatchAcc
from .retriever_recall import RetrieverRecall
from .retriever_recall import RetrieverEvaluator
from .llm_as_judge import LLMasJudge, DEFAULT_LLM_EVALUATOR_PROMPT
from .g_eval import (
GEvalJudgeEvaluator,
Expand All @@ -10,7 +10,7 @@

__all__ = [
"AnswerMatchAcc",
"RetrieverRecall",
"RetrieverEvaluator",
"LLMasJudge",
"DEFAULT_LLM_EVALUATOR_PROMPT",
"GEvalJudgeEvaluator",
Expand Down
47 changes: 1 addition & 46 deletions adalflow/adalflow/eval/answer_match_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,7 @@
from typing import List, Literal
from adalflow.eval.base import BaseEvaluator, EvaluationResult
from adalflow.optim.parameter import Parameter

import re

import string
from collections import Counter


def normalize_answer(s):

def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)

def white_space_fix(text):
return " ".join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()

common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())

if len(prediction_tokens) == len(ground_truth_tokens) == 0:
# Unlike most tasks, QReCC and SQuAD-2.0 assign 1.0 in this edge case. We don't for uniformity.
print(
"\n#> F1 Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n"
)

if num_same == 0:
return 0

precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)

return f1
from adalflow.eval.utils import normalize_answer, f1_score


class AnswerMatchAcc(BaseEvaluator):
Expand Down
Loading

0 comments on commit e06fd7b

Please sign in to comment.