Skip to content

Commit

Permalink
adapt the gradcomponent to a version where we support backpropagation…
Browse files Browse the repository at this point in the history
… instead of just pass-through gradient
  • Loading branch information
liyin2015 committed Jan 8, 2025
1 parent e06fd7b commit e204688
Show file tree
Hide file tree
Showing 21 changed files with 702 additions and 281 deletions.
49 changes: 30 additions & 19 deletions adalflow/adalflow/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

log = logging.getLogger(__name__)

DEBUG_MODE = os.environ.get("DEBUG_MODE", False)
DEBUG_MODE = os.environ.get("DEBUG_MODE", True)

PromptArgType = Dict[str, Union[str, Parameter]]

Expand Down Expand Up @@ -281,7 +281,9 @@ def set_parameters(self, prompt_kwargs: PromptArgType):
peers = [
p
for k, p in prompt_kwargs.items()
if isinstance(p, Parameter) and k != key
if isinstance(p, Parameter)
and k != key
and p.param_type == ParameterType.PROMPT
]
p.set_peers(peers)
setattr(self, key, p)
Expand Down Expand Up @@ -338,7 +340,7 @@ def get_prompt(self, **kwargs) -> str:
return self.prompt.call(**kwargs)

def _extra_repr(self) -> str:
s = f"model_kwargs={self.model_kwargs}, model_type={self.model_type}"
s = f"model_kwargs={self.model_kwargs}, model_type={self.model_type}, prompt={self.prompt}"
return s

def _post_call(self, completion: Any) -> GeneratorOutput:
Expand Down Expand Up @@ -641,7 +643,9 @@ def backward(

log.info(f"Generator: Backward: {response.name}")

backward_pass_setup = backward_engine.backward_pass_setup
backward_pass_setup = (
backward_engine.backward_pass_setup if backward_engine else None
)
printc(
f"backward pass setup: {backward_pass_setup}, name: {self.name}",
color="red",
Expand All @@ -655,14 +659,14 @@ def backward(
# backward score to the demo parameter
for pred in children_params:
# if pred.requires_opt:
pred.set_score(response._score)
pred.set_score(response.score)
log.debug(
f"backpropagate the score {response._score} to {pred.name}, is_teacher: {self.teacher_mode}"
f"backpropagate the score {response.score} to {pred.name}, is_teacher: {self.teacher_mode}"
)
if pred.param_type == ParameterType.DEMOS:
# Accumulate the score to the demo
pred.add_score_to_trace(
trace_id=id, score=response._score, is_teacher=self.teacher_mode
trace_id=id, score=response.score, is_teacher=self.teacher_mode
)
log.debug(f"Pred: {pred.name}, traces: {pred._traces}")

Expand Down Expand Up @@ -808,11 +812,11 @@ def _backward_through_all_predecessors(
response_gradient_list = [""] * len(children_params)
if (
backward_pass_setup.compute_grad_for_errors_only
and response._score is not None
and float(response._score)
and response.score is not None
and float(response.score)
> backward_pass_setup.threshold_score_to_compute_grad_for_errors
):
manual_response_1 = f"You get score: {response._score}. No noticable error."
manual_response_1 = f"You get score: {response.score}. No noticable error."
response_gradient_list = [manual_response_1] * len(children_params)
raw_response = str(response_gradient_list)
gradient_output = GeneratorOutput(
Expand Down Expand Up @@ -860,7 +864,7 @@ def _backward_through_all_predecessors(
var_gradient = Gradient(
data=gradient_data,
data_id=response.data_id,
score=response._score, # add score to gradient
score=response.score, # add score to gradient
from_response=response,
to_pred=pred,
)
Expand All @@ -873,7 +877,7 @@ def _backward_through_all_predecessors(
)
var_gradient.add_prompt(backward_engine_prompt_str)
pred.add_gradient(var_gradient)
pred.set_score(response._score)
pred.set_score(response.score)

@staticmethod
def _backward_through_one_predecessor(
Expand Down Expand Up @@ -923,6 +927,7 @@ def _backward_through_one_predecessor(
variable_dict = pred.get_param_info()

peers = [p.get_param_info() for p in pred.peers]
# peers = []

variable_and_peers_info = Prompt(
prompt_kwargs={"variable": variable_dict, "peers": peers},
Expand All @@ -939,12 +944,17 @@ def _backward_through_one_predecessor(
raise ValueError(
f"Generator: No gradient found for {response}. Please check the response. pred: {pred}"
)

predecessors = [
pred.get_param_info()
for pred in response.predecessors
if pred not in pred.peers
]
instruction_str = Prompt(
template=conv_ins_template,
prompt_kwargs={
"variable_and_peers_info": variable_and_peers_info,
"conversation_str": conversation_str,
"predecessors": predecessors,
},
)()
log.info(f"Conversation start instruction base str: {instruction_str}")
Expand All @@ -968,13 +978,13 @@ def _backward_through_one_predecessor(
gradient_output: GeneratorOutput = None
if (
backward_pass_setup.compute_grad_for_errors_only
and response._score is not None
and float(response._score)
and response.score is not None
and float(response.score)
> backward_pass_setup.threshold_score_to_compute_grad_for_errors
):
log.debug(f"EvalFnToTextLoss: Skipping {pred} as the score is high enough.")
# TODO: plus score descriptions
manual_response = f"You get score: {response._score}. No noticable error."
manual_response = f"You get score: {response.score}. No noticable error."
gradient_output = GeneratorOutput(
data=manual_response, raw_response=manual_response
)
Expand All @@ -987,7 +997,7 @@ def _backward_through_one_predecessor(
raise ValueError(
f"Generator: Backward Engine should return a GeneratorOutput. Got {gradient_output} instead."
)
print(f"Backward engine gradient: {gradient_output}")
printc(f"Backward engine gradient: {gradient_output}")

# USE this to trace each node's input and output, all nodes can be visualized
log.info(
Expand All @@ -1001,7 +1011,7 @@ def _backward_through_one_predecessor(
var_gradient = Gradient(
data=gradient_value,
data_id=response.data_id,
score=response._score, # add score to gradient
score=response.score, # add score to gradient
from_response=response,
to_pred=pred,
)
Expand All @@ -1014,7 +1024,7 @@ def _backward_through_one_predecessor(
)
var_gradient.add_prompt(backward_engine_prompt_str)
pred.add_gradient(var_gradient)
pred.set_score(response._score)
pred.set_score(response.score)

def _run_callbacks(
self,
Expand Down Expand Up @@ -1168,6 +1178,7 @@ def _extra_repr(self) -> str:
]

s += f"trainable_prompt_kwargs={prompt_kwargs_repr}"
s += f", prompt={self.prompt}"
return s

def to_dict(self) -> Dict[str, Any]:
Expand Down
9 changes: 9 additions & 0 deletions adalflow/adalflow/core/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ def _extra_repr(self) -> str:
s += f", prompt_variables: {self.prompt_variables}"
return s

def __repr__(self) -> str:
s = f"template: {self.template}"
prompt_kwargs_str = _convert_prompt_kwargs_to_str(self.prompt_kwargs)
if prompt_kwargs_str:
s += f", prompt_kwargs: {prompt_kwargs_str}"
if self.prompt_variables:
s += f", prompt_variables: {self.prompt_variables}"
return s

@classmethod
def from_dict(cls: type[T], data: Dict[str, Any]) -> T:
obj = super().from_dict(data)
Expand Down
3 changes: 3 additions & 0 deletions adalflow/adalflow/datasets/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class HotPotQAData(Example):
default=None,
)

__input_fields__ = ["question"]
__output_fields__ = ["answer"]

# @staticmethod
# def from_dict(d: Dict[str, Any]) -> "HotPotQAData":
# # Preprocess gold_titles
Expand Down
6 changes: 5 additions & 1 deletion adalflow/adalflow/optim/few_shot/bootstrap_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from adalflow.core.functional import random_sample
from adalflow.optim.optimizer import DemoOptimizer
from adalflow.optim.types import ParameterType
from adalflow.utils import printc

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -219,7 +220,10 @@ def samples_to_str(
yaml_str = sample.to_yaml(exclude=exclude_fields)

else:
yaml_str = sample.to_yaml(exclude=["id", "score"])
yaml_str = sample.to_yaml(
include=sample.get_input_fields() + sample.get_output_fields()
)
printc(f"yaml_str: {yaml_str}")
sample_strs.append(yaml_str + "\n")
except Exception as e:
print(f"Error: {e} to yaml for {sample}")
Expand Down
Loading

0 comments on commit e204688

Please sign in to comment.