Skip to content

Commit

Permalink
Merge pull request #3 from gabrielguarisa/f/DNX-1791/graph-validators
Browse files Browse the repository at this point in the history
F/dnx 1791/graph validators
  • Loading branch information
gabrielguarisa authored Feb 27, 2023
2 parents 0d6e30e + 8322ded commit 3d51b72
Show file tree
Hide file tree
Showing 26 changed files with 1,178 additions and 239 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ runner = retrack.Runner(parser)
runner(data)
```

The `Parser` class parses the rule/model and creates a graph of nodes. The `Runner` class runs the rule/model using the data passed to the runner. The `data` is a dictionary or a list of dictionaries containing the data that will be used to evaluate the conditions and execute the actions. To see wich data is required for the given rule/model, check the `runner.payload_manager.model` property that is a pydantic model used to validate the data.
The `Parser` class parses the rule/model and creates a graph of nodes. The `Runner` class runs the rule/model using the data passed to the runner. The `data` is a dictionary or a list of dictionaries containing the data that will be used to evaluate the conditions and execute the actions. To see wich data is required for the given rule/model, check the `runner.request_model` property that is a pydantic model used to validate the data.

### Creating a rule/model

Expand Down
64 changes: 32 additions & 32 deletions examples/age-check.json
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@
}
},
"position": [
83.84765625,
-146.80078125
45.51953125,
-136.8515625
],
"name": "Check"
},
Expand Down Expand Up @@ -180,29 +180,6 @@
],
"name": "If"
},
"7": {
"id": 7,
"data": {
"message": "invalid age"
},
"inputs": {
"input_bool": {
"connections": [
{
"node": 8,
"output": "output_bool",
"data": {}
}
]
}
},
"outputs": {},
"position": [
972.3247551810931,
-8.462733235746624
],
"name": "BoolOutput"
},
"8": {
"id": 8,
"data": {
Expand All @@ -223,8 +200,8 @@
"output_bool": {
"connections": [
{
"node": 7,
"input": "input_bool",
"node": 11,
"input": "input_value",
"data": {}
}
]
Expand Down Expand Up @@ -257,7 +234,7 @@
"connections": [
{
"node": 10,
"input": "input_bool",
"input": "input_value",
"data": {}
}
]
Expand All @@ -275,7 +252,7 @@
"message": "valid age"
},
"inputs": {
"input_bool": {
"input_value": {
"connections": [
{
"node": 9,
Expand All @@ -287,10 +264,33 @@
},
"outputs": {},
"position": [
974.0269089794663,
-190.8242802623253
1015.5346416468075,
-247.2703893983769
],
"name": "Output"
},
"11": {
"id": 11,
"data": {
"message": "invalid age"
},
"inputs": {
"input_value": {
"connections": [
{
"node": 8,
"output": "output_bool",
"data": {}
}
]
}
},
"outputs": {},
"position": [
1008.826235840524,
-68.88453626006572
],
"name": "BoolOutput"
"name": "Output"
}
}
}
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "retrack"
version = "0.2.2"
version = "0.3.0"
description = "A business rules engine"
authors = ["Gabriel Guarisa <[email protected]>", "Nathalia Trotte <[email protected]>"]
license = "MIT"
Expand All @@ -13,6 +13,7 @@ keywords = ["rules", "models", "business", "node", "graph"]
python = "^3.8.16"
pydantic = "^1.10.4"
pandas = "^1.5.2"
networkx = "^3.0"

[tool.poetry.dev-dependencies]
pytest = "^6.2.4"
Expand Down
77 changes: 53 additions & 24 deletions retrack/engine/parser.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
import typing

from retrack.engine.validators import registry as GLOBAL_VALIDATOR_REGISTRY
from retrack.nodes import BaseNode
from retrack.nodes import registry as GLOBAL_NODE_REGISTRY
from retrack.utils.registry import Registry

INPUT_TOKENS = ["input"]

OUTPUT_TOKENS = ["booloutput"]

CONSTANT_TOKENS = ["constant", "list", "bool"]


class Parser:
def __init__(
self,
data: dict,
graph_data: dict,
component_registry: Registry = GLOBAL_NODE_REGISTRY,
validator_registry: Registry = GLOBAL_VALIDATOR_REGISTRY,
unknown_node_error: bool = False,
):
Parser._check_input_data(data)
"""Parses a dictionary of nodes and returns a dictionary of BaseNode objects.
Args:
data (dict): A dictionary of nodes.
component_registry (Registry, optional): A registry of BaseNode objects. Defaults to GLOBAL_NODE_REGISTRY.
validator_registry (Registry, optional): A registry of BaseValidator objects. Defaults to GLOBAL_VALIDATOR_REGISTRY.
unknown_node_error (bool, optional): Whether to raise an error if an unknown node is found. Defaults to False.
"""
Parser._check_input_data(graph_data)

node_registry = Registry()
tokens = {}
self._indexes_by_kind_map = {}
self._indexes_by_name_map = {}
self.__edges = None

for node_id, node_data in data["nodes"].items():
for node_id, node_data in graph_data["nodes"].items():
node_name = node_data.get("name", None)

Parser._check_node_name(node_name, node_id)
Expand All @@ -32,19 +38,29 @@ def __init__(

validation_model = component_registry.get(node_name)
if validation_model is not None:
if node_name not in tokens:
tokens[node_name] = []
if node_name not in self._indexes_by_name_map:
self._indexes_by_name_map[node_name] = []

tokens[node_name].append(node_id)
self._indexes_by_name_map[node_name].append(node_id)

node_data["id"] = node_id
if node_id not in node_registry:
node_registry.register(node_id, validation_model(**node_data))
node = validation_model(**node_data)
node_registry.register(node_id, node)

if node.kind() not in self._indexes_by_kind_map:
self._indexes_by_kind_map[node.kind()] = []

self._indexes_by_kind_map[node.kind()].append(node_id)

elif unknown_node_error:
raise ValueError(f"Unknown node name: {node_name}")

self._node_registry = node_registry
self._tokens = tokens

for validator_name, validator in validator_registry.data.items():
if not validator.validate(graph_data=graph_data, edges=self.edges):
raise ValueError(f"Invalid graph data: {validator_name}")

@staticmethod
def _check_node_name(node_name: str, node_id: str):
Expand All @@ -71,13 +87,33 @@ def _check_input_data(data: dict):
def nodes(self) -> typing.Dict[str, BaseNode]:
return self._node_registry.data

@property
def edges(self) -> typing.List[typing.Tuple[str, str]]:
if self.__edges is None:
edges = []

for node_id, node in self.nodes.items():
for _, output_connection in node.outputs:
for c in output_connection.connections:
edges.append((node_id, c.node))

self.__edges = edges

return self.__edges

@property
def data(self) -> dict:
return {i: j.dict(by_alias=True) for i, j in self.nodes.items()}

@property
def tokens(self) -> dict:
return self._tokens
"""Returns a dictionary of tokens (node name) and their associated node ids."""
return self._indexes_by_name_map

@property
def indexes_by_kind_map(self) -> dict:
"""Returns a dictionary of node kinds and their associated node ids."""
return self._indexes_by_kind_map

def get_node_by_id(self, node_id: str) -> BaseNode:
return self._node_registry.get(node_id)
Expand All @@ -96,11 +132,4 @@ def get_nodes_by_multiple_names(self, node_names: list) -> typing.List[BaseNode]
return all_nodes

def get_nodes_by_kind(self, kind: str) -> typing.List[BaseNode]:
if kind == "input":
return self.get_nodes_by_multiple_names(INPUT_TOKENS)
if kind == "output":
return self.get_nodes_by_multiple_names(OUTPUT_TOKENS)
if kind == "constant":
return self.get_nodes_by_multiple_names(CONSTANT_TOKENS)

return []
return [self.get_node_by_id(i) for i in self.indexes_by_kind_map.get(kind, [])]
70 changes: 0 additions & 70 deletions retrack/engine/payload_manager.py

This file was deleted.

Loading

0 comments on commit 3d51b72

Please sign in to comment.