Skip to content

Commit

Permalink
change the gradcomponent to one that can do both pass through and nor…
Browse files Browse the repository at this point in the history
…mal graidents
  • Loading branch information
liyin2015 committed Jan 10, 2025
1 parent ee81bb7 commit 4060b18
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 48 deletions.
20 changes: 8 additions & 12 deletions adalflow/adalflow/components/agent/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


from adalflow.core.generator import Generator
from adalflow.optim.grad_component import GradComponent
from adalflow.optim.grad_component import GradComponent2
from adalflow.optim.parameter import Parameter, ParameterType
from adalflow.core.func_tool import FunctionTool, AsyncCallable
from adalflow.core.tool_manager import ToolManager
Expand Down Expand Up @@ -130,11 +130,9 @@ def map_step_history_list_to_prompt(x: Parameter) -> str:
return "\n".join(output)


class AppendStepHistory(GradComponent):
class AppendStepHistory(GradComponent2):
def __init__(self):
super().__init__()
self.name = "AppendStepHistory"
self._component_desc = "Append the step_output to the step_history."
super().__init__(desc="Append the step_output to the step_history.")

def call(
self, step_output: StepOutput, step_history: List[StepOutput]
Expand All @@ -154,11 +152,9 @@ def forward(self, *args, **kwargs) -> Parameter:
return output


class FunctionOutputToStepOutput(GradComponent):
class FunctionOutputToStepOutput(GradComponent2):
def __init__(self):
super().__init__()
self.name = "FunctionOutputToStepOutput"
self._component_desc = "Convert the FunctionOutput to StepOutput"
super().__init__(desc="Convert the FunctionOutput to StepOutput")

def call(
self,
Expand Down Expand Up @@ -368,7 +364,7 @@ def set_step_output_with_error(
step_output: StepOutput, error: str, response: Any
):
"""Set the step_output with error."""
step_output.observation = f"erro: {error} at {response.data}"
step_output.observation = f"error: {error} at {response.data}"
return step_output

response.add_successor_map_fn(
Expand Down Expand Up @@ -418,7 +414,6 @@ def set_step_output_with_error(
return handle_error(response, e)

try:

# printc(f"func: {func}", color="yellow")
# replace the id
if isinstance(func, Parameter):
Expand Down Expand Up @@ -497,7 +492,7 @@ def _execute_action_eval_mode(
id=None,
) -> StepOutput:
"""Execute the action and update the step_output."""
if x.error:
if x.error or not x.data:
error_msg = f"Error planning step {step}: {x.error}"
step_output.observation = error_msg
step_output.action = None
Expand All @@ -506,6 +501,7 @@ def _execute_action_eval_mode(
else:
try:
fun_expr: FunctionExpression = x.data
printc(f"Step {step}: {fun_expr}", color="blue")
step_output.action = fun_expr
log.debug(f"Step {step}: {fun_expr}")

Expand Down
8 changes: 5 additions & 3 deletions adalflow/adalflow/core/func_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from adalflow.core import Component
from adalflow.optim.parameter import Parameter
from adalflow.optim.grad_component import GradComponent
from adalflow.optim.grad_component import GradComponent2
from adalflow.core.functional import (
get_fun_schema,
)
Expand Down Expand Up @@ -59,7 +59,7 @@ def find_instance_name_from_self(instance):


# TODO: improve the support for async functions, similarly a component might be used as a tool
class FunctionTool(GradComponent):
class FunctionTool(GradComponent2):
__doc__ = r"""Describing and executing a function via call with arguments.
Expand Down Expand Up @@ -116,7 +116,9 @@ def __init__(
component: Optional[Component] = None,
definition: Optional[FunctionDefinition] = None,
):
super().__init__()
super().__init__(
name="FunctionTool", desc="A component calls and executes a function."
)
nest_asyncio.apply()
assert fn is not None, "fn must be provided"

Expand Down
33 changes: 17 additions & 16 deletions adalflow/adalflow/core/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import warnings

from adalflow.core.container import ComponentList
from adalflow.optim.grad_component import GradComponent
from adalflow.optim.grad_component import GradComponent2
from adalflow.core.component import Component
from adalflow.core.func_tool import FunctionTool
from adalflow.core.types import (
Expand Down Expand Up @@ -107,9 +107,9 @@ def dummy_pass_through_for_untrainable_fn(output, func):
return output


class FunctionExperssionToFunction(GradComponent):
class FunctionExperssionToFunction(GradComponent2):
def __init__(self):
super().__init__()
super().__init__(desc="Convert FunctionExpression to Function")

def call(self, expr: FunctionExpression, context: Dict[str, object]) -> Function:

Expand Down Expand Up @@ -227,21 +227,21 @@ def parse_func_expr(
r"""Parse the function call expression."""

if isinstance(expr, Parameter):
try:
# try:

func = FunctionExperssionToFunction()
expr.add_successor_map_fn(func, map_fn=map_fn)
# print("FunctionExperssionToFunction")
output = func.forward(expr, context=self.context)
# print(f"output data: {output.data}")
return output
func = FunctionExperssionToFunction()
expr.add_successor_map_fn(func, map_fn=map_fn)
# print("FunctionExperssionToFunction")
output = func.forward(expr, context=self.context)
# print(f"output data: {output.data}")
return output

except Exception as e:
error_msg = (
f"Error {e} parsing function call expression: {map_fn(expr)}"
)
return error_msg
else:
# except Exception as e:
# error_msg = (
# f"Error {e} parsing function call expression: {map_fn(expr)}"
# )
# return error_msg
# else:
try:
expr_str = expr.action
func_name, args, kwargs = parse_function_call_expr(
Expand Down Expand Up @@ -278,6 +278,7 @@ def call(
expr_or_fun: Union[FunctionExpression, Function],
step: Literal["execute"] = "execute",
) -> Union[FunctionOutput, Function, Parameter]:
print(f"self.training: {self.training}, expr_or_fun: {expr_or_fun}")
if not isinstance(expr_or_fun, (Function, FunctionExpression)):
raise ValueError(
f"expr_or_fun should be either a Function or FunctionExpression. Got {expr_or_fun}"
Expand Down
7 changes: 4 additions & 3 deletions adalflow/adalflow/optim/grad_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ class GradComponent2(GradComponent):

def __init__(
self,
name: str,
desc: str,
name: Optional[str] = None,
backward_engine: Optional["BackwardEngine"] = None,
model_client: "ModelClient" = None,
model_kwargs: Dict[str, object] = None,
Expand Down Expand Up @@ -532,15 +532,16 @@ def backward(self, *, response: "OutputParameter", id: str = None, **kwargs):
for pred in children_params:
pred.backward_engine_disabled = True

if not self.backward_engine:
# use pass through gradient when there is one predecessor
if not self.backward_engine or len(children_params) < 2:
super().backward(response=response, id=id)

else:

for _, pred in enumerate(children_params):
if response.score is not None:
pred.set_score(response.score)
printc(f"score score for pred name: {pred.name}")
printc(f"score {response.score} for pred name: {pred.name}")
if not pred.requires_opt:
continue

Expand Down
22 changes: 17 additions & 5 deletions adalflow/adalflow/optim/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from adalflow.optim.types import ParameterType
from adalflow.core.base_data_class import DataClass
from adalflow.utils.logger import printc
import html


if TYPE_CHECKING:
from adalflow.optim.text_grad.tgd_optimizer import TGDData, TGDOptimizerTrace
Expand Down Expand Up @@ -1229,19 +1231,29 @@ def draw_output_subgraph(
node_ids = set()

for node in nodes:
escaped_name = html.escape(node.name if node.name else "")
escaped_param_type = html.escape(
node.param_type.name if node.param_type else ""
)
escaped_value = html.escape(
node.get_short_value() if node.get_short_value() else ""
)

node_label = f"""
<table border="0" cellborder="1" cellspacing="0">
<tr><td><b>Name:</b></td><td>{node.name}</td></tr>
<tr><td><b>Type:</b></td><td>{node.param_type}</td></tr>
<tr><td><b>Value:</b></td><td>{node.get_short_value()}</td></tr>"""
<tr><td><b>Name:</b></td><td>{escaped_name}</td></tr>
<tr><td><b>Type:</b></td><td>{escaped_param_type}</td></tr>
<tr><td><b>Value:</b></td><td>{escaped_value}</td></tr>"""
# add the component trace id and name
if hasattr(node, "component_trace") and node.component_trace.id is not None:
node_label += f"<tr><td><b>Component Trace ID:</b></td><td>{node.component_trace.id}</td></tr>"
escaped_ct_id = html.escape(str(node.component_trace.id))
node_label += f"<tr><td><b>Component Trace ID:</b></td><td>{escaped_ct_id}</td></tr>"
if (
hasattr(node, "component_trace")
and node.component_trace.name is not None
):
node_label += f"<tr><td><b>Component Trace Name:</b></td><td>{node.component_trace.name}</td></tr>"
escaped_ct_name = html.escape(str(node.component_trace.name))
node_label += f"<tr><td><b>Component Trace Name:</b></td><td>{escaped_ct_name}</td></tr>"

node_label += "</table>"
dot.node(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _backward_through_one_predecessor(
input_output=conversation_str,
response_desc=response.role_desc,
variable_desc=pred.role_desc,
input=input,
# input=input,
# ground_truth=ground_truth,
)
)
Expand Down
11 changes: 7 additions & 4 deletions adalflow/adalflow/optim/text_grad/tgd_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class HistoryPrompt(DataClass):
1. **Address Feedback**: Resolve concerns raised in the feedback while preserving the positive aspects of the original variable.
2. Observe past performance patterns (when available) to retain good qualities in the variable.
3. **System Awareness**: When other system variables are given, ensure you understand how this variable works in the whole system.
You have a choice to not update a variable if it is not responsible for the error. Just keep the `update` field as `False`.
You have a choice to not update a variable if it is not responsible for the error by setting `update: false` and `proposed_variable: None`.
You MUST not update variable when there is no clear error indicated in a multi-component system.
4. **Peer Awareness**: This variable works together with Peer variables, ensure you are aware of their roles and constraints.
5. Be Creative. If adding new elements, be concise.
Expand Down Expand Up @@ -280,15 +280,18 @@ class TGDData(DataClass):
"desc": "Which solution did you choose, which prompt engineering technique did you use? Why? Be Concise (maximum 2 sentences)"
}
)
proposed_variable: str = field(
metadata={"desc": "The proposed variable"}, default=None
)
update: bool = field(
default=True,
metadata={
"desc": "Depending on the feedback, update the variable if it is responsible for the error, else, keep it"
},
)
proposed_variable: str = field(
metadata={
"desc": "The proposed variable, ignoring the field when update: false"
},
default=None,
)


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions adalflow/adalflow/optim/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class ParameterType(Enum):
3. Third element: whether the parameter is trainable.
To access each element, use the following:
1. name: `ParameterType.PROMPT.name`
1. name: `ParameterType.PROMPT.value`
2. description: `ParameterType.PROMPT.description`
3. trainable: `ParameterType.PROMPT.trainable`
3. trainable: `ParameterType.PROMPT.default_trainable`
"""

# trainable parameters with optimizers
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ class AgenticRAG(adal.GradComponent):
def __init__(self, model_client, model_kwargs):
super().__init__()

self.dspy_retriever = DspyRetriever(top_k=3)
self.dspy_retriever = DspyRetriever(top_k=2)
# self.llm_parser = adal.DataClassParser(
# data_class=AnswerData, return_data_class=True, format_type="json"
# )
Expand Down Expand Up @@ -756,7 +756,7 @@ def dspy_retriever_as_tool(
# context_variables: Dict,
id: Optional[str] = None,
) -> List[str]:
r"""Retrieves the top k passages from using input as the query.
r"""Retrieves the top 2 passages from using input as the query.
Ensure you get all the context to answer the original question.
"""
output = self.dspy_retriever(input=input, id=id)
Expand Down

0 comments on commit 4060b18

Please sign in to comment.