Skip to content

Commit

Permalink
Merge pull request #4 from gabrielguarisa/f/fix-list-constant-issue
Browse files Browse the repository at this point in the history
F/fix list constant issue
  • Loading branch information
gabrielguarisa authored May 2, 2023
2 parents c422396 + ad57fda commit 93b1c0c
Show file tree
Hide file tree
Showing 22 changed files with 1,086 additions and 294 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ parser = retrack.Parser(rule)
runner = retrack.Runner(parser)

# Run the rule/model passing the data
runner(data)
runner.execute(data)
```

The `Parser` class parses the rule/model and creates a graph of nodes. The `Runner` class runs the rule/model using the data passed to the runner. The `data` is a dictionary or a list of dictionaries containing the data that will be used to evaluate the conditions and execute the actions. To see wich data is required for the given rule/model, check the `runner.request_model` property that is a pydantic model used to validate the data.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "retrack"
version = "0.3.0"
version = "0.4.1"
description = "A business rules engine"
authors = ["Gabriel Guarisa <[email protected]>", "Nathalia Trotte <[email protected]>"]
license = "MIT"
Expand Down
227 changes: 141 additions & 86 deletions retrack/engine/parser.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,31 @@
import typing

from retrack.engine.validators import registry as GLOBAL_VALIDATOR_REGISTRY
from retrack.nodes import BaseNode
from retrack.nodes import registry as GLOBAL_NODE_REGISTRY
from retrack import nodes, validators
from retrack.utils.registry import Registry


class Parser:
def __init__(
self,
graph_data: dict,
component_registry: Registry = GLOBAL_NODE_REGISTRY,
validator_registry: Registry = GLOBAL_VALIDATOR_REGISTRY,
unknown_node_error: bool = False,
component_registry: Registry = nodes.registry(),
validator_registry: Registry = validators.registry(),
):
"""Parses a dictionary of nodes and returns a dictionary of BaseNode objects.
Args:
data (dict): A dictionary of nodes.
component_registry (Registry, optional): A registry of BaseNode objects. Defaults to GLOBAL_NODE_REGISTRY.
validator_registry (Registry, optional): A registry of BaseValidator objects. Defaults to GLOBAL_VALIDATOR_REGISTRY.
unknown_node_error (bool, optional): Whether to raise an error if an unknown node is found. Defaults to False.
"""
Parser._check_input_data(graph_data)

node_registry = Registry()
self._indexes_by_kind_map = {}
self._indexes_by_name_map = {}
self._execution_order = None
self.__components = {}
self.__edges = None

for node_id, node_data in graph_data["nodes"].items():
node_name = node_data.get("name", None)

Parser._check_node_name(node_name, node_id)

node_name = node_name.lower()

validation_model = component_registry.get(node_name)
if validation_model is not None:
if node_name not in self._indexes_by_name_map:
self._indexes_by_name_map[node_name] = []

self._indexes_by_name_map[node_name].append(node_id)

node_data["id"] = node_id
if node_id not in node_registry:
node = validation_model(**node_data)
node_registry.register(node_id, node)

if node.kind() not in self._indexes_by_kind_map:
self._indexes_by_kind_map[node.kind()] = []
self._check_input_data(graph_data)

self._indexes_by_kind_map[node.kind()].append(node_id)

elif unknown_node_error:
raise ValueError(f"Unknown node name: {node_name}")
self._set_components(graph_data, component_registry)
self._set_edges()

self._node_registry = node_registry

for validator_name, validator in validator_registry.data.items():
if not validator.validate(graph_data=graph_data, edges=self.edges):
raise ValueError(f"Invalid graph data: {validator_name}")
self._validate_graph(graph_data, validator_registry)

@staticmethod
def _check_node_name(node_name: str, node_id: str):
if node_name is None:
raise ValueError(f"BaseNode {node_id} has no name")
if not isinstance(node_name, str):
raise TypeError(f"BaseNode {node_id} name must be a string")
self._set_indexes_by_name_map()
self._set_indexes_by_kind_map()
self._set_execution_order()
self._set_indexes_by_memory_type_map()

@staticmethod
def _check_input_data(data: dict):
Expand All @@ -83,53 +41,150 @@ def _check_input_data(data: dict):
+ str(type(data["nodes"]))
)

@staticmethod
def _check_node_name(node_name: str, node_id: str):
if node_name is None:
raise ValueError(f"BaseNode {node_id} has no name")
if not isinstance(node_name, str):
raise TypeError(f"BaseNode {node_id} name must be a string")

@property
def nodes(self) -> typing.Dict[str, BaseNode]:
return self._node_registry.data
def components(self) -> typing.Dict[str, nodes.BaseNode]:
return self.__components

def _set_components(self, graph_data: dict, component_registry: Registry):
for node_id, node_metadata in graph_data["nodes"].items():
if node_id in self.__components:
raise ValueError(f"Duplicate node id: {node_id}")

node_name = node_metadata.get("name", None)
self._check_node_name(node_name, node_id)

node_name = node_name.lower()

validation_model = component_registry.get(node_name)

if validation_model is None:
raise ValueError(f"Unknown node name: {node_name}")

self.__components[node_id] = validation_model(**node_metadata)

@property
def edges(self) -> typing.List[typing.Tuple[str, str]]:
if self.__edges is None:
edges = []
return self.__edges

for node_id, node in self.nodes.items():
for _, output_connection in node.outputs:
for c in output_connection.connections:
edges.append((node_id, c.node))
def _set_edges(self):
self.__edges = []

self.__edges = edges
for node_id, node in self.components.items():
for _, output_connection in node.outputs:
for c in output_connection.connections:
self.__edges.append((node_id, c.node))

return self.__edges
def _validate_graph(self, graph_data: dict, validator_registry: Registry):
for validator_name, validator in validator_registry.data.items():
if not validator.validate(graph_data=graph_data, edges=self.edges):
raise ValueError(f"Invalid graph data: {validator_name}")

@property
def data(self) -> dict:
return {i: j.dict(by_alias=True) for i, j in self.nodes.items()}
def get_by_id(self, id_: str) -> nodes.BaseNode:
return self.components.get(id_)

@property
def tokens(self) -> dict:
"""Returns a dictionary of tokens (node name) and their associated node ids."""
def indexes_by_name_map(self) -> typing.Dict[str, typing.List[str]]:
return self._indexes_by_name_map

def _set_indexes_by_name_map(self):
self._indexes_by_name_map = {}

for node_id, node in self.components.items():
node_name = node.__class__.__name__.lower()
if node_name not in self._indexes_by_name_map:
self._indexes_by_name_map[node_name] = []

self._indexes_by_name_map[node_name].append(node_id)

def get_by_name(self, name: str) -> typing.List[nodes.BaseNode]:
return [self.get_by_id(id_) for id_ in self.indexes_by_name_map.get(name, [])]

@property
def indexes_by_kind_map(self) -> dict:
"""Returns a dictionary of node kinds and their associated node ids."""
def indexes_by_kind_map(self) -> typing.Dict[str, typing.List[str]]:
return self._indexes_by_kind_map

def get_node_by_id(self, node_id: str) -> BaseNode:
return self._node_registry.get(node_id)
def _set_indexes_by_kind_map(self):
self._indexes_by_kind_map = {}

for node_id, node in self.components.items():
if node.kind() not in self._indexes_by_kind_map:
self._indexes_by_kind_map[node.kind()] = []

self._indexes_by_kind_map[node.kind()].append(node_id)

def get_by_kind(self, kind: str) -> typing.List[nodes.BaseNode]:
return [self.get_by_id(id_) for id_ in self.indexes_by_kind_map.get(kind, [])]

@property
def indexes_by_memory_type_map(self) -> typing.Dict[str, typing.List[str]]:
return self._indexes_by_memory_type_map

def _set_indexes_by_memory_type_map(self):
self._indexes_by_memory_type_map = {}

for node_id, node in self.components.items():
memory_type = node.memory_type()
if memory_type not in self.indexes_by_memory_type_map:
self._indexes_by_memory_type_map[memory_type] = []

self._indexes_by_memory_type_map[memory_type].append(node_id)

def get_by_memory_type(self, memory_type: str) -> typing.List[nodes.BaseNode]:
return [
self.get_by_id(id_)
for id_ in self.indexes_by_memory_type_map.get(memory_type, [])
]

@property
def execution_order(self) -> typing.List[str]:
return self._execution_order

def _set_execution_order(self):
start_nodes = self.get_by_name("start")

self._execution_order = self._walk(start_nodes[0].id, [])

def get_node_connections(
self, node_id: str, is_input: bool = True, filter_by_connector=None
):
node_dict = self.get_by_id(node_id).dict(by_alias=True)

connectors = node_dict.get("inputs" if is_input else "outputs", {})
result = []

for connector_name, value in connectors.items():
if (
filter_by_connector is not None
and connector_name != filter_by_connector
):
continue

for connection in value["connections"]:
result.append(connection["node"])
return result

def _walk(self, actual_id: str, skiped_ids: list):
skiped_ids.append(actual_id)

def get_nodes_by_name(self, node_name: str) -> typing.List[BaseNode]:
node_name = node_name.lower()
return [self.get_node_by_id(i) for i in self.tokens.get(node_name, [])]
output_ids = self.get_node_connections(actual_id, is_input=False)

def get_nodes_by_multiple_names(self, node_names: list) -> typing.List[BaseNode]:
all_nodes = []
for node_name in node_names:
nodes = self.get_nodes_by_name(node_name)
if nodes is not None:
all_nodes.extend(nodes)
for next_id in output_ids:
if next_id not in skiped_ids:
next_node_input_ids = self.get_node_connections(next_id, is_input=True)
run_next = True
for next_node_input_id in next_node_input_ids:
if next_node_input_id not in skiped_ids:
run_next = False
break

return all_nodes
if run_next:
self._walk(next_id, skiped_ids)

def get_nodes_by_kind(self, kind: str) -> typing.List[BaseNode]:
return [self.get_node_by_id(i) for i in self.indexes_by_kind_map.get(kind, [])]
return skiped_ids
Loading

0 comments on commit 93b1c0c

Please sign in to comment.