From fadcf35157242a57b0b424c73db369f2fc53b788 Mon Sep 17 00:00:00 2001 From: Li Yin Date: Sat, 21 Dec 2024 12:26:53 -0800 Subject: [PATCH] separate the output parameter out from the base parameter which handles the component trace, and make the backward engine with only eval mode and without output a parameter --- adalflow/adalflow/core/generator.py | 27 ++- adalflow/adalflow/optim/grad_component.py | 4 +- adalflow/adalflow/optim/loss_component.py | 3 + adalflow/adalflow/optim/parameter.py | 188 ++++++++++++++---- adalflow/adalflow/optim/text_grad/ops.py | 10 +- .../optim/text_grad/text_loss_with_eval_fn.py | 55 ++--- .../adalflow/optim/text_grad/tgd_optimizer.py | 2 +- adalflow/adalflow/optim/trainer/trainer.py | 8 +- .../hotpot_qa/adal_exp/build_multi_hop_rag.py | 2 +- .../adal_exp/train_multi_hop_rag_cycle.py | 10 +- use_cases/classification/train.py | 3 + .../trec_task_structured_output.py | 14 +- 12 files changed, 217 insertions(+), 109 deletions(-) diff --git a/adalflow/adalflow/core/generator.py b/adalflow/adalflow/core/generator.py index 8ae4f196..d8f2f0ed 100644 --- a/adalflow/adalflow/core/generator.py +++ b/adalflow/adalflow/core/generator.py @@ -21,7 +21,12 @@ from adalflow.core.base_data_class import DataClass -from adalflow.optim.parameter import Parameter, GradientContext, Gradient +from adalflow.optim.parameter import ( + Parameter, + GradientContext, + Gradient, + OutputParameter, +) from adalflow.optim.types import ParameterType from adalflow.core.prompt_builder import Prompt @@ -523,7 +528,7 @@ def forward( if output and not output.error else f"Error: {output.error}, raw_response: {output.raw_response}" ) - response: Parameter = Parameter( + response: Parameter = OutputParameter( data=param_data, name=self.name + "_output", role_desc=f"Output from (llm) {self.name}", @@ -740,6 +745,11 @@ def _backward_through_one_predecessor( gradient_output: GeneratorOutput = backward_engine( prompt_kwargs=backward_engine_prompt_kwargs ) + if not isinstance(gradient_output, GeneratorOutput): + raise ValueError( + f"Generator: Backward Engine should return a GeneratorOutput. Got {gradient_output} instead." + ) + # USE this to trace each node's input and output, all nodes can be visualized log.info( f"Generator Backward Engine Prompt: {backward_engine.get_prompt( **backward_engine_prompt_kwargs)}" @@ -748,9 +758,6 @@ def _backward_through_one_predecessor( gradient_output.data or backward_engine.failure_message_to_optimizer(gradient_output) ) - log.info( - f"Generator Gradient value: {gradient_value}, raw response: {gradient_output.raw_response}" - ) # TODO: make it a debug feature var_gradient = Gradient( data=gradient_value, @@ -947,7 +954,11 @@ class BackwardEngine(Generator): # it is a generator with defaule template __doc__ = """The backward engine is a Generator with a default template for the backward pass. - If you want to customize the template, you can create your own backward engine""" + If you want to customize the template, you can create your own backward engine. + + Yet, we will forever keep the training mode to False for the backward engine. + This is achieved by making forward the same as call. + """ def __init__(self, **kwargs): if kwargs is None: @@ -965,6 +976,10 @@ def call(self, **kwargs) -> GeneratorOutputType: raise ValueError(f"Error in the backward engine: {output.error}") return output + def forward(self, **kwargs): + r"""Forward pass for the backward engine.""" + return self.call(**kwargs) + @staticmethod def failure_message_to_optimizer( gradient_response: GeneratorOutput, diff --git a/adalflow/adalflow/optim/grad_component.py b/adalflow/adalflow/optim/grad_component.py index 07249a64..14f77b3f 100644 --- a/adalflow/adalflow/optim/grad_component.py +++ b/adalflow/adalflow/optim/grad_component.py @@ -64,7 +64,7 @@ def forward(self, *args, **kwargs) -> "Parameter": 3. Return the parameter object. """ - from adalflow.optim.parameter import Parameter + from adalflow.optim.parameter import Parameter, OutputParameter log.debug( f"Forwarding through {self.name} with args: {args} and kwargs: {kwargs}" @@ -122,7 +122,7 @@ def forward(self, *args, **kwargs) -> "Parameter": # 4. Create a Parameter object to trace the forward pass input_args.update(kwargs) - response = Parameter( + response = OutputParameter( data=call_response, name=self.name + "_output", role_desc=self.name + " response", diff --git a/adalflow/adalflow/optim/loss_component.py b/adalflow/adalflow/optim/loss_component.py index e53ac609..cd773653 100644 --- a/adalflow/adalflow/optim/loss_component.py +++ b/adalflow/adalflow/optim/loss_component.py @@ -1,6 +1,7 @@ """Base class for Autograd Components that can be called and backpropagated through.""" from typing import TYPE_CHECKING +import uuid if TYPE_CHECKING: from adalflow.core.generator import BackwardEngine @@ -27,10 +28,12 @@ class LossComponent(Component): """ backward_engine: "BackwardEngine" _component_type = "loss" + id = None def __init__(self, *args, **kwargs): super().__init__() super().__setattr__("backward_engine", None) + super().__setattr__("id", str(uuid.uuid4())) def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) diff --git a/adalflow/adalflow/optim/parameter.py b/adalflow/adalflow/optim/parameter.py index de491521..70052019 100644 --- a/adalflow/adalflow/optim/parameter.py +++ b/adalflow/adalflow/optim/parameter.py @@ -46,7 +46,7 @@ class GradientContext: @dataclass(frozen=True) -class ComponentNode: +class ComponentNode(DataClass): """Used to represent a node in the component graph.""" id: str = field(metadata={"desc": "The unique id of the component"}) @@ -57,7 +57,7 @@ class ComponentNode: @dataclass -class ComponentTrace: +class ComponentTrace(DataClass): name: str = field(metadata={"desc": "The name of the component"}, default=None) id: str = field(metadata={"desc": "The unique id of the component"}, default=None) input_args: Dict[str, Any] = field( @@ -67,6 +67,9 @@ class ComponentTrace: full_response: object = field( metadata={"desc": "The full response of the GradComponent output"}, default=None ) + raw_response: str = field( + metadata={"desc": "The raw response of the generator"}, default=None + ) api_kwargs: Dict[str, Any] = field( metadata={ "desc": "The api_kwargs for components like Generator and Retriever that pass to the model client" @@ -172,8 +175,8 @@ class Parameter(Generic[T]): predecessors: Set["Parameter"] = set() # Predecessors of the parameter peers: Set["Parameter"] = set() # Peers of the parameter # TODO: input_args should be OrderedDict to keep the order of args - input_args: Dict[str, Any] = None # Input arguments of the GradComponent forward - full_response: object = None # Full response of the GradComponent output + # input_args: Dict[str, Any] = None # Input arguments of the GradComponent forward + # full_response: object = None # Full response of the GradComponent output eval_input: object = None # Eval input passing to the eval_fn or evaluator you use successor_map_fn: Dict[str, Callable] = ( None # Map function to get the data from the output @@ -183,7 +186,7 @@ class Parameter(Generic[T]): False # Disable the backward engine for the parameter ) - component_trace: ComponentTrace = None # Trace of the component + # component_trace: ComponentTrace = None # Trace of the component tgd_optimizer_trace: "TGDOptimizerTrace" = None # Trace of the TGD optimizer def __init__( @@ -196,7 +199,7 @@ def __init__( role_desc: str = "", param_type: ParameterType = ParameterType.NONE, name: str = None, # name is used to refer to the parameter in the prompt, easier to read for humans - raw_response: str = None, # use this to track the raw response of generator instead of the data (can be parsed) + # raw_response: str = None, # use this to track the raw response of generator instead of the data (can be parsed) instruction_to_optimizer: str = None, instruction_to_backward_engine: str = None, score: Optional[float] = None, @@ -228,7 +231,7 @@ def __init__( self.previous_data = None # used to store the previous data # context of the forward pass - self.raw_response = raw_response + # self.raw_response = raw_response self.instruction_to_optimizer: str = instruction_to_optimizer self.instruction_to_backward_engine: str = instruction_to_backward_engine @@ -247,9 +250,9 @@ def __init__( self._previous_demos: List[DataClass] = [] self.eval_input = eval_input - self.from_response_id = from_response_id # for gradient parameter + # self.from_response_id = from_response_id # for gradient parameter self.successor_map_fn = successor_map_fn or {} - self.component_trace = ComponentTrace() + # self.component_trace = ComponentTrace() def map_to_successor(self, successor: object) -> T: """Apply the map function to the successor based on the successor's id.""" @@ -385,21 +388,21 @@ def trace_optimizer(self, api_kwargs: Dict[str, Any], response: "TGDData"): ############################################################################################################ # Trace component, include trace_forward_pass & trace_api_kwargs for now ############################################################################################################ - def trace_forward_pass( - self, - input_args: Dict[str, Any], - full_response: object, - id: str = None, - name: str = None, - ): - r"""Trace the forward pass of the parameter. Adding the component information to the trace""" - self.input_args = input_args - self.full_response = full_response - # TODO: remove the input_args and full_response to use component_trace - self.component_trace.input_args = input_args - self.component_trace.full_response = full_response - self.component_trace.id = id - self.component_trace.name = name + # def trace_forward_pass( + # self, + # input_args: Dict[str, Any], + # full_response: object, + # id: str = None, + # name: str = None, + # ): + # r"""Trace the forward pass of the parameter. Adding the component information to the trace""" + # self.input_args = input_args + # self.full_response = full_response + # # TODO: remove the input_args and full_response to use component_trace + # self.component_trace.input_args = input_args + # self.component_trace.full_response = full_response + # self.component_trace.id = id + # self.component_trace.name = name def trace_api_kwargs(self, api_kwargs: Dict[str, Any]): r"""Trace the api_kwargs for components like Generator and Retriever that pass to the model client.""" @@ -852,7 +855,11 @@ def wrap_and_escape(text, width=40): node_label += f"Requires Optimization: {{'Yes'}}" if n.param_type: node_label += f"Type: {wrap_and_escape(n.param_type.name)}" - if full_trace and n.component_trace.api_kwargs is not None: + if ( + full_trace + and hasattr(n, "component_trace") + and n.component_trace.api_kwargs is not None + ): node_label += f" API kwargs: {wrap_and_escape(str(n.component_trace.api_kwargs))}" # show the score for intermediate nodes @@ -865,6 +872,8 @@ def wrap_and_escape(text, width=40): # combined_gradients_contexts = zip( # n.gradients, [n.gradients_context[g] for g in n.gradients] # ) + # if "output" in n.name: + print(f"Node: {n.name}, \n gradients: {n.gradients}") for g in n.gradients: gradient_context = g.context log.info(f"Gradient context display: {gradient_context}") @@ -881,9 +890,9 @@ def wrap_and_escape(text, width=40): node_label += f"TGD Optimizer Trace: {wrap_and_escape(str(n.tgd_optimizer_trace))}" # show component trace, id and name - if n.component_trace.id is not None: + if hasattr(n, "component_trace") and n.component_trace.id is not None: node_label += f"Component Trace ID: {wrap_and_escape(str(n.component_trace.id))}" - if n.component_trace.name is not None: + if hasattr(n, "component_trace") and n.component_trace.name is not None: node_label += f"Component Trace Name: {wrap_and_escape(str(n.component_trace.name))}" node_label += "" @@ -1189,6 +1198,8 @@ def traverse(node: "Parameter"): # Traverse predecessors and add edges for pred in node.predecessors: + if pred.param_type != ParameterType.OUTPUT: + continue pred_id = pred.component_trace.id or f"unknown_id_{uuid.uuid4()}" pred_name = pred.component_trace.name or "Unknown Component" @@ -1240,10 +1251,8 @@ def to_dict(self): "grad_fn": str( self.grad_fn ), # Simplify for serialization, modify as needed - "raw_response": self.raw_response, "score": self._score, "traces": {k: v.to_dict() for k, v in self._traces.items()}, - "input_args": self.input_args, # demos "demos": [d.to_dict() for d in self._demos], } @@ -1261,8 +1270,6 @@ def from_dict(cls, data: dict): predecessors=predecessors, gradients=[cls.from_dict(grad) for grad in data["gradients"]], previous_data=data["previous_data"], - raw_response=data["raw_response"], - input_args=data["input_args"], score=data["score"], # demos demos=[DataClass.from_dict(d) for d in data["demos"]], @@ -1274,12 +1281,14 @@ def from_dict(cls, data: dict): # TODO: very hard to read directly, need to simplify and let users use to_dict for better readability def __repr__(self): return f"Parameter(name={self.name}, requires_opt={self.requires_opt}, param_type={self.param_type}, role_desc={self.role_desc}, data={self.data}, predecessors={self.predecessors}, gradients={self.gradients},\ - raw_response={self.raw_response}, input_args={self.input_args}, traces={self._traces})" + traces={self._traces})" # TODO: separate the Parameter class into different classes and each class will have its own methods instead of all in one class class InputParameter(Parameter): - """One of the simplest types of parameters, representing an input to the system.""" + """One of the simplest types of parameters, representing an input to the system. + Input parameter will not be trainable, but serves a tracing purpose in the computation graph. + """ def __init__( self, @@ -1357,22 +1366,92 @@ def __init__( class OutputParameter(Parameter): + __doc__ = r"""The output parameter is the most complex type of parameter in the system. + + It will trace the predecessors, set up a grad_fn, store gradients, and trace the forward pass by tracking the component_trace. + """ + component_trace: ComponentTrace = ( + None # Trace of the component that produced this output + ) def __init__( self, - name: str, - role_desc: str, - data: Any, + *, + id: Optional[str] = None, # unique id of the parameter + data: T = None, # for generator output, the data will be set up as raw_response + data_id: str = None, # for tracing the data item in the training/val/test set requires_opt: bool = True, + role_desc: str = "", param_type: ParameterType = ParameterType.OUTPUT, + name: str = None, # name is used to refer to the parameter in the prompt, easier to read for humans + instruction_to_optimizer: str = None, + instruction_to_backward_engine: str = None, + score: Optional[float] = None, + eval_input: object = None, + from_response_id: Optional[str] = None, + successor_map_fn: Optional[Dict[str, Callable]] = None, ): super().__init__( - name=name, - role_desc=role_desc, + id=id, data=data, + data_id=data_id, requires_opt=requires_opt, + role_desc=role_desc, param_type=param_type, + name=name, + instruction_to_optimizer=instruction_to_optimizer, + instruction_to_backward_engine=instruction_to_backward_engine, + score=score, + eval_input=eval_input, + from_response_id=from_response_id, + successor_map_fn=successor_map_fn, ) + self.component_trace = ComponentTrace() + + ############################################################################################################ + # Trace component, include trace_forward_pass & trace_api_kwargs for now + ############################################################################################################ + def trace_forward_pass( + self, + input_args: Dict[str, Any], + full_response: object, + id: str = None, + name: str = None, + ): + r"""Trace the forward pass of the parameter. Adding the component information to the trace""" + self.input_args = input_args + self.full_response = full_response + # TODO: remove the input_args and full_response to use component_trace + self.component_trace.input_args = input_args + self.component_trace.full_response = full_response + self.component_trace.id = id + self.component_trace.name = name + + def trace_api_kwargs(self, api_kwargs: Dict[str, Any]): + r"""Trace the api_kwargs for components like Generator and Retriever that pass to the model client.""" + self.component_trace.api_kwargs = api_kwargs + + def to_dict(self): + super_dict = super().to_dict() + super_dict.update( + { + "component_trace": self.component_trace.to_dict(), + } + ) + + @classmethod + def from_dict(cls, data: dict): + component_trace = ComponentTrace.from_dict(data["component_trace"]) + return super().from_dict(data).update({"component_trace": component_trace}) + + def __repr__(self): + super_repr = super().__repr__() + # replace first Parameter with OutputParameter + super_repr = super_repr.replace("Parameter", "OutputParameter") + return super_repr + + +# gradients= List[Gradient] @dataclass @@ -1380,11 +1459,16 @@ class Gradient(DataClass): __doc__ = r"""It will handle gradients and feedbacks. It tracks the d_from_response_id / d_to_pred_id and the score of the whole response. - """ + if two gradients have the same data_id, different from_response_id, and same from_response_component_id, this is a cycle component structure. + """ + data_id: Optional[str] = None # the id of the response from data in the dataset from_response_id: str = ( None # the id of the response from which the gradient is calculated ) + from_response_component_id: str = ( + None # the id of the component from which the gradient is calculated + ) to_pred_id: str = ( None # the id of the parameter to which the gradient is calculated and attached to d(from_response_id) / d(to_pred_id) ) @@ -1393,7 +1477,6 @@ class Gradient(DataClass): context: GradientContext = None data: Any = None - data_id: Optional[str] = None # the id of the response from data in the dataset prompt: Optional[str] = None # the LLM prompt to generate the gradient def __init__( @@ -1408,6 +1491,11 @@ def __init__( ): self.id = id or str(uuid.uuid4()) self._generate_name(from_response, to_pred) + self.from_response_component_id = from_response.component_trace.id + if not self.from_response_component_id: + raise ValueError( + "The from_response_component_id should not be None. Please ensure the component_trace is set." + ) self.from_response_id = from_response.id self.to_pred_id = to_pred.id self.score = score @@ -1426,3 +1514,23 @@ def add_data(self, data: Any): def add_prompt(self, prompt: str): self.prompt = prompt + + +# Move the gradients representation to this class. +@dataclass +class Gradients(DataClass): + gradients: List[Gradient] + + def __init__(self, gradients: List[Gradient]): + self.gradients = gradients + + def to_dict(self): + return {"gradients": [g.to_dict() for g in self.gradients]} + + @classmethod + def from_dict(cls, data: dict): + gradients = [Gradient.from_dict(g) for g in data["gradients"]] + return cls(gradients) + + def __repr__(self): + return f"Gradients(gradients={self.gradients})" diff --git a/adalflow/adalflow/optim/text_grad/ops.py b/adalflow/adalflow/optim/text_grad/ops.py index d685f081..89833ec3 100644 --- a/adalflow/adalflow/optim/text_grad/ops.py +++ b/adalflow/adalflow/optim/text_grad/ops.py @@ -4,7 +4,7 @@ import logging from adalflow.optim.function import BackwardContext -from adalflow.optim.parameter import Parameter, Gradient +from adalflow.optim.parameter import Parameter, Gradient, OutputParameter from adalflow.optim.types import ParameterType from adalflow.optim.grad_component import GradComponent @@ -58,7 +58,7 @@ def forward(self, params: List[Parameter]) -> Parameter: role_descriptions = set([p.role_desc for p in params]) role_descriptions = ", ".join(role_descriptions) - total = Parameter( + total = OutputParameter( data=concat_values, role_desc=f"A combination of a list of variables: {role_descriptions}", requires_opt=any([p.requires_opt for p in params]), @@ -67,6 +67,12 @@ def forward(self, params: List[Parameter]) -> Parameter: param_type=ParameterType.SUM_OUTPUT, ) total.set_predecessors(params) + total.trace_forward_pass( + input_args=params, + full_response=concat_values, + id=total.id, + name=total.name, + ) log.info("Sum forward", extra={"total": total.data}) diff --git a/adalflow/adalflow/optim/text_grad/text_loss_with_eval_fn.py b/adalflow/adalflow/optim/text_grad/text_loss_with_eval_fn.py index ba8b8081..a717144b 100644 --- a/adalflow/adalflow/optim/text_grad/text_loss_with_eval_fn.py +++ b/adalflow/adalflow/optim/text_grad/text_loss_with_eval_fn.py @@ -11,7 +11,12 @@ from adalflow.core import ModelClient from adalflow.core.generator import BackwardEngine from adalflow.core.types import GeneratorOutput -from adalflow.optim.parameter import Parameter, GradientContext, Gradient +from adalflow.optim.parameter import ( + Parameter, + GradientContext, + Gradient, + OutputParameter, +) from adalflow.optim.types import ParameterType from adalflow.core.prompt_builder import Prompt @@ -25,44 +30,6 @@ log = logging.getLogger(__name__) -### Loss/Score Information ### -# LOSS_CONVERSATION_TEMPLATE_STRING = r""" -# The variable is passed to the eval function and compared with a target/ground truth value. - -# : {{eval_fn_desc}} -# : {{input_str}} -# : {{response_value}} -# {% if metadata %} -# Note: {{metadata}} -# {% endif %}""" - - -# Does not have gradient on the output, the loss function of the backpropagation chain -# CONVERSATION_START_INSTRUCTION_STRING_FN_BASE = r"""You will give feedback to a variable with the following role: -# {{variable_desc}} . -# Here is an evaluation of the variable using the eval function: -# {{conversation}}""" - -# Has the gradient on the output, the layer in the backpropagation chain -# Conversation will be provided differently. - -# ### Variable Information ### -# CONVERSATION_START_INSTRUCTION_STRING_FN = r""" -# TARGET VARIABLE: -# {{variable_name}} -# {{variable_desc}} -# {{variable_value}} -# {{conversation_str}} -# """ - -# Third part of the user prompt -# OBJECTIVE_INSTRUCTION_BASE = r""" -# Your only goal is to clearly states how it obtained the "". -# Especially when the score is low. -# Be CONCISE. -# If you have enough context, add a more specific feedback on how it failed. -# """ - OBJECTIVE_INSTRUCTION_CHAIN = r"""This conversation is part of a larger system. The was later used as "{{response_name}}: {{response_desc}}". @@ -141,6 +108,7 @@ def forward( kwargs: Dict[str, Parameter], response_desc: str = None, metadata: Dict[str, str] = None, # additional notes on the input kwargs + id: str = None, ) -> Parameter: if response_desc is None: response_desc = "Output of EvalFnToTextLoss." @@ -161,15 +129,22 @@ def forward( # Create a parameter # TODO: improve the readability of the input and response - eval_param: Parameter = Parameter( + eval_param: Parameter = OutputParameter( name=self.name + "_output", data=score, requires_opt=True, role_desc=response_desc, score=score, param_type=ParameterType.LOSS_OUTPUT, + data_id=id, ) eval_param.set_predecessors(predesessors) + eval_param.trace_forward_pass( + input_args=kwargs, + full_response=score, + id=self.id, + name=self.name, + ) log.info(f"EvalFnToTextLoss: Input: {kwargs}, Output: {eval_param}") eval_param.set_grad_fn( diff --git a/adalflow/adalflow/optim/text_grad/tgd_optimizer.py b/adalflow/adalflow/optim/text_grad/tgd_optimizer.py index ac6fc15e..c23b8467 100644 --- a/adalflow/adalflow/optim/text_grad/tgd_optimizer.py +++ b/adalflow/adalflow/optim/text_grad/tgd_optimizer.py @@ -129,7 +129,7 @@ class HistoryPrompt(DataClass): ### Notes: - In the absence of specific feedback, you may rephrase the initial variable to improve clarity or specificity without altering its core meaning. -- When feedback is provided, refine the variable with more detailed instructions or adjustments to directly address the feedback. +- When specific feedback is provided, you can either rephrase or refine the variable with more detailed instructions or adjustments to directly or indirectly address the feedback. {{output_format_str}} diff --git a/adalflow/adalflow/optim/trainer/trainer.py b/adalflow/adalflow/optim/trainer/trainer.py index 436d96ed..877839ad 100644 --- a/adalflow/adalflow/optim/trainer/trainer.py +++ b/adalflow/adalflow/optim/trainer/trainer.py @@ -953,11 +953,9 @@ def _fit_text_grads_one_step_for_debug(self, train_loader: Any) -> Dict[str, str total_loss = sum_ops([copy(failed_loss)]) total_loss.backward() - failed_debug_files = failed_loss.draw_graph( - filepath=debug_path, full_trace=True - ) - failed_output_file = failed_loss.draw_output_subgraph(filepath=debug_path) - failed_component_file = failed_loss.draw_component_subgraph(filepath=debug_path) + failed_debug_files = total_loss.draw_graph(filepath=debug_path, full_trace=True) + failed_output_file = total_loss.draw_output_subgraph(filepath=debug_path) + failed_component_file = total_loss.draw_component_subgraph(filepath=debug_path) failed_debug_files.update(failed_output_file) failed_debug_files.update(failed_component_file) diff --git a/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py b/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py index 10ca0072..fde07f94 100644 --- a/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py +++ b/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py @@ -301,7 +301,7 @@ def forward(self, *, input: str, id: str = None) -> adal.Parameter: name="question", data=input, role_desc="The question to be answered", - requires_opt=True, + requires_opt=False, param_type=ParameterType.INPUT, ) # context_param = adal.Parameter( diff --git a/benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag_cycle.py b/benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag_cycle.py index 376ef664..1860cd13 100644 --- a/benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag_cycle.py +++ b/benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag_cycle.py @@ -11,7 +11,7 @@ # TODO: look more into the loss function # TODO: test LLM judge too. -class MultiHopRAGAdal(adal.AdalComponent): +class MultiHopRAGCycleAdal(adal.AdalComponent): def __init__( self, model_client: adal.ModelClient, @@ -73,7 +73,7 @@ def prepare_loss(self, sample: HotPotQAData, pred: adal.Parameter): and pred.full_response.data.answer else "" ) - return self.loss_fn, {"kwargs": {"y": pred, "y_gt": y_gt}} + return self.loss_fn, {"kwargs": {"y": pred, "y_gt": y_gt}, "id": sample.id} # Note: diagnose is quite helpful, it helps you to quickly check if the evalfunction is the right metrics @@ -86,7 +86,7 @@ def train_diagnose( trainset, valset, testset = load_datasets() - adal_component = MultiHopRAGAdal( + adal_component = MultiHopRAGCycleAdal( model_client, model_kwargs, backward_engine_model_config=gpt_4o_model, @@ -111,7 +111,7 @@ def train( resume_from_ckpt=None, exclude_input_fields_from_bootstrap_demos=True, ): - adal_component = MultiHopRAGAdal( + adal_component = MultiHopRAGCycleAdal( **gpt_3_model, teacher_model_config=gpt_3_model, text_optimizer_model_config=gpt_4o_model, # gpt3.5 is not enough to be used as a good optimizer, it struggles for long contenxt @@ -157,7 +157,7 @@ def train( # train: 0.15 before the evaluator converted to lower and 0.4 after the conversion train( - debug=False, + debug=True, max_steps=12, # resume_from_ckpt="/Users/liyin/.adalflow/ckpt/ValinaRAGAdal/random_max_steps_12_7c091_run_1.json", ) diff --git a/use_cases/classification/train.py b/use_cases/classification/train.py index 0bdbd562..6e67ac73 100644 --- a/use_cases/classification/train.py +++ b/use_cases/classification/train.py @@ -146,6 +146,9 @@ def train( # NOTE: # continue from last best, 1 bootstrap, (both input and rational)86.1 val, 86.1 test (not really better) # TrecClassifierAdal/constrained_max_steps_12_2ffa7_run_2.json + # 1086s + # 0.88 validation (the steps are not right, it shows 56 steps) + # /Users/liyin/.adalflow/ckpt/TrecClassifierAdal/constrained_max_steps_12_5d1bf_run_1.json # theory: all few-shots demo or instruction, all so that the llm can reason better. Once it reches to its limits, no more shots can help or further instruction can. diff --git a/use_cases/classification/trec_task_structured_output.py b/use_cases/classification/trec_task_structured_output.py index eb5333cd..9d6c75fe 100644 --- a/use_cases/classification/trec_task_structured_output.py +++ b/use_cases/classification/trec_task_structured_output.py @@ -60,7 +60,7 @@ def __init__(self, model_client: adal.ModelClient, model_kwargs: Dict): # data="You are a classifier. Given a question, classify it into one of the following classes based on what the question is seeking:\n\nFormat: class_index. class_name, class_description\n\n0. ABBR, Abbreviation\n1. ENTY, Entity\n2. DESC, Description and abstract concept\n3. HUM, Human being\n4. LOC, Location\n5. NUM, Numeric value\n\nPay special attention to questions about entities versus descriptions, as well as those asking for specific terms or people. Do not try to answer the question:", # best # data="You are a classifier. For each question given, classify it into one of the following classes:\n\nFormat: class_index. class_name, class_description\n\n0. ABBR, Abbreviation (includes initials)\n1. ENTY, Entity (includes products, languages, objects, etc.)\n2. DESC, Description and abstract concept (includes explanations)\n3. HUM, Human being (includes individuals, groups, etc.)\n4. LOC, Location (includes addresses, places, etc.)\n5. NUM, Numeric value (includes distances, dates, ages, etc.)\n\n- Focus on identifying the primary subject of the question and classifying based on what is being explicitly asked for.", role_desc="Task description", - requires_opt=False, + requires_opt=True, param_type=adal.ParameterType.PROMPT, ), "output_format_str": adal.Parameter( @@ -70,12 +70,12 @@ def __init__(self, model_client: adal.ModelClient, model_kwargs: Dict): param_type=adal.ParameterType.PROMPT, ), # NOTE: 88.19% - "few_shot_demos": adal.Parameter( - data=None, - requires_opt=True, - role_desc="Few shot examples to help the model", - param_type=adal.ParameterType.DEMOS, - ), + # "few_shot_demos": adal.Parameter( + # data=None, + # requires_opt=True, + # role_desc="Few shot examples to help the model", + # param_type=adal.ParameterType.DEMOS, + # ), } self.llm = adal.Generator(