From 3c4e39bf717187ebecaa239e3bc69bf6d13e4cc6 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Thu, 27 Apr 2023 18:43:38 -0300 Subject: [PATCH 1/7] Add broken test with memory reuse --- tests/resources/multiple-ifs.json | 716 ++++++++++++++++++++++++++++++ tests/test_engine/test_runner.py | 49 +- 2 files changed, 748 insertions(+), 17 deletions(-) create mode 100644 tests/resources/multiple-ifs.json diff --git a/tests/resources/multiple-ifs.json b/tests/resources/multiple-ifs.json new file mode 100644 index 0000000..ae790b5 --- /dev/null +++ b/tests/resources/multiple-ifs.json @@ -0,0 +1,716 @@ +{ + "id": "demo@0.1.0", + "nodes": { + "0": { + "id": 0, + "data": {}, + "inputs": {}, + "outputs": { + "output_up_void": { + "connections": [ + { + "node": 2, + "input": "input_void", + "data": {} + } + ] + }, + "output_down_void": { + "connections": [ + { + "node": 3, + "input": "input_void", + "data": {} + } + ] + } + }, + "position": [ + -541.046875, + -69.94140625 + ], + "name": "Start" + }, + "2": { + "id": 2, + "data": { + "name": "number", + "default": null + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 0, + "output": "output_up_void", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 4, + "input": "input_value_0", + "data": {} + } + ] + } + }, + "position": [ + -232.140625, + -265.17578125 + ], + "name": "Input" + }, + "3": { + "id": 3, + "data": { + "value": "1" + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 0, + "output": "output_down_void", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 4, + "input": "input_value_1", + "data": {} + } + ] + } + }, + "position": [ + -233.8984375, + 9.2109375 + ], + "name": "Constant" + }, + "4": { + "id": 4, + "data": { + "operator": "==" + }, + "inputs": { + "input_value_0": { + "connections": [ + { + "node": 2, + "output": "output_value", + "data": {} + } + ] + }, + "input_value_1": { + "connections": [ + { + "node": 3, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": { + "output_bool": { + "connections": [ + { + "node": 5, + "input": "input_bool", + "data": {} + } + ] + } + }, + "position": [ + 85.94921875, + -151.54296875 + ], + "name": "Check" + }, + "5": { + "id": 5, + "data": {}, + "inputs": { + "input_bool": { + "connections": [ + { + "node": 4, + "output": "output_bool", + "data": {} + } + ] + } + }, + "outputs": { + "output_then_filter": { + "connections": [ + { + "node": 10, + "input": "input_void", + "data": {} + } + ] + }, + "output_else_filter": { + "connections": [ + { + "node": 6, + "input": "input_void", + "data": {} + }, + { + "node": 7, + "input": "input_void", + "data": {} + } + ] + } + }, + "position": [ + 391.59645390537196, + -195.2584515796633 + ], + "name": "If" + }, + "6": { + "id": 6, + "data": { + "name": "number", + "default": null + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 5, + "output": "output_else_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 12, + "input": "input_value_0", + "data": {} + } + ] + } + }, + "position": [ + 677.273253686871, + -99.51029607389484 + ], + "name": "Input" + }, + "7": { + "id": 7, + "data": { + "value": "2" + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 5, + "output": "output_else_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 12, + "input": "input_value_1", + "data": {} + } + ] + } + }, + "position": [ + 674.8144465573389, + 156.26079784779859 + ], + "name": "Constant" + }, + "10": { + "id": 10, + "data": { + "value": "1" + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 5, + "output": "output_then_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 11, + "input": "input_value", + "data": {} + } + ] + } + }, + "position": [ + 683.1885458669273, + -305.04161735974293 + ], + "name": "Constant" + }, + "11": { + "id": 11, + "data": { + "message": "first" + }, + "inputs": { + "input_value": { + "connections": [ + { + "node": 10, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": {}, + "position": [ + 1005.135434526846, + -287.55703688913104 + ], + "name": "Output" + }, + "12": { + "id": 12, + "data": { + "operator": "==" + }, + "inputs": { + "input_value_0": { + "connections": [ + { + "node": 6, + "output": "output_value", + "data": {} + } + ] + }, + "input_value_1": { + "connections": [ + { + "node": 7, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": { + "output_bool": { + "connections": [ + { + "node": 13, + "input": "input_bool", + "data": {} + } + ] + } + }, + "position": [ + 980.7535311889883, + 2.928947121487271 + ], + "name": "Check" + }, + "13": { + "id": 13, + "data": {}, + "inputs": { + "input_bool": { + "connections": [ + { + "node": 12, + "output": "output_bool", + "data": {} + } + ] + } + }, + "outputs": { + "output_then_filter": { + "connections": [ + { + "node": 14, + "input": "input_void", + "data": {} + } + ] + }, + "output_else_filter": { + "connections": [ + { + "node": 16, + "input": "input_void", + "data": {} + }, + { + "node": 17, + "input": "input_void", + "data": {} + } + ] + } + }, + "position": [ + 1269.314834650143, + -11.264604317155857 + ], + "name": "If" + }, + "14": { + "id": 14, + "data": { + "value": "2" + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 13, + "output": "output_then_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 15, + "input": "input_value", + "data": {} + } + ] + } + }, + "position": [ + 1553.666222556354, + -123.72348595289333 + ], + "name": "Constant" + }, + "15": { + "id": 15, + "data": { + "message": "second" + }, + "inputs": { + "input_value": { + "connections": [ + { + "node": 14, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": {}, + "position": [ + 1836.9009191990704, + -152.8446598115035 + ], + "name": "Output" + }, + "16": { + "id": 16, + "data": { + "name": "number", + "default": null + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 13, + "output": "output_else_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 18, + "input": "input_value_0", + "data": {} + } + ] + } + }, + "position": [ + 1547.1232183994962, + 82.63101131150805 + ], + "name": "Input" + }, + "17": { + "id": 17, + "data": { + "value": "3" + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 13, + "output": "output_else_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 18, + "input": "input_value_1", + "data": {} + } + ] + } + }, + "position": [ + 1545.9895761297034, + 344.9279787495316 + ], + "name": "Constant" + }, + "18": { + "id": 18, + "data": { + "operator": "==" + }, + "inputs": { + "input_value_0": { + "connections": [ + { + "node": 16, + "output": "output_value", + "data": {} + } + ] + }, + "input_value_1": { + "connections": [ + { + "node": 17, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": { + "output_bool": { + "connections": [ + { + "node": 21, + "input": "input_bool", + "data": {} + } + ] + } + }, + "position": [ + 1817.9033565634622, + 178.64075343168426 + ], + "name": "Check" + }, + "20": { + "id": 20, + "data": { + "message": "third" + }, + "inputs": { + "input_value": { + "connections": [ + { + "node": 22, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": {}, + "position": [ + 2558.3564942509242, + 68.94427065989652 + ], + "name": "Output" + }, + "21": { + "id": 21, + "data": {}, + "inputs": { + "input_bool": { + "connections": [ + { + "node": 18, + "output": "output_bool", + "data": {} + } + ] + } + }, + "outputs": { + "output_then_filter": { + "connections": [ + { + "node": 22, + "input": "input_void", + "data": {} + } + ] + }, + "output_else_filter": { + "connections": [ + { + "node": 23, + "input": "input_void", + "data": {} + } + ] + } + }, + "position": [ + 2057.9117508031372, + 161.72627964236332 + ], + "name": "If" + }, + "22": { + "id": 22, + "data": { + "value": "3" + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 21, + "output": "output_then_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 20, + "input": "input_value", + "data": {} + } + ] + } + }, + "position": [ + 2315.1936187054284, + 38.81957542871652 + ], + "name": "Constant" + }, + "23": { + "id": 23, + "data": { + "value": "0" + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 21, + "output": "output_else_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 24, + "input": "input_value", + "data": {} + } + ] + } + }, + "position": [ + 2316.6487997845193, + 282.7729752267469 + ], + "name": "Constant" + }, + "24": { + "id": 24, + "data": { + "message": "other" + }, + "inputs": { + "input_value": { + "connections": [ + { + "node": 23, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": {}, + "position": [ + 2555.308493387755, + 281.51261293236865 + ], + "name": "Output" + } + } +} \ No newline at end of file diff --git a/tests/test_engine/test_runner.py b/tests/test_engine/test_runner.py index 1248366..f133a17 100644 --- a/tests/test_engine/test_runner.py +++ b/tests/test_engine/test_runner.py @@ -4,23 +4,38 @@ from retrack import Parser, Runner +@pytest.mark.parametrize( + "filename, in_values, expected_out_values", + [ + ( + "multiple-ifs", + [{"number": 1}, {"number": 2}, {"number": 3}, {"number": 4}], + [ + {"message": "first", "output": "1"}, + {"message": "second", "output": "2"}, + {"message": "third", "output": "3"}, + {"message": "other", "output": "0"}, + ], + ), + ( + "age-negative", + [{"age": 10}, {"age": -10}, {"age": 18}, {"age": 19}, {"age": 100}], + [ + {"message": "underage", "output": False}, + {"message": "invalid age", "output": False}, + {"message": "valid age", "output": True}, + {"message": "valid age", "output": True}, + {"message": "valid age", "output": True}, + ], + ), -@pytest.fixture -def age_negative_json() -> dict: - with open("tests/resources/age-negative.json", "r") as f: - return json.load(f) + ], +) +def test_flows(filename, in_values, expected_out_values): + with open(f"tests/resources/{filename}.json", "r") as f: + rule = json.load(f) + runner = Runner(Parser(rule)) + out_values = runner(in_values) -def test_age_negative(age_negative_json): - parser = Parser(age_negative_json) - runner = Runner(parser) - in_values = [10, -10, 18, 19, 100] - out_values = runner([{"age": val} for val in in_values]) - - assert out_values == [ - {"message": "underage", "output": False}, - {"message": "invalid age", "output": False}, - {"message": "valid age", "output": True}, - {"message": "valid age", "output": True}, - {"message": "valid age", "output": True}, - ] + assert out_values == expected_out_values From 65d74d5310d3a555675bbbcf684a68765089eac8 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Thu, 27 Apr 2023 21:57:28 -0300 Subject: [PATCH 2/7] Refactor nodes and validators registry --- retrack/engine/parser.py | 11 ++-- retrack/engine/validators/__init__.py | 14 ----- retrack/nodes/__init__.py | 51 +++++++++++-------- retrack/validators/__init__.py | 24 +++++++++ retrack/{engine => }/validators/base.py | 0 .../{engine => }/validators/check_is_dag.py | 2 +- .../{engine => }/validators/node_exists.py | 2 +- tests/test_engine/test_runner.py | 2 +- 8 files changed, 62 insertions(+), 44 deletions(-) delete mode 100644 retrack/engine/validators/__init__.py create mode 100644 retrack/validators/__init__.py rename retrack/{engine => }/validators/base.py (100%) rename retrack/{engine => }/validators/check_is_dag.py (86%) rename retrack/{engine => }/validators/node_exists.py (95%) diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index a735e59..70a1ab3 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -1,8 +1,7 @@ import typing -from retrack.engine.validators import registry as GLOBAL_VALIDATOR_REGISTRY +from retrack import nodes, validators from retrack.nodes import BaseNode -from retrack.nodes import registry as GLOBAL_NODE_REGISTRY from retrack.utils.registry import Registry @@ -10,16 +9,16 @@ class Parser: def __init__( self, graph_data: dict, - component_registry: Registry = GLOBAL_NODE_REGISTRY, - validator_registry: Registry = GLOBAL_VALIDATOR_REGISTRY, + component_registry: Registry = nodes.registry(), + validator_registry: Registry = validators.registry(), unknown_node_error: bool = False, ): """Parses a dictionary of nodes and returns a dictionary of BaseNode objects. Args: data (dict): A dictionary of nodes. - component_registry (Registry, optional): A registry of BaseNode objects. Defaults to GLOBAL_NODE_REGISTRY. - validator_registry (Registry, optional): A registry of BaseValidator objects. Defaults to GLOBAL_VALIDATOR_REGISTRY. + component_registry (Registry, optional): A registry of BaseNode objects. Defaults to retrack.nodes.registry(). + validator_registry (Registry, optional): A registry of BaseValidator objects. Defaults to retrack.validators.registry(). unknown_node_error (bool, optional): Whether to raise an error if an unknown node is found. Defaults to False. """ Parser._check_input_data(graph_data) diff --git a/retrack/engine/validators/__init__.py b/retrack/engine/validators/__init__.py deleted file mode 100644 index 2f9b0f3..0000000 --- a/retrack/engine/validators/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from retrack.engine.validators.base import BaseValidator -from retrack.engine.validators.check_is_dag import CheckIsDAG -from retrack.engine.validators.node_exists import NodeExistsValidator -from retrack.utils.registry import Registry - -registry = Registry() - -registry.register( - "single_start_node_exists", - NodeExistsValidator("start", min_quantity=1, max_quantity=1), -) -registry.register("check_is_dag", CheckIsDAG()) - -__all__ = ["registry", "BaseValidator"] diff --git a/retrack/nodes/__init__.py b/retrack/nodes/__init__.py index dc5cf35..e7443e7 100644 --- a/retrack/nodes/__init__.py +++ b/retrack/nodes/__init__.py @@ -14,24 +14,33 @@ from retrack.nodes.startswithany import StartsWithAny from retrack.utils.registry import Registry -registry = Registry() - -registry.register("Input", Input) -registry.register("Start", Start) -registry.register("Constant", Constant) -registry.register("List", List) -registry.register("Bool", Bool) -registry.register("Output", Output) -registry.register("Check", Check) -registry.register("If", If) -registry.register("And", And) -registry.register("Or", Or) -registry.register("Not", Not) -registry.register("Math", Math) -registry.register("StartsWith", StartsWith) -registry.register("EndsWith", EndsWith) -registry.register("StartsWithAny", StartsWithAny) -registry.register("EndsWithAny", EndsWithAny) -registry.register("Contains", Contains) - -__all__ = ["registry", "BaseNode"] +_registry = Registry() + + +def registry() -> Registry: + return _registry + + +def register(name: str, node: BaseNode) -> None: + registry().register(name, node) + + +register("Input", Input) +register("Start", Start) +register("Constant", Constant) +register("List", List) +register("Bool", Bool) +register("Output", Output) +register("Check", Check) +register("If", If) +register("And", And) +register("Or", Or) +register("Not", Not) +register("Math", Math) +register("StartsWith", StartsWith) +register("EndsWith", EndsWith) +register("StartsWithAny", StartsWithAny) +register("EndsWithAny", EndsWithAny) +register("Contains", Contains) + +__all__ = ["registry", "register", "BaseNode"] diff --git a/retrack/validators/__init__.py b/retrack/validators/__init__.py new file mode 100644 index 0000000..8f0f759 --- /dev/null +++ b/retrack/validators/__init__.py @@ -0,0 +1,24 @@ +from retrack.utils.registry import Registry +from retrack.validators.base import BaseValidator +from retrack.validators.check_is_dag import CheckIsDAG +from retrack.validators.node_exists import NodeExistsValidator + +_registry = Registry() + + +def registry() -> Registry: + return _registry + + +def register(name: str, validator: BaseValidator) -> None: + registry().register(name, validator) + + +register( + "single_start_node_exists", + NodeExistsValidator("start", min_quantity=1, max_quantity=1), +) +register("check_is_dag", CheckIsDAG()) + + +__all__ = ["registry", "register", "BaseValidator"] diff --git a/retrack/engine/validators/base.py b/retrack/validators/base.py similarity index 100% rename from retrack/engine/validators/base.py rename to retrack/validators/base.py diff --git a/retrack/engine/validators/check_is_dag.py b/retrack/validators/check_is_dag.py similarity index 86% rename from retrack/engine/validators/check_is_dag.py rename to retrack/validators/check_is_dag.py index 1c2c9db..3115c1e 100644 --- a/retrack/engine/validators/check_is_dag.py +++ b/retrack/validators/check_is_dag.py @@ -1,6 +1,6 @@ import networkx as nx -from retrack.engine.validators.base import BaseValidator +from retrack.validators.base import BaseValidator class CheckIsDAG(BaseValidator): diff --git a/retrack/engine/validators/node_exists.py b/retrack/validators/node_exists.py similarity index 95% rename from retrack/engine/validators/node_exists.py rename to retrack/validators/node_exists.py index 58172fc..a0894db 100644 --- a/retrack/engine/validators/node_exists.py +++ b/retrack/validators/node_exists.py @@ -1,4 +1,4 @@ -from retrack.engine.validators.base import BaseValidator +from retrack.validators.base import BaseValidator class NodeExistsValidator(BaseValidator): diff --git a/tests/test_engine/test_runner.py b/tests/test_engine/test_runner.py index f133a17..b58a7ac 100644 --- a/tests/test_engine/test_runner.py +++ b/tests/test_engine/test_runner.py @@ -4,6 +4,7 @@ from retrack import Parser, Runner + @pytest.mark.parametrize( "filename, in_values, expected_out_values", [ @@ -28,7 +29,6 @@ {"message": "valid age", "output": True}, ], ), - ], ) def test_flows(filename, in_values, expected_out_values): From 96c0de5beb673708bf43241a556b6ce17c94f01a Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Thu, 27 Apr 2023 22:38:18 -0300 Subject: [PATCH 3/7] Refactor parser class --- retrack/engine/parser.py | 151 ++++++++++++++++++--------------------- retrack/engine/runner.py | 4 +- retrack/utils/graph.py | 6 +- tests/test_parser.py | 13 ++-- 4 files changed, 78 insertions(+), 96 deletions(-) diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index 70a1ab3..b69cc93 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -11,62 +11,19 @@ def __init__( graph_data: dict, component_registry: Registry = nodes.registry(), validator_registry: Registry = validators.registry(), - unknown_node_error: bool = False, ): - """Parses a dictionary of nodes and returns a dictionary of BaseNode objects. - - Args: - data (dict): A dictionary of nodes. - component_registry (Registry, optional): A registry of BaseNode objects. Defaults to retrack.nodes.registry(). - validator_registry (Registry, optional): A registry of BaseValidator objects. Defaults to retrack.validators.registry(). - unknown_node_error (bool, optional): Whether to raise an error if an unknown node is found. Defaults to False. - """ - Parser._check_input_data(graph_data) - - node_registry = Registry() - self._indexes_by_kind_map = {} - self._indexes_by_name_map = {} + self.__components = {} self.__edges = None - for node_id, node_data in graph_data["nodes"].items(): - node_name = node_data.get("name", None) - - Parser._check_node_name(node_name, node_id) - - node_name = node_name.lower() - - validation_model = component_registry.get(node_name) - if validation_model is not None: - if node_name not in self._indexes_by_name_map: - self._indexes_by_name_map[node_name] = [] - - self._indexes_by_name_map[node_name].append(node_id) + self._check_input_data(graph_data) - node_data["id"] = node_id - if node_id not in node_registry: - node = validation_model(**node_data) - node_registry.register(node_id, node) + self._set_components(graph_data, component_registry) + self._set_edges() - if node.kind() not in self._indexes_by_kind_map: - self._indexes_by_kind_map[node.kind()] = [] + self._validate_graph(graph_data, validator_registry) - self._indexes_by_kind_map[node.kind()].append(node_id) - - elif unknown_node_error: - raise ValueError(f"Unknown node name: {node_name}") - - self._node_registry = node_registry - - for validator_name, validator in validator_registry.data.items(): - if not validator.validate(graph_data=graph_data, edges=self.edges): - raise ValueError(f"Invalid graph data: {validator_name}") - - @staticmethod - def _check_node_name(node_name: str, node_id: str): - if node_name is None: - raise ValueError(f"BaseNode {node_id} has no name") - if not isinstance(node_name, str): - raise TypeError(f"BaseNode {node_id} name must be a string") + self._set_indexes_by_name_map() + self._set_indexes_by_kind_map() @staticmethod def _check_input_data(data: dict): @@ -82,53 +39,83 @@ def _check_input_data(data: dict): + str(type(data["nodes"])) ) + @staticmethod + def _check_node_name(node_name: str, node_id: str): + if node_name is None: + raise ValueError(f"BaseNode {node_id} has no name") + if not isinstance(node_name, str): + raise TypeError(f"BaseNode {node_id} name must be a string") + @property - def nodes(self) -> typing.Dict[str, BaseNode]: - return self._node_registry.data + def components(self) -> typing.Dict[str, BaseNode]: + return self.__components + + def _set_components(self, graph_data: dict, component_registry: Registry): + for node_id, node_metadata in graph_data["nodes"].items(): + if node_id in self.__components: + raise ValueError(f"Duplicate node id: {node_id}") + + node_name = node_metadata.get("name", None) + self._check_node_name(node_name, node_id) + + node_name = node_name.lower() + + validation_model = component_registry.get(node_name) + + if validation_model is None: + raise ValueError(f"Unknown node name: {node_name}") + + self.__components[node_id] = validation_model(**node_metadata) @property def edges(self) -> typing.List[typing.Tuple[str, str]]: - if self.__edges is None: - edges = [] + return self.__edges - for node_id, node in self.nodes.items(): - for _, output_connection in node.outputs: - for c in output_connection.connections: - edges.append((node_id, c.node)) + def _set_edges(self): + self.__edges = [] - self.__edges = edges + for node_id, node in self.components.items(): + for _, output_connection in node.outputs: + for c in output_connection.connections: + self.__edges.append((node_id, c.node)) - return self.__edges + def _validate_graph(self, graph_data: dict, validator_registry: Registry): + for validator_name, validator in validator_registry.data.items(): + if not validator.validate(graph_data=graph_data, edges=self.edges): + raise ValueError(f"Invalid graph data: {validator_name}") - @property - def data(self) -> dict: - return {i: j.dict(by_alias=True) for i, j in self.nodes.items()} + def get_by_id(self, id_: str) -> BaseNode: + return self.components.get(id_) @property - def tokens(self) -> dict: - """Returns a dictionary of tokens (node name) and their associated node ids.""" + def indexes_by_name_map(self) -> typing.Dict[str, typing.List[str]]: return self._indexes_by_name_map + def _set_indexes_by_name_map(self): + self._indexes_by_name_map = {} + + for node_id, node in self.components.items(): + node_name = node.__class__.__name__.lower() + if node_name not in self._indexes_by_name_map: + self._indexes_by_name_map[node_name] = [] + + self._indexes_by_name_map[node_name].append(node_id) + + def get_by_name(self, name: str) -> typing.List[BaseNode]: + return [self.get_by_id(id_) for id_ in self.indexes_by_name_map[name]] + @property - def indexes_by_kind_map(self) -> dict: - """Returns a dictionary of node kinds and their associated node ids.""" + def indexes_by_kind_map(self) -> typing.Dict[str, typing.List[str]]: return self._indexes_by_kind_map - def get_node_by_id(self, node_id: str) -> BaseNode: - return self._node_registry.get(node_id) - - def get_nodes_by_name(self, node_name: str) -> typing.List[BaseNode]: - node_name = node_name.lower() - return [self.get_node_by_id(i) for i in self.tokens.get(node_name, [])] + def _set_indexes_by_kind_map(self): + self._indexes_by_kind_map = {} - def get_nodes_by_multiple_names(self, node_names: list) -> typing.List[BaseNode]: - all_nodes = [] - for node_name in node_names: - nodes = self.get_nodes_by_name(node_name) - if nodes is not None: - all_nodes.extend(nodes) + for node_id, node in self.components.items(): + if node.kind() not in self._indexes_by_kind_map: + self._indexes_by_kind_map[node.kind()] = [] - return all_nodes + self._indexes_by_kind_map[node.kind()].append(node_id) - def get_nodes_by_kind(self, kind: str) -> typing.List[BaseNode]: - return [self.get_node_by_id(i) for i in self.indexes_by_kind_map.get(kind, [])] + def get_by_kind(self, kind: str) -> typing.List[BaseNode]: + return [self.get_by_id(id_) for id_ in self.indexes_by_kind_map[kind]] diff --git a/retrack/engine/runner.py b/retrack/engine/runner.py index 83de27b..02c5c2a 100644 --- a/retrack/engine/runner.py +++ b/retrack/engine/runner.py @@ -14,7 +14,7 @@ class Runner: def __init__(self, parser: Parser): self._parser = parser - input_nodes = self._parser.get_nodes_by_kind(NodeKind.INPUT) + input_nodes = self._parser.get_by_kind(NodeKind.INPUT) self._input_new_columns = { f"{node.id}@{constants.INPUT_OUTPUT_VALUE_CONNECTOR_NAME}": node.data.name for node in input_nodes @@ -126,7 +126,7 @@ def __set_state_data( self._state_df.loc[filter_by, column] = value def __run_node(self, node_id: str): - node = self._parser.get_node_by_id(node_id) + node = self._parser.get_by_id(node_id) current_node_filter = self._filters.get(node_id, None) input_params = self.__get_input_params( diff --git a/retrack/utils/graph.py b/retrack/utils/graph.py index cf4eb76..1a9ff13 100644 --- a/retrack/utils/graph.py +++ b/retrack/utils/graph.py @@ -16,7 +16,7 @@ def get_node_connections(node, is_input: bool = True, filter_by_connector=None): def walk(parser, actual_id: str, skiped_ids=[], callback=None): - node = parser.get_node_by_id(actual_id) + node = parser.get_by_id(actual_id) if callback: callback(node.id) skiped_ids.append(actual_id) @@ -25,7 +25,7 @@ def walk(parser, actual_id: str, skiped_ids=[], callback=None): for next_id in output_ids: if next_id not in skiped_ids: - next_node = parser.get_node_by_id(next_id) + next_node = parser.get_by_id(next_id) next_node_input_ids = get_node_connections(next_node, is_input=True) run_next = True @@ -41,7 +41,7 @@ def walk(parser, actual_id: str, skiped_ids=[], callback=None): def get_execution_order(parser): - start_nodes = parser.get_nodes_by_name("start") + start_nodes = parser.get_by_name("start") if len(start_nodes) == 0: raise ValueError("No start node found") elif len(start_nodes) > 1: diff --git a/tests/test_parser.py b/tests/test_parser.py index b578b5a..a94a9eb 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -6,11 +6,10 @@ @pytest.mark.parametrize( - "data_filename,expected_data_filename,expected_tokens", + "data_filename,expected_tokens", [ ( "tests/resources/age-negative.json", - "tests/resources/age-negative-data.json", { "start": ["0"], "input": ["2", "13"], @@ -23,21 +22,17 @@ ) ], ) -def test_parser_extract(data_filename, expected_data_filename, expected_tokens): +def test_parser_extract(data_filename, expected_tokens): with open(data_filename) as f: input_data = json.load(f) - with open(expected_data_filename) as f: - expected_output_data = json.load(f) - parser = Parser(input_data) - assert parser.data == expected_output_data - assert parser.tokens == expected_tokens + assert parser.indexes_by_name_map == expected_tokens def test_parser_with_unknown_node(): with pytest.raises(ValueError): - Parser({"nodes": {"1": {"name": "Unknown"}}}, unknown_node_error=True) + Parser({"nodes": {"1": {"name": "Unknown"}}}) def test_parser_invalid_input_data(): From 6740f66d1c85efe4ac0b6d8ee2e312ae57a0e579 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 1 May 2023 22:09:40 -0300 Subject: [PATCH 4/7] Fix shared memory issue --- retrack/engine/parser.py | 83 +++++++++++++-- retrack/engine/runner.py | 173 ++++++++++++++++--------------- retrack/nodes/base.py | 14 +++ retrack/nodes/constants.py | 4 + retrack/nodes/match.py | 4 + retrack/utils/graph.py | 50 --------- tests/test_engine/test_runner.py | 6 +- 7 files changed, 192 insertions(+), 142 deletions(-) delete mode 100644 retrack/utils/graph.py diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index b69cc93..f67d60e 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -1,7 +1,6 @@ import typing from retrack import nodes, validators -from retrack.nodes import BaseNode from retrack.utils.registry import Registry @@ -12,6 +11,7 @@ def __init__( component_registry: Registry = nodes.registry(), validator_registry: Registry = validators.registry(), ): + self._execution_order = None self.__components = {} self.__edges = None @@ -24,6 +24,8 @@ def __init__( self._set_indexes_by_name_map() self._set_indexes_by_kind_map() + self._set_execution_order() + self._set_indexes_by_memory_type_map() @staticmethod def _check_input_data(data: dict): @@ -47,7 +49,7 @@ def _check_node_name(node_name: str, node_id: str): raise TypeError(f"BaseNode {node_id} name must be a string") @property - def components(self) -> typing.Dict[str, BaseNode]: + def components(self) -> typing.Dict[str, nodes.BaseNode]: return self.__components def _set_components(self, graph_data: dict, component_registry: Registry): @@ -84,7 +86,7 @@ def _validate_graph(self, graph_data: dict, validator_registry: Registry): if not validator.validate(graph_data=graph_data, edges=self.edges): raise ValueError(f"Invalid graph data: {validator_name}") - def get_by_id(self, id_: str) -> BaseNode: + def get_by_id(self, id_: str) -> nodes.BaseNode: return self.components.get(id_) @property @@ -101,8 +103,8 @@ def _set_indexes_by_name_map(self): self._indexes_by_name_map[node_name].append(node_id) - def get_by_name(self, name: str) -> typing.List[BaseNode]: - return [self.get_by_id(id_) for id_ in self.indexes_by_name_map[name]] + def get_by_name(self, name: str) -> typing.List[nodes.BaseNode]: + return [self.get_by_id(id_) for id_ in self.indexes_by_name_map.get(name, [])] @property def indexes_by_kind_map(self) -> typing.Dict[str, typing.List[str]]: @@ -117,5 +119,72 @@ def _set_indexes_by_kind_map(self): self._indexes_by_kind_map[node.kind()].append(node_id) - def get_by_kind(self, kind: str) -> typing.List[BaseNode]: - return [self.get_by_id(id_) for id_ in self.indexes_by_kind_map[kind]] + def get_by_kind(self, kind: str) -> typing.List[nodes.BaseNode]: + return [self.get_by_id(id_) for id_ in self.indexes_by_kind_map.get(kind, [])] + + @property + def indexes_by_memory_type_map(self) -> typing.Dict[str, typing.List[str]]: + return self._indexes_by_memory_type_map + + def _set_indexes_by_memory_type_map(self): + self._indexes_by_memory_type_map = {} + + for node_id, node in self.components.items(): + memory_type = node.memory_type() + if memory_type not in self.indexes_by_memory_type_map: + self._indexes_by_memory_type_map[memory_type] = [] + + self._indexes_by_memory_type_map[memory_type].append(node_id) + + def get_by_memory_type(self, memory_type: str) -> typing.List[nodes.BaseNode]: + return [ + self.get_by_id(id_) + for id_ in self.indexes_by_memory_type_map.get(memory_type, []) + ] + + @property + def execution_order(self) -> typing.List[str]: + return self._execution_order + + def _set_execution_order(self): + start_nodes = self.get_by_name("start") + + self._execution_order = self._walk(start_nodes[0].id, []) + + def get_node_connections( + self, node_id: str, is_input: bool = True, filter_by_connector=None + ): + node_dict = self.get_by_id(node_id).dict(by_alias=True) + + connectors = node_dict.get("inputs" if is_input else "outputs", {}) + result = [] + + for connector_name, value in connectors.items(): + if ( + filter_by_connector is not None + and connector_name != filter_by_connector + ): + continue + + for connection in value["connections"]: + result.append(connection["node"]) + return result + + def _walk(self, actual_id: str, skiped_ids: list): + skiped_ids.append(actual_id) + + output_ids = self.get_node_connections(actual_id, is_input=False) + + for next_id in output_ids: + if next_id not in skiped_ids: + next_node_input_ids = self.get_node_connections(next_id, is_input=True) + run_next = True + for next_node_input_id in next_node_input_ids: + if next_node_input_id not in skiped_ids: + run_next = False + break + + if run_next: + self._walk(next_id, skiped_ids) + + return skiped_ids diff --git a/retrack/engine/runner.py b/retrack/engine/runner.py index 02c5c2a..9bf5f04 100644 --- a/retrack/engine/runner.py +++ b/retrack/engine/runner.py @@ -6,25 +6,21 @@ from retrack.engine.parser import Parser from retrack.engine.request_manager import RequestManager -from retrack.nodes.base import NodeKind -from retrack.utils import constants, graph +from retrack.nodes.base import NodeKind, NodeMemoryType +from retrack.utils import constants class Runner: def __init__(self, parser: Parser): self._parser = parser + self.reset() + self._set_constants() + self._set_input_columns() + self._request_manager = RequestManager(self._parser.get_by_kind(NodeKind.INPUT)) - input_nodes = self._parser.get_by_kind(NodeKind.INPUT) - self._input_new_columns = { - f"{node.id}@{constants.INPUT_OUTPUT_VALUE_CONNECTOR_NAME}": node.data.name - for node in input_nodes - } - self._request_manager = RequestManager(input_nodes) - - self._execution_order = graph.get_execution_order(self._parser) - self._state_df = None - - self._filters = {} + @property + def parser(self) -> Parser: + return self._parser @property def request_manager(self) -> RequestManager: @@ -35,27 +31,59 @@ def request_model(self) -> pydantic.BaseModel: return self._request_manager.model @property - def state_df(self) -> pd.DataFrame: - return self._state_df - - @property - def states(self) -> list: - return self._state_df.to_dict(orient="records") + def states(self) -> pd.DataFrame: + return self._states @property def filters(self) -> dict: return self._filters @property - def filter_df(self) -> pd.DataFrame: - return pd.DataFrame(self._filters) + def constants(self) -> dict: + return self._constants + + def _set_constants(self): + constant_nodes = self.parser.get_by_memory_type(NodeMemoryType.CONSTANT) + self._constants = {node.id: node.data.value for node in constant_nodes} + + @property + def input_columns(self) -> dict: + return self._input_columns + + def _set_input_columns(self): + input_nodes = self._parser.get_by_kind(NodeKind.INPUT) + self._input_columns = { + f"{node.id}@{constants.INPUT_OUTPUT_VALUE_CONNECTOR_NAME}": node.data.name + for node in input_nodes + } + + def reset(self): + self._states = None + self._filters = {} - def __get_initial_state_df(self, payload: typing.Union[dict, list]) -> pd.DataFrame: + def __set_output_connection_filters( + self, node_id: str, filter: typing.Any, filter_by_connector=None + ): + if filter is not None: + output_connections = self.parser.get_node_connections( + node_id, is_input=False, filter_by_connector=filter_by_connector + ) + for output_connection_id in output_connections: + if self._filters.get(output_connection_id, None) is None: + self._filters[output_connection_id] = filter + else: + self._filters[output_connection_id] = ( + self._filters[output_connection_id] & filter + ) + + def _create_state_from_payload( + self, payload: typing.Union[dict, list] + ) -> pd.DataFrame: validated_payload = self.request_manager.validate(payload) validated_payload = pd.DataFrame([p.dict() for p in validated_payload]) state_df = pd.DataFrame([]) - for node_id, input_name in self._input_new_columns.items(): + for node_id, input_name in self.input_columns.items(): state_df[node_id] = validated_payload[input_name] state_df[constants.OUTPUT_REFERENCE_COLUMN] = np.nan @@ -63,38 +91,6 @@ def __get_initial_state_df(self, payload: typing.Union[dict, list]) -> pd.DataFr return state_df - @staticmethod - def __get_output_state_df(state_df: pd.DataFrame) -> pd.DataFrame: - output_state_df = state_df[ - [ - constants.OUTPUT_REFERENCE_COLUMN, - constants.OUTPUT_MESSAGE_REFERENCE_COLUMN, - ] - ].copy() - - output_state_df = output_state_df.rename( - columns={ - constants.OUTPUT_REFERENCE_COLUMN: "output", - constants.OUTPUT_MESSAGE_REFERENCE_COLUMN: "message", - } - ) - - return output_state_df - - def __set_output_connection_filters( - self, node, value: typing.Any, filter_by_connector=None - ): - output_connections = graph.get_node_connections( - node, is_input=False, filter_by_connector=filter_by_connector - ) - for output_connection_id in output_connections: - if self._filters.get(output_connection_id, None) is None: - self._filters[output_connection_id] = value - else: - self._filters[output_connection_id] = ( - self._filters[output_connection_id] & value - ) - def __get_input_params( self, node_dict: dict, current_node_filter: pd.Series ) -> dict: @@ -111,34 +107,31 @@ def __get_input_params( return input_params - def __get_state_data(self, column: str, filter_by: typing.Any = None): - if filter_by is None: - return self._state_df[column] - else: - return self._state_df.loc[filter_by, column] - def __set_state_data( self, column: str, value: typing.Any, filter_by: typing.Any = None ): if filter_by is None: - self._state_df[column] = value + self._states[column] = value + else: + self._states.loc[filter_by, column] = value + + def __get_state_data(self, column: str, filter_by: typing.Any = None): + if filter_by is None: + return self._states[column] else: - self._state_df.loc[filter_by, column] = value + return self._states.loc[filter_by, column] def __run_node(self, node_id: str): - node = self._parser.get_by_id(node_id) current_node_filter = self._filters.get(node_id, None) + # if there is a filter, we need to set the children nodes to receive filtered data + self.__set_output_connection_filters(node_id, current_node_filter) + node = self.parser.get_by_id(node_id) input_params = self.__get_input_params( node.dict(by_alias=True), current_node_filter ) - - if ( - current_node_filter is not None - ): # if there is a filter, we need to set the children nodes to receive filtered data - self.__set_output_connection_filters(node, current_node_filter) - output = node.run(**input_params) + for output_name, output_value in output.items(): if ( output_name == constants.OUTPUT_REFERENCE_COLUMN @@ -146,27 +139,41 @@ def __run_node(self, node_id: str): ): # Setting output values self.__set_state_data(output_name, output_value, current_node_filter) elif output_name.endswith(constants.FILTER_SUFFIX): # Setting filters - self.__set_output_connection_filters(node, output_value, output_name) + self.__set_output_connection_filters(node_id, output_value, output_name) else: # Setting node outputs to be used as inputs by other nodes self.__set_state_data( f"{node_id}@{output_name}", output_value, current_node_filter ) - def __call__( - self, payload: typing.Union[dict, list], to_dict: bool = True - ) -> pd.DataFrame: - self._state_df = self.__get_initial_state_df(payload) - self._filters = {} + @staticmethod + def __get_output_state_df(state_df: pd.DataFrame) -> pd.DataFrame: + output_state_df = state_df[ + [ + constants.OUTPUT_REFERENCE_COLUMN, + constants.OUTPUT_MESSAGE_REFERENCE_COLUMN, + ] + ].copy() - for node_id in self._execution_order: + output_state_df = output_state_df.rename( + columns={ + constants.OUTPUT_REFERENCE_COLUMN: "output", + constants.OUTPUT_MESSAGE_REFERENCE_COLUMN: "message", + } + ) + + return output_state_df + + def execute(self, payload: typing.Union[dict, list]) -> pd.DataFrame: + self.reset() + self._states = self._create_state_from_payload(payload) + + for node_id in self.parser.execution_order: try: self.__run_node(node_id) except Exception as e: - raise e # TODO: Handle errors - if self._state_df[constants.OUTPUT_REFERENCE_COLUMN].isna().sum() == 0: + print(f"Error running node {node_id}") + raise e + if self.states[constants.OUTPUT_REFERENCE_COLUMN].isna().sum() == 0: break - if to_dict: - return self.__get_output_state_df(self._state_df).to_dict(orient="records") - - return Runner.__get_output_state_df(self._state_df) + return Runner.__get_output_state_df(self.states) diff --git a/retrack/nodes/base.py b/retrack/nodes/base.py index 27c8488..65f36b8 100644 --- a/retrack/nodes/base.py +++ b/retrack/nodes/base.py @@ -19,6 +19,17 @@ class NodeKind(str, enum.Enum): OTHER = "other" +############################################################### +# Node Memory Types +############################################################### + + +class NodeMemoryType(str, enum.Enum): + STATE = "state" + FILTER = "filter" + CONSTANT = "constant" + + ############################################################### # Connection Models ############################################################### @@ -57,3 +68,6 @@ def run(self, **kwargs) -> typing.Dict[str, typing.Any]: def kind(self) -> NodeKind: return NodeKind.OTHER + + def memory_type(self) -> NodeMemoryType: + return NodeMemoryType.STATE diff --git a/retrack/nodes/constants.py b/retrack/nodes/constants.py index 05766ca..96cda38 100644 --- a/retrack/nodes/constants.py +++ b/retrack/nodes/constants.py @@ -6,6 +6,7 @@ BaseNode, InputConnectionModel, NodeKind, + NodeMemoryType, OutputConnectionModel, ) @@ -80,6 +81,9 @@ class List(BaseConstant): def run(self, **kwargs) -> typing.Dict[str, typing.Any]: return {"output_list": self.data.value} + def memory_type(self) -> NodeMemoryType: + return NodeMemoryType.CONSTANT + class Bool(BaseConstant): data: BoolMetadataModel = BoolMetadataModel(value=False) diff --git a/retrack/nodes/match.py b/retrack/nodes/match.py index f424466..df76e2f 100644 --- a/retrack/nodes/match.py +++ b/retrack/nodes/match.py @@ -7,6 +7,7 @@ BaseNode, InputConnectionModel, NodeKind, + NodeMemoryType, OutputConnectionModel, ) @@ -41,3 +42,6 @@ def run(self, input_bool: pd.Series) -> typing.Dict[str, pd.Series]: f"output_then_filter": input_bool, f"output_else_filter": ~input_bool, } + + def memory_type(self) -> NodeMemoryType: + return NodeMemoryType.FILTER diff --git a/retrack/utils/graph.py b/retrack/utils/graph.py deleted file mode 100644 index 1a9ff13..0000000 --- a/retrack/utils/graph.py +++ /dev/null @@ -1,50 +0,0 @@ -def get_node_connections(node, is_input: bool = True, filter_by_connector=None): - if isinstance(node, dict): - node_dict = node - else: - node_dict = node.dict(by_alias=True) - connectors = node_dict.get("inputs" if is_input else "outputs", {}) - result = [] - - for connector_name, value in connectors.items(): - if filter_by_connector is not None and connector_name != filter_by_connector: - continue - - for connection in value["connections"]: - result.append(connection["node"]) - return result - - -def walk(parser, actual_id: str, skiped_ids=[], callback=None): - node = parser.get_by_id(actual_id) - if callback: - callback(node.id) - skiped_ids.append(actual_id) - - output_ids = get_node_connections(node, is_input=False) - - for next_id in output_ids: - if next_id not in skiped_ids: - next_node = parser.get_by_id(next_id) - - next_node_input_ids = get_node_connections(next_node, is_input=True) - run_next = True - for next_node_input_id in next_node_input_ids: - if next_node_input_id not in skiped_ids: - run_next = False - break - - if run_next: - walk(parser, next_id, skiped_ids, callback) - - return skiped_ids - - -def get_execution_order(parser): - start_nodes = parser.get_by_name("start") - if len(start_nodes) == 0: - raise ValueError("No start node found") - elif len(start_nodes) > 1: - raise ValueError("Multiple start nodes found") - - return walk(parser, start_nodes[0].id) diff --git a/tests/test_engine/test_runner.py b/tests/test_engine/test_runner.py index b58a7ac..c0944f0 100644 --- a/tests/test_engine/test_runner.py +++ b/tests/test_engine/test_runner.py @@ -3,6 +3,7 @@ import pytest from retrack import Parser, Runner +import pandas as pd @pytest.mark.parametrize( @@ -36,6 +37,7 @@ def test_flows(filename, in_values, expected_out_values): rule = json.load(f) runner = Runner(Parser(rule)) - out_values = runner(in_values) + out_values = runner.execute(in_values) - assert out_values == expected_out_values + assert isinstance(out_values, pd.DataFrame) + assert out_values.to_dict(orient="records") == expected_out_values From 4fc5946d6f65dfedec49428a9e92588e04cbaf08 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Tue, 2 May 2023 10:42:54 -0300 Subject: [PATCH 5/7] Version 0.4.0 --- pyproject.toml | 2 +- tests/test_engine/test_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 03282b4..a2f8ab9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "retrack" -version = "0.3.0" +version = "0.4.0" description = "A business rules engine" authors = ["Gabriel Guarisa ", "Nathalia Trotte "] license = "MIT" diff --git a/tests/test_engine/test_runner.py b/tests/test_engine/test_runner.py index c0944f0..d88cbc4 100644 --- a/tests/test_engine/test_runner.py +++ b/tests/test_engine/test_runner.py @@ -1,9 +1,9 @@ import json +import pandas as pd import pytest from retrack import Parser, Runner -import pandas as pd @pytest.mark.parametrize( From 27f82331d32a6bc3f046e2500ed1d1f49d30d9e4 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Tue, 2 May 2023 14:20:18 -0300 Subject: [PATCH 6/7] Fix list constant issue --- retrack/engine/runner.py | 45 ++++++++++++++++++---------------- retrack/nodes/check.py | 7 +----- retrack/nodes/constants.py | 2 +- retrack/nodes/contains.py | 3 ++- retrack/nodes/endswithany.py | 4 ++- retrack/nodes/startswithany.py | 4 ++- retrack/utils/transformers.py | 8 ++++++ 7 files changed, 42 insertions(+), 31 deletions(-) create mode 100644 retrack/utils/transformers.py diff --git a/retrack/engine/runner.py b/retrack/engine/runner.py index 9bf5f04..fe3c447 100644 --- a/retrack/engine/runner.py +++ b/retrack/engine/runner.py @@ -44,7 +44,10 @@ def constants(self) -> dict: def _set_constants(self): constant_nodes = self.parser.get_by_memory_type(NodeMemoryType.CONSTANT) - self._constants = {node.id: node.data.value for node in constant_nodes} + self._constants = {} + for node in constant_nodes: + for output_connector_name, _ in node.outputs: + self._constants[f"{node.id}@{output_connector_name}"] = node.data.value @property def input_columns(self) -> dict: @@ -76,9 +79,10 @@ def __set_output_connection_filters( self._filters[output_connection_id] & filter ) - def _create_state_from_payload( + def _create_initial_state_from_payload( self, payload: typing.Union[dict, list] ) -> pd.DataFrame: + """Create initial state from payload. This is the first step of the runner.""" validated_payload = self.request_manager.validate(payload) validated_payload = pd.DataFrame([p.dict() for p in validated_payload]) @@ -116,10 +120,13 @@ def __set_state_data( self._states.loc[filter_by, column] = value def __get_state_data(self, column: str, filter_by: typing.Any = None): + if column in self._constants: + return self._constants[column] + if filter_by is None: return self._states[column] - else: - return self._states.loc[filter_by, column] + + return self._states.loc[filter_by, column] def __run_node(self, node_id: str): current_node_filter = self._filters.get(node_id, None) @@ -127,6 +134,10 @@ def __run_node(self, node_id: str): self.__set_output_connection_filters(node_id, current_node_filter) node = self.parser.get_by_id(node_id) + + if node.memory_type == NodeMemoryType.CONSTANT: + return + input_params = self.__get_input_params( node.dict(by_alias=True), current_node_filter ) @@ -145,27 +156,19 @@ def __run_node(self, node_id: str): f"{node_id}@{output_name}", output_value, current_node_filter ) - @staticmethod - def __get_output_state_df(state_df: pd.DataFrame) -> pd.DataFrame: - output_state_df = state_df[ - [ - constants.OUTPUT_REFERENCE_COLUMN, - constants.OUTPUT_MESSAGE_REFERENCE_COLUMN, - ] - ].copy() - - output_state_df = output_state_df.rename( - columns={ - constants.OUTPUT_REFERENCE_COLUMN: "output", - constants.OUTPUT_MESSAGE_REFERENCE_COLUMN: "message", + def __get_output_states(self) -> pd.DataFrame: + """Returns a dataframe with the final states of the flow""" + return pd.DataFrame( + { + "output": self.states[constants.OUTPUT_REFERENCE_COLUMN], + "message": self.states[constants.OUTPUT_MESSAGE_REFERENCE_COLUMN], } ) - return output_state_df - def execute(self, payload: typing.Union[dict, list]) -> pd.DataFrame: + """Executes the flow with the given payload""" self.reset() - self._states = self._create_state_from_payload(payload) + self._states = self._create_initial_state_from_payload(payload) for node_id in self.parser.execution_order: try: @@ -176,4 +179,4 @@ def execute(self, payload: typing.Union[dict, list]) -> pd.DataFrame: if self.states[constants.OUTPUT_REFERENCE_COLUMN].isna().sum() == 0: break - return Runner.__get_output_state_df(self.states) + return self.__get_output_states() diff --git a/retrack/nodes/check.py b/retrack/nodes/check.py index 17170c7..fecfc40 100644 --- a/retrack/nodes/check.py +++ b/retrack/nodes/check.py @@ -5,12 +5,7 @@ import pandas as pd import pydantic -from retrack.nodes.base import ( - BaseNode, - InputConnectionModel, - NodeKind, - OutputConnectionModel, -) +from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel ############################################################### # Check Metadata Models diff --git a/retrack/nodes/constants.py b/retrack/nodes/constants.py index 96cda38..8faf69c 100644 --- a/retrack/nodes/constants.py +++ b/retrack/nodes/constants.py @@ -79,7 +79,7 @@ class List(BaseConstant): outputs: ListOutputsModel def run(self, **kwargs) -> typing.Dict[str, typing.Any]: - return {"output_list": self.data.value} + return {} # {"output_list": self.data.value} def memory_type(self) -> NodeMemoryType: return NodeMemoryType.CONSTANT diff --git a/retrack/nodes/contains.py b/retrack/nodes/contains.py index 81d62d8..9cd6bea 100644 --- a/retrack/nodes/contains.py +++ b/retrack/nodes/contains.py @@ -2,6 +2,7 @@ import pydantic from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel +from retrack.utils import transformers ################################################ # Contains Inputs Outputs @@ -27,4 +28,4 @@ class Contains(BaseNode): outputs: ContainsOutputsModel def run(self, input_list: pd.Series, input_value: pd.Series) -> pd.Series: - return {"output_bool": input_value.isin(input_list.to_list())} + return {"output_bool": input_value.isin(transformers.to_list(input_list))} diff --git a/retrack/nodes/endswithany.py b/retrack/nodes/endswithany.py index 88fa851..b24a3d5 100644 --- a/retrack/nodes/endswithany.py +++ b/retrack/nodes/endswithany.py @@ -2,6 +2,7 @@ import pydantic from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel +from retrack.utils import transformers ################################################ # EndsWithAny Inputs Outputs @@ -27,4 +28,5 @@ class EndsWithAny(BaseNode): outputs: EndsWithAnyOutputsModel def run(self, input_value: pd.Series, input_list: pd.Series) -> pd.Series: - return {"output_bool": input_value.str.endswith(tuple(input_list.to_list()))} + input_list = transformers.to_list(input_list) + return {"output_bool": input_value.str.endswith(tuple(input_list))} diff --git a/retrack/nodes/startswithany.py b/retrack/nodes/startswithany.py index 50975f0..92f5a54 100644 --- a/retrack/nodes/startswithany.py +++ b/retrack/nodes/startswithany.py @@ -2,6 +2,7 @@ import pydantic from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel +from retrack.utils import transformers ################################################ # StartsWithAny Inputs Outputs @@ -27,4 +28,5 @@ class StartsWithAny(BaseNode): outputs: StartsWithAnyOutputsModel def run(self, input_value: pd.Series, input_list: pd.Series) -> pd.Series: - return {"output_bool": input_value.str.startswith(tuple(input_list.to_list()))} + input_list = transformers.to_list(input_list) + return {"output_bool": input_value.str.startswith(tuple(input_list))} diff --git a/retrack/utils/transformers.py b/retrack/utils/transformers.py new file mode 100644 index 0000000..4c10c15 --- /dev/null +++ b/retrack/utils/transformers.py @@ -0,0 +1,8 @@ +import pandas as pd + + +def to_list(input_list): + if isinstance(input_list, pd.Series): + input_list = input_list.to_list() + + return input_list From ad57fda64e7b5a58e9eb009e98826f8642bfd19c Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Tue, 2 May 2023 14:22:35 -0300 Subject: [PATCH 7/7] Version 0.4.1 --- README.md | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5952b3b..6948784 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ parser = retrack.Parser(rule) runner = retrack.Runner(parser) # Run the rule/model passing the data -runner(data) +runner.execute(data) ``` The `Parser` class parses the rule/model and creates a graph of nodes. The `Runner` class runs the rule/model using the data passed to the runner. The `data` is a dictionary or a list of dictionaries containing the data that will be used to evaluate the conditions and execute the actions. To see wich data is required for the given rule/model, check the `runner.request_model` property that is a pydantic model used to validate the data. diff --git a/pyproject.toml b/pyproject.toml index a2f8ab9..7312434 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "retrack" -version = "0.4.0" +version = "0.4.1" description = "A business rules engine" authors = ["Gabriel Guarisa ", "Nathalia Trotte "] license = "MIT"