Skip to content

Commit

Permalink
adding parameters in componenets, state_dict does not work yet
Browse files Browse the repository at this point in the history
  • Loading branch information
liyin2015 committed May 25, 2024
1 parent effdec8 commit 0a8951e
Show file tree
Hide file tree
Showing 5 changed files with 363 additions and 84 deletions.
320 changes: 251 additions & 69 deletions core/component.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import OrderedDict
from collections import OrderedDict, namedtuple
from typing import (
Callable,
Dict,
Expand All @@ -22,6 +22,8 @@
import matplotlib.pyplot as plt
import uuid

from core.parameter import Parameter


# TODO: design hooks.
_global_pre_call_hooks: Dict[int, Callable] = OrderedDict()
Expand Down Expand Up @@ -72,87 +74,127 @@ class Component:
_execution_graph: List[str] = [] # This will store the graph of execution.
_graph = nx.DiGraph()
_last_called = None # Tracks the last component called
_parameters: Dict[str, Optional[Parameter]]
training: bool

# def _generate_unique_name(self):
# # Generate a unique identifier that includes the class name
# return f"{self.__class__.__name__}_{uuid.uuid4().hex[:8]}"

def __init__(self, *args, **kwargs) -> None:
super().__setattr__("_components", {})
super().__setattr__("_components", OrderedDict())
super().__setattr__("_parameters", OrderedDict())
super().__setattr__("training", True)

def __setattr__(self, name: str, value: Any) -> None:
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
r"""Add a parameter to the component.
components = self.__dict__.get("_components")
if isinstance(value, Component):
if components is None:
raise AttributeError(
"cant assign component before Component.__init__() call"
)
remove_from(self.__dict__)
components[name] = value
The parameter can be accessed as an attribute using given name.
Args:
name (str): name of the parameter. The parameter can be accessed using this name.
param (Parameter): parameter to be added.
"""
if "_parameters" not in self.__dict__:
raise AttributeError(
"cant assign parameter before Component.__init__() call"
)
elif "." in name:
raise ValueError('parameter name can\'t contain "."')
elif name == "":
raise ValueError('parameter name can\'t be empty string ""')
elif hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))

if param is None:
self._parameters[name] = None
elif not isinstance(param, Parameter):
raise TypeError(
f"cannot assign'{type(param)}' object to parameter '{name}'(Parameter or None required)"
)
else:
super().__setattr__(name, value)
self._parameters[name] = param
print(f"Registered parameter {name} with value {param}")

def __getattr__(self, name: str) -> Any:
if "_components" in self.__dict__:
components = self.__dict__.get("_components")
if name in components:
return components[name]
# else:
# super().__getattr__(name)
def parameters(self, recursive: bool = True) -> Iterable[Parameter]:
r"""Returns an iterator over module parameters.
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
Args:
recursive (bool): if True, then yields parameters of this module and all submodules.
Otherwise, yields only parameters that are direct members of this module.
def __delattr__(self, name: str) -> None:
if name in self._components:
del self._components[name]
else:
super().__delattr__(name)
Yields:
Parameter: module parameter
def _extra_repr(self) -> str:
"""
Normally implemented by subcomponents to print additional positional or keyword arguments.
# NOTE: Dont add components as it will have its own __repr__
Examples:
>>> for param in model.parameters():
>>> print(param)
"""
return ""
for name, param in self.named_parameters():
yield param

def _get_name(self):
# return self._name
return self.__class__.__name__
def _named_members(
self,
get_members_fn,
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True,
):
r"""Helper method for yielding various names + members of the module.
def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = self._extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
for key, component in self._components.items():
mod_str = repr(component)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
lines = extra_lines + child_lines
Args:
get_members_fn (Callable): callable to extract the members from the module.
prefix (str): prefix to prepend to all parameter names.
recurse (bool): if True, then yields parameters of this module and all submodules.
Otherwise, yields only parameters that are direct members of this module.
main_str = self._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
Yields:
Tuple[str, Any]: Tuple containing the name and member
main_str += ")"
return main_str
Examples:
>>> for name, param in model._named_members(model.named_parameters):
>>> print(name, param)
"""
memo = set()
components = self.named_components(
prefix=prefix, remove_duplicate=remove_duplicate
)
for component_prefix, component in components:
members = get_members_fn(component)
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
name = component_prefix + ("." if component_prefix else "") + k
yield name, v

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterable[Tuple[str, Parameter]]:
r"""Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
Args:
prefix (str): prefix to prepend to all parameter names.
recursive (bool): if True, then yields parameters of this module and all submodules.
Otherwise, yields only parameters that are direct members of this module.
are direct members of this module.
remove_duplicate (bool): if True, then yields only unique parameters.
Yields:
Tuple[str, Parameter]: Tuple containing the name and parameter
Examples:
>>> for name, param in model.named_parameters():
>>> print(name, param)
"""
gen = self._named_members(
lambda componnet: componnet._parameters.items(),
prefix=prefix,
recurse=recurse,
remove_duplicate=remove_duplicate,
)
yield from gen

@staticmethod
def visualize_graph_html(filename="graph.html"):
Expand Down Expand Up @@ -253,12 +295,29 @@ def children(self) -> Iterable["Component"]:
for name, component in self.named_children():
yield component

def components(self) -> Iterable["Component"]:
r"""
Returns an iterator over all components in the Module.
"""
for name, component in self.named_children():
yield component

def named_components(
self,
memo: Optional[Set["Component"]] = None,
prefix: str = "",
remove_duplicate: bool = True,
):
r"""Return an iterator over all components in the pipeline, yielding both the name of the component as well as the component itself.
Args:
memo (Optional[Set["Component"]]): a memo to store the set of components already added to the result
prefix (str): a prefix to prepend to all component names
remove_duplicate (bool): if True, then yields only unique components
Yields:
Tuple[str, "Component"]: Tuple containing the name and component
"""
if memo is None:
memo = set()
if self not in memo:
Expand All @@ -273,12 +332,37 @@ def named_components(
memo, submodule_prefix, remove_duplicate
)

def components(self) -> Iterable["Component"]:
r"""
Returns an iterator over all components in the Module.
def state_dict(
self, destination: Optional[Dict[str, Any]] = None, prefix: Optional[str] = ""
) -> Dict[str, Any]:
r"""Returns a dictionary containing references to the whole state of the component.
Parameters are included for now.
..note:
The returned object is a shallow copy. It cantains references to the original data.
Args:
destination (Dict[str, Any]): If provided, the state of component will be copied into it.
And the same object is returned.
Othersie, an ``OrderedDict`` will be created and returned.
prefix (str): a prefix to add to the keys in the state_dict.
Returns:
Dict[str, Any]: a dictionary containing the state of the component.
"""
for name, component in self.named_children():
yield component
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
if hasattr(self, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
for name, component in self._components.items():
if component is not None:
component.state_dict(
destination=destination, prefix=prefix + name + "."
)
return destination

def apply(self: "Component", fn: Callable[["Component", Any], None]) -> None:
r"""
Expand All @@ -290,6 +374,104 @@ def apply(self: "Component", fn: Callable[["Component", Any], None]) -> None:
fn(self)
return self

def __setattr__(self, name: str, value: Any) -> None:
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)

# set parameter
params = self.__dict__.get("_parameters")
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cant assign parameter before Component.__init__() call"
)
remove_from(self.__dict__)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError(
f"cannot assign '{type(value)}' object to parameter '{name}' (Parameter or None required)"
)
self.register_parameter(name, value)
else: # set component

components = self.__dict__.get("_components")
if isinstance(value, Component):
if components is None:
raise AttributeError(
"cant assign component before Component.__init__() call"
)
remove_from(self.__dict__)
components[name] = value

else: # set attribute
super().__setattr__(name, value)

def __getattr__(self, name: str) -> Any:
if "_parameters" in self.__dict__:
parameters = self.__dict__.get("_parameters")
if name in parameters:
return parameters[name]
if "_components" in self.__dict__:
components = self.__dict__.get("_components")
if name in components:
return components[name]
# else:
# super().__getattr__(name)

raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)

def __delattr__(self, name: str) -> None:
if name in self._parameters:
del self._parameters[name]
if name in self._components:
del self._components[name]
else:
super().__delattr__(name)

def _extra_repr(self) -> str:
"""
Normally implemented by subcomponents to print additional positional or keyword arguments.
# NOTE: Dont add components as it will have its own __repr__
"""
return ""

def _get_name(self):
# return self._name
return self.__class__.__name__

def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = self._extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
for key, component in self._components.items():
mod_str = repr(component)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
lines = extra_lines + child_lines

main_str = self._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"

main_str += ")"
return main_str


class Sequential(Component):
r"""A sequential container. Components will be added to it in the order they are passed to the constructor.
Expand Down
5 changes: 5 additions & 0 deletions core/default_prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
"context_str",
"steps_str",
]
LIGHTRAG_DEFAULT_PROMPT_TRAINABLE_PARAMS = [
"task_desc_str",
# "output_format_str",
"examples_str",
]

DEFAULT_LIGHTRAG_SYSTEM_PROMPT = r"""{# task desc #}
{% if task_desc_str %}
Expand Down
Loading

0 comments on commit 0a8951e

Please sign in to comment.