Skip to content

Commit

Permalink
clean up on react
Browse files Browse the repository at this point in the history
  • Loading branch information
liyin2015 committed Dec 24, 2024
1 parent ee842b8 commit f8d006c
Show file tree
Hide file tree
Showing 16 changed files with 140 additions and 231 deletions.
152 changes: 18 additions & 134 deletions adalflow/adalflow/components/agent/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,32 +110,22 @@ def call(
return step_history


class StepHistoryToOutput(GradComponent):
def __init__(self):
super().__init__()

def call(self, step_history: List[StepOutput]) -> Any:
"""Convert the step_history to the final output."""
if not step_history:
return None
return step_history[-1].observation


class GeneroutorOutputToStepOutput(GradComponent):
def __init__(self):
super().__init__()

def call(
self,
generator_output: GeneratorOutput,
step_output: StepOutput,
step: int,
execute_action: Any,
) -> StepOutput:
"""Convert the generator output to the step output."""
return execute_action_fn(generator_output, step_output, step, execute_action)


# class GeneroutorOutputToStepOutput(GradComponent):
# def __init__(self):
# super().__init__()

# def call(
# self,
# generator_output: GeneratorOutput,
# step_output: StepOutput,
# step: int,
# execute_action: Any,
# ) -> StepOutput:
# """Convert the generator output to the step output."""
# return execute_action_fn(generator_output, step_output, step, execute_action)


# TODO: make execute_action_fn to a GradComponent to enable the training of the tools too.
def execute_action_fn(
x: GeneratorOutput, step_output: StepOutput, step: int, execute_action: Any
) -> StepOutput:
Expand Down Expand Up @@ -284,11 +274,8 @@ def __init__(
model_kwargs=model_kwargs,
)

# self.step_history: Union[List[StepOutput], Parameter] = None
# added this component to the computation graph
self.append_step_history = AppendStepHistory()
self.step_history_to_output = StepHistoryToOutput()
# self.generator_output_to_step_output = GeneroutorOutputToStepOutput()

def _init_tools(
self,
Expand Down Expand Up @@ -329,12 +316,6 @@ def finish(answer: str) -> str:
tools.append(finish)
self.tool_manager: ToolManager = ToolManager(tools=tools)

# def reset(self):
# r"""Reset the agent to start a new query."""
# self.step_history = None
# # if isinstance(self.step_history, Parameter):
# # self.step_history = self.step_history.data

# TODO: add async execution
def _execute_action(self, action_step: StepOutput) -> Optional[StepOutput]:
"""Parse the action string to a function call and execute it. Update the action_step with the result."""
Expand Down Expand Up @@ -366,10 +347,6 @@ def _run_one_step(
"""
from functools import partial

# prompt_kwargs["step_history"] = self.step_history

# step_history = prompt_kwargs["step_history"]
# add step history to the prompt_kwargs
prompt_kwargs["step_history"] = step_history

log.debug(
Expand All @@ -383,33 +360,6 @@ def _run_one_step(
# create a new step output
step_output: StepOutput = StepOutput(step=step)

# def execute_action_fn(x: GeneratorOutput, step_output: StepOutput = step_output) ->StepOutput:
# """Execute the action and update the step_output."""
# if x.error:
# error_msg = f"Error planning step {step}: {x.error}"
# step_output.observation = error_msg
# log.error(error_msg)
# else:
# try:
# fun_expr: FunctionExpression = x.data
# step_output.action = fun_expr
# log.debug(f"Step {step}: {fun_expr}")

# if step_output and step_output.action:
# step_output = self._execute_action(step_output)
# printc(f"Step {step}: \n{step_output}\n_______\n", color="blue")
# return step_output
# else:
# printc(f"Failed to parse response for step {step}", color="red")
# log.error(f"Failed to parse response for step {step}")
# return step_output
# except Exception as e:
# error_msg = f"Error parsing response for step {step}: {e}"
# step_output.observation = error_msg
# log.error(error_msg)
# printc(error_msg, color="red")
# return step_output

# connecting two generators in the computation graph, it will set up self.step_history
if isinstance(response, Parameter):

Expand All @@ -426,49 +376,20 @@ def map_fn(
f"Error: {x} does not have full_response attribute."
)

def map_fn2(x: Parameter) -> GeneratorOutput:
if x and hasattr(x, "full_response"):
return x.full_response
else:
raise ValueError(
f"Error: {x} does not have full_response attribute."
)

def map_parameter_to_step_output(x: Parameter) -> StepOutput:
if x and x.data:
return x.data
else:
raise ValueError(f"Error: {x} does not have data attribute.")

# Bind `step_output` to a specific value using partial
preinitialized_map_fn = partial(map_fn, step_output=step_output)
# execute the function and get the output
# response.add_successor_map_fn(
# successor=self.generator_output_to_step_output, map_fn=map_fn2
# )
# output = self.generator_output_to_step_output.forward(
# response, step_output, step, self._execute_action
# )
# # add the output to the step history
# output.add_successor_map_fn(
# successor=self.append_step_history, map_fn=map_parameter_to_step_output
# )

# # connect response to append_step_history
response.add_successor_map_fn(
successor=self.append_step_history, map_fn=preinitialized_map_fn
)

# # call self.append_step_history with the response
# self.step_history = self.append_step_history.forward(
# output, self.step_history
# )
step_history = self.append_step_history.forward(response, step_history)
# connect step_history to the next planner
step_history.add_successor_map_fn(
successor=self.planner, map_fn=lambda x: x.data
)
# printc(f"step_history 2: {self.step_history}", color="yellow")
# convert step history back to data
return step_history

Expand All @@ -477,28 +398,6 @@ def map_parameter_to_step_output(x: Parameter) -> StepOutput:
step_history.append(step_output)
return step_history

# if response_generator_output.error:
# error_msg = f"Error planning step {step}: {response_generator_output.error}"
# step_output.observation = error_msg
# log.error(error_msg)
# else:
# try:
# fun_expr: FunctionExpression = response_generator_output.data
# step_output.action = fun_expr
# log.debug(f"Step {step}: {fun_expr}")

# if step_output and step_output.action:
# step_output = self._execute_action(step_output)
# printc(f"Step {step}: \n{step_output}\n_______\n", color="blue")
# else:
# log.error(f"Failed to parse response for step {step}")
# except Exception as e:
# error_msg = f"Error parsing response for step {step}: {e}"
# step_output.observation = error_msg
# log.error(error_msg)

return response

def _check_last_step(
self, step_history: Union["Parameter", List[str]] = None
) -> bool:
Expand Down Expand Up @@ -532,13 +431,6 @@ def _get_answer(
last_step: StepOutput = None
if isinstance(step_history, Parameter):
try:

# last_step = self.step_history.add_successor_map_fn(
# self.step_history_to_output, map_fn=lambda x: x.data
# )
# # self.step_history.draw_graph()
# last_step = self.step_history_to_output.forward(self.step_history)
# printc(f"last_step: {last_step.data}", color="yellow")
return step_history

except Exception as e:
Expand Down Expand Up @@ -601,26 +493,18 @@ def bicall(
step_history = self._run_one_step(
step, prompt_kwargs, model_kwargs, id, step_history
)
# if (
# self.step_history[-1].function
# and self.step_history[-1].function.name == "finish"
# ):

if self._check_last_step(step_history):
break
# if self._is_step_output_last_step(step_output):
# break

except Exception as e:
log.error(f"Error running step {step}: {e}")

# answer = self.step_history[-1].observation
# answer = self._get_answer()
answer = self._get_answer(step_history)
if self.training:
return answer
# wrap the output
output = ReActOutput(step_history=step_history, id=id, answer=answer)
# printc(f"answer:\n {answer}", color="green")
# log.info(f"step_history: {self.step_history}")
return output

def _extra_repr(self) -> str:
Expand Down
46 changes: 33 additions & 13 deletions adalflow/adalflow/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from adalflow.utils.cache import CachedEngine
from adalflow.tracing.callback_manager import CallbackManager
from adalflow.utils.global_config import get_adalflow_default_root_path
from adalflow.core.string_parser import ListParser
from adalflow.core.string_parser import JsonParser


from adalflow.optim.text_grad.backend_engine_prompt import (
Expand Down Expand Up @@ -628,6 +628,10 @@ def backward(
log.debug(
f"Generator: Backward engine is set for the generator. {backward_engine}"
)
if response.backward_engine_disabled:
for pred in children_params:
pred.backward_engine_disabled = True
return
if not all_pred_at_once:
for pred in children_params:
if not pred.requires_opt or pred.param_type == ParameterType.DEMOS:
Expand All @@ -646,16 +650,22 @@ def backward(
is_intermediate_node=is_intermediate_node,
)
else:
# 2nd approach, backward all that need opt at once.
self._backward_through_all_predecessors(
children_params=children_params,
response=response,
prompt_kwargs=prompt_kwargs,
template=template,
backward_engine=backward_engine,
prompt_str=prompt_str,
is_intermediate_node=is_intermediate_node,
)
backward = False
for pred in children_params:
if pred.requires_opt and pred.param_type == ParameterType.PROMPT:
backward = True
break
if backward:
# 2nd approach, backward all that need opt at once.
self._backward_through_all_predecessors(
children_params=children_params,
response=response,
prompt_kwargs=prompt_kwargs,
template=template,
backward_engine=backward_engine,
prompt_str=prompt_str,
is_intermediate_node=is_intermediate_node,
)
else:
log.debug("Backward engine is not set for the generator. No text gradient.")

Expand All @@ -669,7 +679,7 @@ def _backward_through_all_predecessors(
prompt_str: str,
is_intermediate_node: bool = False,
):
parser = ListParser()
parser = JsonParser()
# instruction and objective is the same for all the children
instruction_str, objective_str = None, None

Expand Down Expand Up @@ -748,6 +758,7 @@ def _backward_through_all_predecessors(
gradient_output: GeneratorOutput = backward_engine(
prompt_kwargs=backward_engine_prompt_kwargs
)
print(f"gradient_output: {gradient_output}")
if not isinstance(gradient_output, GeneratorOutput):
raise ValueError(
f"Generator: Backward Engine should return a GeneratorOutput. Got {gradient_output} instead."
Expand All @@ -767,6 +778,8 @@ def _backward_through_all_predecessors(
response_gradient_list = [failure_message] * len(children_params)
print(f"failure_message: {failure_message}")

print(f"response_gradient_list: {response_gradient_list}")

# generate the gradient for each child
for i, pred in enumerate(children_params):
if not pred.requires_opt or pred.param_type == ParameterType.DEMOS:
Expand All @@ -775,8 +788,15 @@ def _backward_through_all_predecessors(
)
continue

gradient_data = (
response_gradient_list[i]
if response_gradient_list and len(response_gradient_list) > i
else "Failed to get the gradient."
)
print(f"i: {i}, gradient_data: {gradient_data}")

var_gradient = Gradient(
data=response_gradient_list[i],
data=gradient_data,
data_id=response.data_id,
score=response._score, # add score to gradient
from_response=response,
Expand Down
5 changes: 5 additions & 0 deletions adalflow/adalflow/optim/grad_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ def backward(self, *, response: "Parameter", id: str = None, **kwargs):
if response.get_gradient_and_context_text().strip() == "":
log.info(f"Generator: Backward: No gradient found for {response}.")

# backward the backward engine disable signal
if response.backward_engine_disabled:
for pred in children_params:
pred.backward_engine_disabled = True

for pred in children_params:
pred.set_score(response._score)
from adalflow.utils.logger import printc
Expand Down
9 changes: 5 additions & 4 deletions adalflow/adalflow/optim/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,7 @@ class ComponentNode(DataClass):
{% endfor %}
</COMPONENT_SCHEMA>
{% endif %}
<DESCRIPTION>
If the same DataID has multiple gradients, it means this component/variable is called multiple times in the same order as it appears in the gradient list.
Use this info to have more clarity while reasoning and proposing new variables.
</DESCRIPTION>
{% if combined_gradients %}
{% for group in combined_gradients %}
<DataID: {{ group.data_id }}>
Expand All @@ -160,9 +157,11 @@ class ComponentNode(DataClass):
{% endfor %}
</DataID>
{% endfor %}
{% endif %}
"""
# Use this info to have more clarity while reasoning and proposing new variables.

# {% if combined_gradients %}
# {# Group gradients by data_id #}
Expand Down Expand Up @@ -1729,6 +1728,8 @@ def __init__(
self.to_pred_id = to_pred.id
self.score = score
self.data_id = data_id
if self.data_id is None:
raise ValueError("The data_id should not be None.")
self.data = data
self.order = None

Expand Down
Loading

0 comments on commit f8d006c

Please sign in to comment.