Skip to content

Commit

Permalink
use random seed to control the order of training samples, add backwar…
Browse files Browse the repository at this point in the history
…d pass setup for the backward engine via the trainer.fit function
  • Loading branch information
liyin2015 committed Jan 2, 2025
1 parent a185e7f commit 540b161
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 129 deletions.
62 changes: 54 additions & 8 deletions adalflow/adalflow/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from typing import Any, Dict, Optional, Union, Callable, Tuple, List
import logging
from dataclasses import dataclass, field


from adalflow.core.types import (
Expand Down Expand Up @@ -63,6 +64,20 @@
PromptArgType = Dict[str, Union[str, Parameter]]


@dataclass
class BackwardPassSetup(DataClass):
all_pred_at_once: bool = field(
default=True, metadata={"desc": "Backward all predecessors at once."}
)
threshold_score_to_compute_grad_for_errors: float = field(
default=0.9,
metadata={"desc": "Threshold score to compute gradient for errors."},
)
compute_grad_for_errors_only: bool = field(
default=False, metadata={"desc": "Compute gradient for errors only."}
)


class Generator(GradComponent, CachedEngine, CallbackManager):
__doc__ = """An user-facing orchestration component for LLM prediction.
Expand Down Expand Up @@ -95,6 +110,10 @@ class Generator(GradComponent, CachedEngine, CallbackManager):
{}
) # to create teacher generator from student TODO: might reaccess this

backward_pass_setup: BackwardPassSetup = (
BackwardPassSetup()
) # default setup for the backward pass

def __init__(
self,
*,
Expand Down Expand Up @@ -184,6 +203,9 @@ def __init__(
{}
) # used by dynamic computation graph and backpropagation

def update_default_backward_pass_setup(self, setup: BackwardPassSetup):
self.backward_pass_setup = setup

def set_cache_path(self, cache_path: str, model_client: object, model: str):
"""Set the cache path for the generator."""

Expand Down Expand Up @@ -593,6 +615,7 @@ def data_to_prompt_map_fn(data: Parameter) -> str:
log.debug(f"Backward engine: {self.backward_engine}")

# attach a funtion to compute gradient for predecessors

response.set_grad_fn(
BackwardContext(
backward_fn=self.backward,
Expand All @@ -602,7 +625,6 @@ def data_to_prompt_map_fn(data: Parameter) -> str:
template=self.template,
prompt_str=self.get_prompt(**combined_prompt_kwargs),
id=id,
all_pred_at_once=True,
)
)
return response
Expand All @@ -615,11 +637,16 @@ def backward(
prompt_str: str,
backward_engine: Optional["Generator"] = None,
id: Optional[str] = None, # the id of the input
all_pred_at_once: bool = True,
) -> Parameter:

log.info(f"Generator: Backward: {response.name}")

backward_pass_setup = backward_engine.backward_pass_setup
printc(
f"backward pass setup: {backward_pass_setup}, name: {self.name}",
color="red",
)

children_params = response.predecessors
is_intermediate_node = True
if response.get_gradient_and_context_text().strip() == "":
Expand Down Expand Up @@ -648,6 +675,9 @@ def backward(
for pred in children_params:
pred.backward_engine_disabled = True
return

all_pred_at_once = backward_pass_setup.all_pred_at_once

if not all_pred_at_once:
for pred in children_params:
if not pred.requires_opt or pred.param_type == ParameterType.DEMOS:
Expand All @@ -663,6 +693,7 @@ def backward(
template=template,
backward_engine=backward_engine,
prompt_str=prompt_str,
backward_pass_setup=backward_pass_setup,
is_intermediate_node=is_intermediate_node,
)
else:
Expand All @@ -680,6 +711,7 @@ def backward(
template=template,
backward_engine=backward_engine,
prompt_str=prompt_str,
backward_pass_setup=backward_pass_setup,
is_intermediate_node=is_intermediate_node,
)
else:
Expand All @@ -693,6 +725,7 @@ def _backward_through_all_predecessors(
template: str,
backward_engine: "BackwardEngine",
prompt_str: str,
backward_pass_setup: BackwardPassSetup,
is_intermediate_node: bool = False,
):
parser = JsonParser()
Expand Down Expand Up @@ -762,8 +795,13 @@ def _backward_through_all_predecessors(

gradient_output: GeneratorOutput = None
response_gradient_list = [""] * len(children_params)
if response._score is not None and float(response._score) > 0.9:
manual_response_1 = f"You get score: {response._score}."
if (
backward_pass_setup.compute_grad_for_errors_only
and response._score is not None
and float(response._score)
> backward_pass_setup.threshold_score_to_compute_grad_for_errors
):
manual_response_1 = f"You get score: {response._score}. No noticable error."
response_gradient_list = [manual_response_1] * len(children_params)
raw_response = str(response_gradient_list)
gradient_output = GeneratorOutput(
Expand Down Expand Up @@ -832,6 +870,7 @@ def _backward_through_one_predecessor(
template: str,
backward_engine: "BackwardEngine",
prompt_str: str,
backward_pass_setup: BackwardPassSetup,
is_intermediate_node: bool = False,
):
"""Creating gradient/textual feedback for prompt type parameters."""
Expand All @@ -840,7 +879,7 @@ def _backward_through_one_predecessor(
f"Generator: Skipping {pred} as it does not require optimization."
)
return
log.debug(
printc(
f"Generator: Backward through {pred}, is_intermediate_node: {is_intermediate_node}"
)

Expand Down Expand Up @@ -872,8 +911,10 @@ def _backward_through_one_predecessor(

variable_dict = pred.get_param_info()

peers = [p.get_param_info() for p in pred.peers]

variable_and_peers_info = Prompt(
prompt_kwargs={"variable": variable_dict, "peers": pred.peers},
prompt_kwargs={"variable": variable_dict, "peers": peers},
template=VARIABLE_AND_PEERS_INFO,
)()

Expand Down Expand Up @@ -914,10 +955,15 @@ def _backward_through_one_predecessor(
)
print(f"Backward engine prompt: {backward_engine_prompt_str}")
gradient_output: GeneratorOutput = None
if response._score is not None and float(response._score) > 0.9:
if (
backward_pass_setup.compute_grad_for_errors_only
and response._score is not None
and float(response._score)
> backward_pass_setup.threshold_score_to_compute_grad_for_errors
):
log.debug(f"EvalFnToTextLoss: Skipping {pred} as the score is high enough.")
# TODO: plus score descriptions
manual_response = f"You get score: {response._score}."
manual_response = f"You get score: {response._score}. No noticable error."
gradient_output = GeneratorOutput(
data=manual_response, raw_response=manual_response
)
Expand Down
2 changes: 1 addition & 1 deletion adalflow/adalflow/optim/text_grad/tgd_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def _get_user_prompt_kwargs(self, param: Parameter) -> Dict[str, str]:
variable=param.get_param_info(), peers=peers_params
)

variable_grad = param.get_gradients_component_schema(skip_correct_sample=True)
variable_grad = param.get_gradients_component_schema(skip_correct_sample=False)

user_prompt_kwargs = {
"variable_and_peers_info": variable_and_peer_info,
Expand Down
6 changes: 5 additions & 1 deletion adalflow/adalflow/optim/trainer/adal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

if TYPE_CHECKING:
from adalflow.core.model_client import ModelClient
from adalflow.core.generator import Generator, BackwardEngine
from adalflow.core.generator import Generator, BackwardEngine, BackwardPassSetup
from adalflow.optim.parameter import Parameter

from adalflow.core.component import Component
Expand Down Expand Up @@ -187,6 +187,7 @@ def configure_backward_engine(self, *args, **kwargs):
self.configure_backward_engine_helper(
model_client=self.backward_engine_model_config["model_client"],
model_kwargs=self.backward_engine_model_config["model_kwargs"],
backward_pass_setup=kwargs.get("backward_pass_setup", None),
)

# def configure_backward_engine(self, *args, **kwargs):
Expand Down Expand Up @@ -594,6 +595,7 @@ def configure_backward_engine_helper(
model_client: "ModelClient",
model_kwargs: Dict[str, Any],
template: Optional[str] = None,
backward_pass_setup: Optional["BackwardPassSetup"] = None,
):
r"""Configure a backward engine for all generators in the task for bootstrapping examples."""
from adalflow.core.generator import BackwardEngine
Expand All @@ -603,6 +605,8 @@ def configure_backward_engine_helper(
model_kwargs=model_kwargs,
template=template,
)
if backward_pass_setup is not None:
self.backward_engine.update_default_backward_pass_setup(backward_pass_setup)

# set all generator's backward engine

Expand Down
Loading

0 comments on commit 540b161

Please sign in to comment.