From 9df3cf29b4f1a843c4c31a1b01a32d26400c51e3 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Fri, 17 Feb 2023 10:19:18 -0300 Subject: [PATCH 01/11] Add node kind method --- retrack/nodes/base.py | 22 +++++++++++++++++++--- retrack/nodes/check.py | 11 ++++++----- retrack/nodes/constants.py | 12 ++++++++---- retrack/nodes/inputs.py | 10 +++++++++- retrack/nodes/match.py | 10 +++++++++- retrack/nodes/math.py | 4 ---- retrack/nodes/outputs.py | 6 ++++-- retrack/nodes/start.py | 5 ++++- 8 files changed, 59 insertions(+), 21 deletions(-) diff --git a/retrack/nodes/base.py b/retrack/nodes/base.py index 09162aa..a98ef7e 100644 --- a/retrack/nodes/base.py +++ b/retrack/nodes/base.py @@ -1,7 +1,24 @@ import typing +import enum + import pydantic +############################################################### +# Node Kind +############################################################### + + +class NodeKind(str, enum.Enum): + INPUT = "input" + CONSTANT = "constant" + OUTPUT = "output" + FILTER = "filter" + CONNECTOR = "connector" + START = "start" + OTHER = "other" + + ############################################################### # Connection Models ############################################################### @@ -38,6 +55,5 @@ class BaseNode(pydantic.BaseModel): def run(self, **kwargs) -> typing.Dict[str, typing.Any]: return {} - @property - def node_type(self) -> str: - raise NotImplementedError + def kind(self) -> NodeKind: + raise NodeKind.OTHER diff --git a/retrack/nodes/check.py b/retrack/nodes/check.py index 582bbbd..17d0deb 100644 --- a/retrack/nodes/check.py +++ b/retrack/nodes/check.py @@ -5,7 +5,12 @@ import pandas as pd import pydantic -from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel +from retrack.nodes.base import ( + BaseNode, + InputConnectionModel, + NodeKind, + OutputConnectionModel, +) ############################################################### # Check Metadata Models @@ -49,10 +54,6 @@ class Check(BaseNode): inputs: CheckInputsModel outputs: CheckOutputsModel - @property - def node_type(self) -> str: - return "logic.check" - def run( self, input_value_0: pd.Series, diff --git a/retrack/nodes/constants.py b/retrack/nodes/constants.py index baf9300..05766ca 100644 --- a/retrack/nodes/constants.py +++ b/retrack/nodes/constants.py @@ -2,7 +2,12 @@ import pydantic -from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel +from retrack.nodes.base import ( + BaseNode, + InputConnectionModel, + NodeKind, + OutputConnectionModel, +) ####################################################### # Constant Metadata Models @@ -56,9 +61,8 @@ class BoolOutputsModel(pydantic.BaseModel): class BaseConstant(BaseNode): inputs: typing.Optional[ConstantInputsModel] = None - @property - def node_type(self) -> str: - return "variable.constant" + def kind(self) -> NodeKind: + return NodeKind.CONSTANT class Constant(BaseConstant): diff --git a/retrack/nodes/inputs.py b/retrack/nodes/inputs.py index 9eae6ec..72db14f 100644 --- a/retrack/nodes/inputs.py +++ b/retrack/nodes/inputs.py @@ -2,7 +2,12 @@ import pydantic -from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel +from retrack.nodes.base import ( + BaseNode, + InputConnectionModel, + NodeKind, + OutputConnectionModel, +) ################################################ # Input Metadata Models @@ -36,3 +41,6 @@ class Input(BaseNode): data: InputMetadataModel inputs: typing.Optional[InputInputsModel] = None outputs: InputOutputsModel + + def kind(self) -> NodeKind: + return NodeKind.INPUT diff --git a/retrack/nodes/match.py b/retrack/nodes/match.py index be50662..f424466 100644 --- a/retrack/nodes/match.py +++ b/retrack/nodes/match.py @@ -3,7 +3,12 @@ import pandas as pd import pydantic -from retrack.nodes.base import BaseNode, InputConnectionModel, OutputConnectionModel +from retrack.nodes.base import ( + BaseNode, + InputConnectionModel, + NodeKind, + OutputConnectionModel, +) ################################################ # If Inputs and Outputs @@ -28,6 +33,9 @@ class If(BaseNode): inputs: IfInputsModel outputs: IfOutputsModel + def kind(self) -> NodeKind: + return NodeKind.FILTER + def run(self, input_bool: pd.Series) -> typing.Dict[str, pd.Series]: return { f"output_then_filter": input_bool, diff --git a/retrack/nodes/math.py b/retrack/nodes/math.py index b2c4246..9ce2d63 100644 --- a/retrack/nodes/math.py +++ b/retrack/nodes/math.py @@ -47,10 +47,6 @@ class Math(BaseNode): inputs: MathInputsModel outputs: MathOutputsModel - @property - def node_type(self) -> str: - return "logic.math" - def run( self, input_value_0: pd.Series, diff --git a/retrack/nodes/outputs.py b/retrack/nodes/outputs.py index 8468dab..1e74054 100644 --- a/retrack/nodes/outputs.py +++ b/retrack/nodes/outputs.py @@ -1,10 +1,9 @@ import typing -import numpy as np import pandas as pd import pydantic -from retrack.nodes.base import BaseNode, InputConnectionModel +from retrack.nodes.base import BaseNode, InputConnectionModel, NodeKind from retrack.utils import constants ################################################ @@ -32,6 +31,9 @@ class BoolOutput(BaseNode): inputs: typing.Optional[BoolOutputInputsModel] data: OutputMetadataModel + def kind(self) -> NodeKind: + return NodeKind.OUTPUT + def run(self, input_bool: pd.Series) -> typing.Dict[str, pd.Series]: return { constants.OUTPUT_REFERENCE_COLUMN: input_bool, diff --git a/retrack/nodes/start.py b/retrack/nodes/start.py index 3010bb6..54fce82 100644 --- a/retrack/nodes/start.py +++ b/retrack/nodes/start.py @@ -1,6 +1,6 @@ import pydantic -from retrack.nodes.base import BaseNode, OutputConnectionModel +from retrack.nodes.base import BaseNode, NodeKind, OutputConnectionModel ################################################ # Start Inputs and Outputs @@ -19,3 +19,6 @@ class StartOutputsModel(pydantic.BaseModel): class Start(BaseNode): outputs: StartOutputsModel + + def kind(self) -> NodeKind: + return NodeKind.START From 5706d14ec3e37a87c6b29b305c30e1621a466212 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Fri, 17 Feb 2023 12:40:08 -0300 Subject: [PATCH 02/11] Change search to use NodeKind enum --- retrack/engine/parser.py | 25 ++++++++++++++++--------- retrack/engine/runner.py | 4 ++-- retrack/nodes/base.py | 2 +- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index 52c9d80..e6a5cce 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -21,6 +21,7 @@ def __init__( Parser._check_input_data(data) node_registry = Registry() + self._kind_index_map = {} tokens = {} for node_id, node_data in data["nodes"].items(): @@ -39,7 +40,14 @@ def __init__( node_data["id"] = node_id if node_id not in node_registry: - node_registry.register(node_id, validation_model(**node_data)) + node = validation_model(**node_data) + node_registry.register(node_id, node) + + if node.kind() not in self._kind_index_map: + self._kind_index_map[node.kind()] = [] + + self._kind_index_map[node.kind()].append(node_id) + elif unknown_node_error: raise ValueError(f"Unknown node name: {node_name}") @@ -77,8 +85,14 @@ def data(self) -> dict: @property def tokens(self) -> dict: + """Returns a dictionary of tokens (node name) and their associated node ids.""" return self._tokens + @property + def kind_index_map(self) -> dict: + """Returns a dictionary of node kinds and their associated node ids.""" + return self._kind_index_map + def get_node_by_id(self, node_id: str) -> BaseNode: return self._node_registry.get(node_id) @@ -96,11 +110,4 @@ def get_nodes_by_multiple_names(self, node_names: list) -> typing.List[BaseNode] return all_nodes def get_nodes_by_kind(self, kind: str) -> typing.List[BaseNode]: - if kind == "input": - return self.get_nodes_by_multiple_names(INPUT_TOKENS) - if kind == "output": - return self.get_nodes_by_multiple_names(OUTPUT_TOKENS) - if kind == "constant": - return self.get_nodes_by_multiple_names(CONSTANT_TOKENS) - - return [] + return [self.get_node_by_id(i) for i in self.kind_index_map.get(kind, [])] diff --git a/retrack/engine/runner.py b/retrack/engine/runner.py index 21511ec..e0d9d0f 100644 --- a/retrack/engine/runner.py +++ b/retrack/engine/runner.py @@ -6,13 +6,13 @@ from retrack.engine.parser import Parser from retrack.engine.payload_manager import PayloadManager from retrack.utils import constants, graph - +from retrack.nodes.base import NodeKind class Runner: def __init__(self, parser: Parser): self._parser = parser - input_nodes = self._parser.get_nodes_by_kind("input") + input_nodes = self._parser.get_nodes_by_kind(NodeKind.INPUT) self._input_new_columns = { f"{node.id}@{constants.INPUT_OUTPUT_VALUE_CONNECTOR_NAME}": node.data.name for node in input_nodes diff --git a/retrack/nodes/base.py b/retrack/nodes/base.py index a98ef7e..27c8488 100644 --- a/retrack/nodes/base.py +++ b/retrack/nodes/base.py @@ -56,4 +56,4 @@ def run(self, **kwargs) -> typing.Dict[str, typing.Any]: return {} def kind(self) -> NodeKind: - raise NodeKind.OTHER + return NodeKind.OTHER From 931e0ccd1569be44ac874b267ee3ace389cafbe6 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Fri, 17 Feb 2023 13:00:55 -0300 Subject: [PATCH 03/11] Add runner test --- retrack/engine/parser.py | 31 +- retrack/engine/runner.py | 27 +- tests/resources/age-negative.json | 506 ++++++++++++++++++++++++++++++ tests/test_engine/test_runner.py | 26 ++ 4 files changed, 569 insertions(+), 21 deletions(-) create mode 100644 tests/resources/age-negative.json create mode 100644 tests/test_engine/test_runner.py diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index e6a5cce..4d6715c 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -4,12 +4,6 @@ from retrack.nodes import registry as GLOBAL_NODE_REGISTRY from retrack.utils.registry import Registry -INPUT_TOKENS = ["input"] - -OUTPUT_TOKENS = ["booloutput"] - -CONSTANT_TOKENS = ["constant", "list", "bool"] - class Parser: def __init__( @@ -21,8 +15,8 @@ def __init__( Parser._check_input_data(data) node_registry = Registry() - self._kind_index_map = {} - tokens = {} + self._indexes_by_kind_map = {} + self._indexes_by_name_map = {} for node_id, node_data in data["nodes"].items(): node_name = node_data.get("name", None) @@ -33,26 +27,25 @@ def __init__( validation_model = component_registry.get(node_name) if validation_model is not None: - if node_name not in tokens: - tokens[node_name] = [] + if node_name not in self._indexes_by_name_map: + self._indexes_by_name_map[node_name] = [] - tokens[node_name].append(node_id) + self._indexes_by_name_map[node_name].append(node_id) node_data["id"] = node_id if node_id not in node_registry: node = validation_model(**node_data) node_registry.register(node_id, node) - if node.kind() not in self._kind_index_map: - self._kind_index_map[node.kind()] = [] + if node.kind() not in self._indexes_by_kind_map: + self._indexes_by_kind_map[node.kind()] = [] - self._kind_index_map[node.kind()].append(node_id) + self._indexes_by_kind_map[node.kind()].append(node_id) elif unknown_node_error: raise ValueError(f"Unknown node name: {node_name}") self._node_registry = node_registry - self._tokens = tokens @staticmethod def _check_node_name(node_name: str, node_id: str): @@ -86,12 +79,12 @@ def data(self) -> dict: @property def tokens(self) -> dict: """Returns a dictionary of tokens (node name) and their associated node ids.""" - return self._tokens + return self._indexes_by_name_map @property - def kind_index_map(self) -> dict: + def indexes_by_kind_map(self) -> dict: """Returns a dictionary of node kinds and their associated node ids.""" - return self._kind_index_map + return self._indexes_by_kind_map def get_node_by_id(self, node_id: str) -> BaseNode: return self._node_registry.get(node_id) @@ -110,4 +103,4 @@ def get_nodes_by_multiple_names(self, node_names: list) -> typing.List[BaseNode] return all_nodes def get_nodes_by_kind(self, kind: str) -> typing.List[BaseNode]: - return [self.get_node_by_id(i) for i in self.kind_index_map.get(kind, [])] + return [self.get_node_by_id(i) for i in self.indexes_by_kind_map.get(kind, [])] diff --git a/retrack/engine/runner.py b/retrack/engine/runner.py index e0d9d0f..0e8c5af 100644 --- a/retrack/engine/runner.py +++ b/retrack/engine/runner.py @@ -2,11 +2,13 @@ import numpy as np import pandas as pd +import pydantic from retrack.engine.parser import Parser from retrack.engine.payload_manager import PayloadManager -from retrack.utils import constants, graph from retrack.nodes.base import NodeKind +from retrack.utils import constants, graph + class Runner: def __init__(self, parser: Parser): @@ -28,10 +30,26 @@ def __init__(self, parser: Parser): def payload_manager(self) -> PayloadManager: return self._payload_manager + @property + def model(self) -> pydantic.BaseModel: + return self._payload_manager.model + @property def state_df(self) -> pd.DataFrame: return self._state_df + @property + def states(self) -> list: + return self._state_df.to_dict(orient="records") + + @property + def filters(self) -> dict: + return self._filters + + @property + def filter_df(self) -> pd.DataFrame: + return pd.DataFrame(self._filters) + def __get_initial_state_df(self, payload: typing.Union[dict, list]) -> pd.DataFrame: validated_payload = self.payload_manager.validate(payload) validated_payload = pd.DataFrame([p.dict() for p in validated_payload]) @@ -134,7 +152,9 @@ def __run_node(self, node_id: str): f"{node_id}@{output_name}", output_value, current_node_filter ) - def __call__(self, payload: typing.Union[dict, list]) -> pd.DataFrame: + def __call__( + self, payload: typing.Union[dict, list], to_dict: bool = True + ) -> pd.DataFrame: self._state_df = self.__get_initial_state_df(payload) self._filters = {} @@ -146,4 +166,7 @@ def __call__(self, payload: typing.Union[dict, list]) -> pd.DataFrame: if self._state_df[constants.OUTPUT_REFERENCE_COLUMN].isna().sum() == 0: break + if to_dict: + return self.__get_output_state_df(self._state_df).to_dict(orient="records") + return Runner.__get_output_state_df(self._state_df) diff --git a/tests/resources/age-negative.json b/tests/resources/age-negative.json new file mode 100644 index 0000000..c7dc057 --- /dev/null +++ b/tests/resources/age-negative.json @@ -0,0 +1,506 @@ +{ + "id": "demo@0.1.0", + "nodes": { + "0": { + "id": 0, + "data": {}, + "inputs": {}, + "outputs": { + "output_up_void": { + "connections": [ + { + "node": 2, + "input": "input_void", + "data": {} + } + ] + }, + "output_down_void": { + "connections": [ + { + "node": 3, + "input": "input_void", + "data": {} + } + ] + } + }, + "position": [ + -570.16015625, + -16.7578125 + ], + "name": "Start" + }, + "2": { + "id": 2, + "data": { + "name": "age", + "default": null + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 0, + "output": "output_up_void", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 4, + "input": "input_value_0", + "data": {} + } + ] + } + }, + "position": [ + -262.082231911288, + -229.52363816128795 + ], + "name": "Input" + }, + "3": { + "id": 3, + "data": { + "value": "18" + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 0, + "output": "output_down_void", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 4, + "input": "input_value_1", + "data": {} + } + ] + } + }, + "position": [ + -266.79444352384587, + 85.59488398537597 + ], + "name": "Constant" + }, + "4": { + "id": 4, + "data": { + "operator": ">=" + }, + "inputs": { + "input_value_0": { + "connections": [ + { + "node": 2, + "output": "output_value", + "data": {} + } + ] + }, + "input_value_1": { + "connections": [ + { + "node": 3, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": { + "output_bool": { + "connections": [ + { + "node": 6, + "input": "input_bool", + "data": {} + } + ] + } + }, + "position": [ + 83.84765625, + -146.80078125 + ], + "name": "Check" + }, + "6": { + "id": 6, + "data": {}, + "inputs": { + "input_bool": { + "connections": [ + { + "node": 4, + "output": "output_bool", + "data": {} + } + ] + } + }, + "outputs": { + "output_then_filter": { + "connections": [ + { + "node": 9, + "input": "input_void", + "data": {} + } + ] + }, + "output_else_filter": { + "connections": [ + { + "node": 32, + "input": "input_void", + "data": {} + }, + { + "node": 33, + "input": "input_void", + "data": {} + } + ] + } + }, + "position": [ + 387.98276806872417, + -127.24641593097007 + ], + "name": "If" + }, + "9": { + "id": 9, + "data": { + "value": true + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 6, + "output": "output_then_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_bool": { + "connections": [ + { + "node": 10, + "input": "input_bool", + "data": {} + } + ] + } + }, + "position": [ + 704.5628174930008, + -191.34367642675832 + ], + "name": "Bool" + }, + "10": { + "id": 10, + "data": { + "message": "valid age" + }, + "inputs": { + "input_bool": { + "connections": [ + { + "node": 9, + "output": "output_bool", + "data": {} + } + ] + } + }, + "outputs": {}, + "position": [ + 1010.059101990149, + -204.43876489419645 + ], + "name": "BoolOutput" + }, + "32": { + "id": 32, + "data": { + "value": "0" + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 6, + "output": "output_else_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 34, + "input": "input_value_1", + "data": {} + } + ] + } + }, + "position": [ + 692.9667216252685, + 232.40342195520302 + ], + "name": "Constant" + }, + "33": { + "id": 33, + "data": { + "name": "age", + "default": null + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 6, + "output": "output_else_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_value": { + "connections": [ + { + "node": 34, + "input": "input_value_0", + "data": {} + } + ] + } + }, + "position": [ + 695.7957828925545, + -11.978323498583437 + ], + "name": "Input" + }, + "34": { + "id": 34, + "data": { + "operator": "<=" + }, + "inputs": { + "input_value_0": { + "connections": [ + { + "node": 33, + "output": "output_value", + "data": {} + } + ] + }, + "input_value_1": { + "connections": [ + { + "node": 32, + "output": "output_value", + "data": {} + } + ] + } + }, + "outputs": { + "output_bool": { + "connections": [ + { + "node": 35, + "input": "input_bool", + "data": {} + } + ] + } + }, + "position": [ + 967.5446870413394, + 81.34660260832283 + ], + "name": "Check" + }, + "35": { + "id": 35, + "data": {}, + "inputs": { + "input_bool": { + "connections": [ + { + "node": 34, + "output": "output_bool", + "data": {} + } + ] + } + }, + "outputs": { + "output_then_filter": { + "connections": [ + { + "node": 36, + "input": "input_void", + "data": {} + } + ] + }, + "output_else_filter": { + "connections": [ + { + "node": 37, + "input": "input_void", + "data": {} + } + ] + } + }, + "position": [ + 1260.2577089540875, + 106.17681353648126 + ], + "name": "If" + }, + "36": { + "id": 36, + "data": { + "value": false + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 35, + "output": "output_then_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_bool": { + "connections": [ + { + "node": 38, + "input": "input_bool", + "data": {} + } + ] + } + }, + "position": [ + 1545.6248398426094, + -14.61885194908769 + ], + "name": "Bool" + }, + "37": { + "id": 37, + "data": { + "value": false + }, + "inputs": { + "input_void": { + "connections": [ + { + "node": 35, + "output": "output_else_filter", + "data": {} + } + ] + } + }, + "outputs": { + "output_bool": { + "connections": [ + { + "node": 39, + "input": "input_bool", + "data": {} + } + ] + } + }, + "position": [ + 1546.5621972662288, + 187.61935817965565 + ], + "name": "Bool" + }, + "38": { + "id": 38, + "data": { + "message": "invalid age" + }, + "inputs": { + "input_bool": { + "connections": [ + { + "node": 36, + "output": "output_bool", + "data": {} + } + ] + } + }, + "outputs": {}, + "position": [ + 1809.2160825804826, + -20.217793077837147 + ], + "name": "BoolOutput" + }, + "39": { + "id": 39, + "data": { + "message": "underage" + }, + "inputs": { + "input_bool": { + "connections": [ + { + "node": 37, + "output": "output_bool", + "data": {} + } + ] + } + }, + "outputs": {}, + "position": [ + 1805.5744940644936, + 197.36671274352045 + ], + "name": "BoolOutput" + } + } +} \ No newline at end of file diff --git a/tests/test_engine/test_runner.py b/tests/test_engine/test_runner.py new file mode 100644 index 0000000..1248366 --- /dev/null +++ b/tests/test_engine/test_runner.py @@ -0,0 +1,26 @@ +import json + +import pytest + +from retrack import Parser, Runner + + +@pytest.fixture +def age_negative_json() -> dict: + with open("tests/resources/age-negative.json", "r") as f: + return json.load(f) + + +def test_age_negative(age_negative_json): + parser = Parser(age_negative_json) + runner = Runner(parser) + in_values = [10, -10, 18, 19, 100] + out_values = runner([{"age": val} for val in in_values]) + + assert out_values == [ + {"message": "underage", "output": False}, + {"message": "invalid age", "output": False}, + {"message": "valid age", "output": True}, + {"message": "valid age", "output": True}, + {"message": "valid age", "output": True}, + ] From feaa1e7443742d8bf137d501541f1ea4ee48c4ae Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Fri, 17 Feb 2023 13:05:55 -0300 Subject: [PATCH 04/11] Change payload_manager to request_manager --- README.md | 2 +- ...{payload_manager.py => request_manager.py} | 2 +- retrack/engine/runner.py | 14 ++++++------- ...oad_manager.py => test_request_manager.py} | 20 +++++++++---------- 4 files changed, 19 insertions(+), 19 deletions(-) rename retrack/engine/{payload_manager.py => request_manager.py} (98%) rename tests/test_engine/{test_payload_manager.py => test_request_manager.py} (50%) diff --git a/README.md b/README.md index b03fb7f..5952b3b 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ runner = retrack.Runner(parser) runner(data) ``` -The `Parser` class parses the rule/model and creates a graph of nodes. The `Runner` class runs the rule/model using the data passed to the runner. The `data` is a dictionary or a list of dictionaries containing the data that will be used to evaluate the conditions and execute the actions. To see wich data is required for the given rule/model, check the `runner.payload_manager.model` property that is a pydantic model used to validate the data. +The `Parser` class parses the rule/model and creates a graph of nodes. The `Runner` class runs the rule/model using the data passed to the runner. The `data` is a dictionary or a list of dictionaries containing the data that will be used to evaluate the conditions and execute the actions. To see wich data is required for the given rule/model, check the `runner.request_model` property that is a pydantic model used to validate the data. ### Creating a rule/model diff --git a/retrack/engine/payload_manager.py b/retrack/engine/request_manager.py similarity index 98% rename from retrack/engine/payload_manager.py rename to retrack/engine/request_manager.py index f78301c..2db3875 100644 --- a/retrack/engine/payload_manager.py +++ b/retrack/engine/request_manager.py @@ -5,7 +5,7 @@ from retrack.nodes.inputs import Input -class PayloadManager: +class RequestManager: def __init__(self, inputs: typing.List[Input]): self._model = None self.inputs = inputs diff --git a/retrack/engine/runner.py b/retrack/engine/runner.py index 0e8c5af..83de27b 100644 --- a/retrack/engine/runner.py +++ b/retrack/engine/runner.py @@ -5,7 +5,7 @@ import pydantic from retrack.engine.parser import Parser -from retrack.engine.payload_manager import PayloadManager +from retrack.engine.request_manager import RequestManager from retrack.nodes.base import NodeKind from retrack.utils import constants, graph @@ -19,7 +19,7 @@ def __init__(self, parser: Parser): f"{node.id}@{constants.INPUT_OUTPUT_VALUE_CONNECTOR_NAME}": node.data.name for node in input_nodes } - self._payload_manager = PayloadManager(input_nodes) + self._request_manager = RequestManager(input_nodes) self._execution_order = graph.get_execution_order(self._parser) self._state_df = None @@ -27,12 +27,12 @@ def __init__(self, parser: Parser): self._filters = {} @property - def payload_manager(self) -> PayloadManager: - return self._payload_manager + def request_manager(self) -> RequestManager: + return self._request_manager @property - def model(self) -> pydantic.BaseModel: - return self._payload_manager.model + def request_model(self) -> pydantic.BaseModel: + return self._request_manager.model @property def state_df(self) -> pd.DataFrame: @@ -51,7 +51,7 @@ def filter_df(self) -> pd.DataFrame: return pd.DataFrame(self._filters) def __get_initial_state_df(self, payload: typing.Union[dict, list]) -> pd.DataFrame: - validated_payload = self.payload_manager.validate(payload) + validated_payload = self.request_manager.validate(payload) validated_payload = pd.DataFrame([p.dict() for p in validated_payload]) state_df = pd.DataFrame([]) diff --git a/tests/test_engine/test_payload_manager.py b/tests/test_engine/test_request_manager.py similarity index 50% rename from tests/test_engine/test_payload_manager.py rename to tests/test_engine/test_request_manager.py index da27203..259fcd8 100644 --- a/tests/test_engine/test_payload_manager.py +++ b/tests/test_engine/test_request_manager.py @@ -1,35 +1,35 @@ import pytest -from retrack.engine.payload_manager import PayloadManager +from retrack.engine.request_manager import RequestManager -def test_create_payload_manager_with_dict(valid_input_dict_before_validation): +def test_create_request_manager_with_dict(valid_input_dict_before_validation): with pytest.raises(TypeError): - PayloadManager(valid_input_dict_before_validation) + RequestManager(valid_input_dict_before_validation) -def test_create_payload_manager_with_list_of_dicts(valid_input_dict_before_validation): - pm = PayloadManager([valid_input_dict_before_validation]) +def test_create_request_manager_with_list_of_dicts(valid_input_dict_before_validation): + pm = RequestManager([valid_input_dict_before_validation]) assert len(pm.inputs) == 1 assert pm.model is not None -def test_create_payload_manager_with_list_of_dicts_and_duplicate_names( +def test_create_request_manager_with_list_of_dicts_and_duplicate_names( valid_input_dict_before_validation, ): - pm = PayloadManager( + pm = RequestManager( [valid_input_dict_before_validation, valid_input_dict_before_validation] ) assert len(pm.inputs) == 1 assert pm.model is not None -def test_create_payload_manager_with_invalid_input(valid_input_dict_before_validation): +def test_create_request_manager_with_invalid_input(valid_input_dict_before_validation): with pytest.raises(TypeError): - PayloadManager([valid_input_dict_before_validation, "invalid"]) + RequestManager([valid_input_dict_before_validation, "invalid"]) def test_validate_payload_with_valid_payload(valid_input_dict_before_validation): - pm = PayloadManager([valid_input_dict_before_validation]) + pm = RequestManager([valid_input_dict_before_validation]) payload = pm.model(example="test") assert pm.validate({"example": "test"})[0] == payload From 49f7242696fcf02a8c4675e12b956fec98b175dc Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Fri, 17 Feb 2023 13:11:40 -0300 Subject: [PATCH 05/11] Add optional default value to request model fields --- retrack/engine/request_manager.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/retrack/engine/request_manager.py b/retrack/engine/request_manager.py index 2db3875..60b5a7d 100644 --- a/retrack/engine/request_manager.py +++ b/retrack/engine/request_manager.py @@ -44,15 +44,17 @@ def model(self) -> typing.Type[pydantic.BaseModel]: return self._model def _create_model(self) -> typing.Type[pydantic.BaseModel]: + fields = {} + for input_field in self.inputs: + fields[input_field.data.name] = ( + (str, ...) + if input_field.data.default is None + else (str, input_field.data.default) + ) + return pydantic.create_model( "Payload", - **{ - input_.data.name: ( - str, - ..., - ) - for input_ in self.inputs - }, + **fields, ) def validate( From 3680170327e8f5bd949a9d43adf2381036752d25 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Fri, 17 Feb 2023 13:13:31 -0300 Subject: [PATCH 06/11] Change payload name to request --- retrack/engine/request_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/retrack/engine/request_manager.py b/retrack/engine/request_manager.py index 60b5a7d..3ccc76c 100644 --- a/retrack/engine/request_manager.py +++ b/retrack/engine/request_manager.py @@ -53,7 +53,7 @@ def _create_model(self) -> typing.Type[pydantic.BaseModel]: ) return pydantic.create_model( - "Payload", + "RequestModel", **fields, ) From 33e3d507dcb9bf872b68ba1f3266b604522c1139 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Fri, 17 Feb 2023 18:17:34 -0300 Subject: [PATCH 07/11] Remove request manager dict args --- retrack/engine/request_manager.py | 46 ++++++++++++++++++----- tests/test_engine/test_request_manager.py | 17 +++++---- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/retrack/engine/request_manager.py b/retrack/engine/request_manager.py index 3ccc76c..295c4eb 100644 --- a/retrack/engine/request_manager.py +++ b/retrack/engine/request_manager.py @@ -2,20 +2,20 @@ import pydantic -from retrack.nodes.inputs import Input +from retrack.nodes.base import BaseNode, NodeKind class RequestManager: - def __init__(self, inputs: typing.List[Input]): + def __init__(self, inputs: typing.List[BaseNode]): self._model = None self.inputs = inputs @property - def inputs(self) -> typing.List[Input]: + def inputs(self) -> typing.List[BaseNode]: return self._inputs @inputs.setter - def inputs(self, inputs: typing.List[Input]): + def inputs(self, inputs: typing.List[BaseNode]): if not isinstance(inputs, list): raise TypeError(f"inputs must be a list, not {type(inputs)}") @@ -23,13 +23,16 @@ def inputs(self, inputs: typing.List[Input]): formated_inputs = [] for i in range(len(inputs)): - if isinstance(inputs[i], dict): - inputs[i] = Input(**inputs[i]) - elif not isinstance(inputs[i], Input): + if not isinstance(inputs[i], BaseNode): raise TypeError( f"inputs[{i}] must be a dict or an InputModel, not {type(inputs[i])}" ) + if inputs[i].kind() != NodeKind.INPUT: + raise TypeError( + f"inputs[{i}] must be an InputModel, not {type(inputs[i])}" + ) + if inputs[i].data.name not in input_names: input_names.add(inputs[i].data.name) formated_inputs.append(inputs[i]) @@ -37,13 +40,25 @@ def inputs(self, inputs: typing.List[Input]): self._inputs = formated_inputs if len(self.inputs) > 0: - self._model = self._create_model() + self._model = self.__create_model() + else: + self._model = None @property def model(self) -> typing.Type[pydantic.BaseModel]: return self._model - def _create_model(self) -> typing.Type[pydantic.BaseModel]: + def __create_model( + self, model_name: str = "RequestModel" + ) -> typing.Type[pydantic.BaseModel]: + """Create a pydantic model from the RequestManager's inputs + + Args: + model_name (str, optional): The name of the model. Defaults to "RequestModel". + + Returns: + typing.Type[pydantic.BaseModel]: The pydantic model + """ fields = {} for input_field in self.inputs: fields[input_field.data.name] = ( @@ -53,7 +68,7 @@ def _create_model(self) -> typing.Type[pydantic.BaseModel]: ) return pydantic.create_model( - "RequestModel", + model_name, **fields, ) @@ -63,6 +78,17 @@ def validate( typing.Dict[str, str], typing.List[typing.Dict[str, str]] ], ) -> typing.List[pydantic.BaseModel]: + """Validate the payload against the RequestManager's model + + Args: + payload (typing.Union[typing.Dict[str, str], typing.List[typing.Dict[str, str]]]): The payload to validate + + Raises: + ValueError: If the RequestManager has no model + + Returns: + typing.List[pydantic.BaseModel]: The validated payload + """ if self.model is None: raise ValueError("No inputs found") diff --git a/tests/test_engine/test_request_manager.py b/tests/test_engine/test_request_manager.py index 259fcd8..3758920 100644 --- a/tests/test_engine/test_request_manager.py +++ b/tests/test_engine/test_request_manager.py @@ -1,24 +1,25 @@ import pytest from retrack.engine.request_manager import RequestManager +from retrack.nodes.inputs import Input -def test_create_request_manager_with_dict(valid_input_dict_before_validation): +def test_create_request_manager(valid_input_dict_before_validation): with pytest.raises(TypeError): - RequestManager(valid_input_dict_before_validation) + RequestManager(Input(**valid_input_dict_before_validation)) -def test_create_request_manager_with_list_of_dicts(valid_input_dict_before_validation): - pm = RequestManager([valid_input_dict_before_validation]) +def test_create_request_manager_with_list_of_inputs(valid_input_dict_before_validation): + pm = RequestManager([Input(**valid_input_dict_before_validation)]) assert len(pm.inputs) == 1 assert pm.model is not None -def test_create_request_manager_with_list_of_dicts_and_duplicate_names( +def test_create_request_manager_with_list_of_inputs_and_duplicate_names( valid_input_dict_before_validation, ): pm = RequestManager( - [valid_input_dict_before_validation, valid_input_dict_before_validation] + [Input(**valid_input_dict_before_validation), Input(**valid_input_dict_before_validation)] ) assert len(pm.inputs) == 1 assert pm.model is not None @@ -26,10 +27,10 @@ def test_create_request_manager_with_list_of_dicts_and_duplicate_names( def test_create_request_manager_with_invalid_input(valid_input_dict_before_validation): with pytest.raises(TypeError): - RequestManager([valid_input_dict_before_validation, "invalid"]) + RequestManager([Input(**valid_input_dict_before_validation), "invalid"]) def test_validate_payload_with_valid_payload(valid_input_dict_before_validation): - pm = RequestManager([valid_input_dict_before_validation]) + pm = RequestManager([Input(**valid_input_dict_before_validation)]) payload = pm.model(example="test") assert pm.validate({"example": "test"})[0] == payload From 54272c467a4c7a02ac3ea182bae7b9f89df04027 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Fri, 17 Feb 2023 18:59:56 -0300 Subject: [PATCH 08/11] Change BoolOutput to Output --- examples/age-check.json | 64 ++++++------- retrack/nodes/__init__.py | 4 +- retrack/nodes/outputs.py | 12 +-- tests/resources/age-negative.json | 152 +++++++++++++++--------------- 4 files changed, 116 insertions(+), 116 deletions(-) diff --git a/examples/age-check.json b/examples/age-check.json index a543cd8..63a47ec 100644 --- a/examples/age-check.json +++ b/examples/age-check.json @@ -135,8 +135,8 @@ } }, "position": [ - 83.84765625, - -146.80078125 + 45.51953125, + -136.8515625 ], "name": "Check" }, @@ -180,29 +180,6 @@ ], "name": "If" }, - "7": { - "id": 7, - "data": { - "message": "invalid age" - }, - "inputs": { - "input_bool": { - "connections": [ - { - "node": 8, - "output": "output_bool", - "data": {} - } - ] - } - }, - "outputs": {}, - "position": [ - 972.3247551810931, - -8.462733235746624 - ], - "name": "BoolOutput" - }, "8": { "id": 8, "data": { @@ -223,8 +200,8 @@ "output_bool": { "connections": [ { - "node": 7, - "input": "input_bool", + "node": 11, + "input": "input_value", "data": {} } ] @@ -257,7 +234,7 @@ "connections": [ { "node": 10, - "input": "input_bool", + "input": "input_value", "data": {} } ] @@ -275,7 +252,7 @@ "message": "valid age" }, "inputs": { - "input_bool": { + "input_value": { "connections": [ { "node": 9, @@ -287,10 +264,33 @@ }, "outputs": {}, "position": [ - 974.0269089794663, - -190.8242802623253 + 1015.5346416468075, + -247.2703893983769 + ], + "name": "Output" + }, + "11": { + "id": 11, + "data": { + "message": "invalid age" + }, + "inputs": { + "input_value": { + "connections": [ + { + "node": 8, + "output": "output_bool", + "data": {} + } + ] + } + }, + "outputs": {}, + "position": [ + 1008.826235840524, + -68.88453626006572 ], - "name": "BoolOutput" + "name": "Output" } } } \ No newline at end of file diff --git a/retrack/nodes/__init__.py b/retrack/nodes/__init__.py index 196a512..dc5cf35 100644 --- a/retrack/nodes/__init__.py +++ b/retrack/nodes/__init__.py @@ -8,7 +8,7 @@ from retrack.nodes.logic import And, Not, Or from retrack.nodes.match import If from retrack.nodes.math import Math -from retrack.nodes.outputs import BoolOutput +from retrack.nodes.outputs import Output from retrack.nodes.start import Start from retrack.nodes.startswith import StartsWith from retrack.nodes.startswithany import StartsWithAny @@ -21,7 +21,7 @@ registry.register("Constant", Constant) registry.register("List", List) registry.register("Bool", Bool) -registry.register("BoolOutput", BoolOutput) +registry.register("Output", Output) registry.register("Check", Check) registry.register("If", If) registry.register("And", And) diff --git a/retrack/nodes/outputs.py b/retrack/nodes/outputs.py index 1e74054..53c991f 100644 --- a/retrack/nodes/outputs.py +++ b/retrack/nodes/outputs.py @@ -18,8 +18,8 @@ class OutputMetadataModel(pydantic.BaseModel): ################################################ # Output Inputs and Outputs ################################################ -class BoolOutputInputsModel(pydantic.BaseModel): - input_bool: InputConnectionModel +class OutputInputsModel(pydantic.BaseModel): + input_value: InputConnectionModel ################################################ @@ -27,15 +27,15 @@ class BoolOutputInputsModel(pydantic.BaseModel): ################################################ -class BoolOutput(BaseNode): - inputs: typing.Optional[BoolOutputInputsModel] +class Output(BaseNode): + inputs: typing.Optional[OutputInputsModel] data: OutputMetadataModel def kind(self) -> NodeKind: return NodeKind.OUTPUT - def run(self, input_bool: pd.Series) -> typing.Dict[str, pd.Series]: + def run(self, input_value: pd.Series) -> typing.Dict[str, pd.Series]: return { - constants.OUTPUT_REFERENCE_COLUMN: input_bool, + constants.OUTPUT_REFERENCE_COLUMN: input_value, constants.OUTPUT_MESSAGE_REFERENCE_COLUMN: self.data.message, } diff --git a/tests/resources/age-negative.json b/tests/resources/age-negative.json index c7dc057..982c5de 100644 --- a/tests/resources/age-negative.json +++ b/tests/resources/age-negative.json @@ -135,8 +135,8 @@ } }, "position": [ - 83.84765625, - -146.80078125 + 45.51953125, + -136.8515625 ], "name": "Check" }, @@ -167,12 +167,12 @@ "output_else_filter": { "connections": [ { - "node": 32, + "node": 13, "input": "input_void", "data": {} }, { - "node": 33, + "node": 14, "input": "input_void", "data": {} } @@ -180,8 +180,8 @@ } }, "position": [ - 387.98276806872417, - -127.24641593097007 + 320.17898461073565, + -126.14455525956922 ], "name": "If" }, @@ -206,15 +206,15 @@ "connections": [ { "node": 10, - "input": "input_bool", + "input": "input_value", "data": {} } ] } }, "position": [ - 704.5628174930008, - -191.34367642675832 + 603.3983427750909, + -198.97829545178527 ], "name": "Bool" }, @@ -224,7 +224,7 @@ "message": "valid age" }, "inputs": { - "input_bool": { + "input_value": { "connections": [ { "node": 9, @@ -236,15 +236,16 @@ }, "outputs": {}, "position": [ - 1010.059101990149, - -204.43876489419645 + 862.2873631452995, + -249.30686509884737 ], - "name": "BoolOutput" + "name": "Output" }, - "32": { - "id": 32, + "13": { + "id": 13, "data": { - "value": "0" + "name": "age", + "default": null }, "inputs": { "input_void": { @@ -261,24 +262,23 @@ "output_value": { "connections": [ { - "node": 34, - "input": "input_value_1", + "node": 15, + "input": "input_value_0", "data": {} } ] } }, "position": [ - 692.9667216252685, - 232.40342195520302 + 604.0864802067674, + -3.07067572994199 ], - "name": "Constant" + "name": "Input" }, - "33": { - "id": 33, + "14": { + "id": 14, "data": { - "name": "age", - "default": null + "value": "0" }, "inputs": { "input_void": { @@ -295,29 +295,29 @@ "output_value": { "connections": [ { - "node": 34, - "input": "input_value_0", + "node": 15, + "input": "input_value_1", "data": {} } ] } }, "position": [ - 695.7957828925545, - -11.978323498583437 + 600.676903010976, + 257.67370457848597 ], - "name": "Input" + "name": "Constant" }, - "34": { - "id": 34, + "15": { + "id": 15, "data": { - "operator": "<=" + "operator": ">" }, "inputs": { "input_value_0": { "connections": [ { - "node": 33, + "node": 13, "output": "output_value", "data": {} } @@ -326,7 +326,7 @@ "input_value_1": { "connections": [ { - "node": 32, + "node": 14, "output": "output_value", "data": {} } @@ -337,7 +337,7 @@ "output_bool": { "connections": [ { - "node": 35, + "node": 16, "input": "input_bool", "data": {} } @@ -345,19 +345,19 @@ } }, "position": [ - 967.5446870413394, - 81.34660260832283 + 861.6756216443376, + 78.00454256820485 ], "name": "Check" }, - "35": { - "id": 35, + "16": { + "id": 16, "data": {}, "inputs": { "input_bool": { "connections": [ { - "node": 34, + "node": 15, "output": "output_bool", "data": {} } @@ -368,7 +368,7 @@ "output_then_filter": { "connections": [ { - "node": 36, + "node": 17, "input": "input_void", "data": {} } @@ -377,7 +377,7 @@ "output_else_filter": { "connections": [ { - "node": 37, + "node": 18, "input": "input_void", "data": {} } @@ -385,13 +385,13 @@ } }, "position": [ - 1260.2577089540875, - 106.17681353648126 + 1148.8003095752115, + 93.57764585438737 ], "name": "If" }, - "36": { - "id": 36, + "17": { + "id": 17, "data": { "value": false }, @@ -399,7 +399,7 @@ "input_void": { "connections": [ { - "node": 35, + "node": 16, "output": "output_then_filter", "data": {} } @@ -410,29 +410,29 @@ "output_bool": { "connections": [ { - "node": 38, - "input": "input_bool", + "node": 19, + "input": "input_value", "data": {} } ] } }, "position": [ - 1545.6248398426094, - -14.61885194908769 + 1448.476245602431, + 6.675835758755483 ], "name": "Bool" }, - "37": { - "id": 37, + "18": { + "id": 18, "data": { - "value": false + "value": null }, "inputs": { "input_void": { "connections": [ { - "node": 35, + "node": 16, "output": "output_else_filter", "data": {} } @@ -443,29 +443,29 @@ "output_bool": { "connections": [ { - "node": 39, - "input": "input_bool", + "node": 20, + "input": "input_value", "data": {} } ] } }, "position": [ - 1546.5621972662288, - 187.61935817965565 + 1445.5353160759287, + 203.0356798844884 ], "name": "Bool" }, - "38": { - "id": 38, + "19": { + "id": 19, "data": { - "message": "invalid age" + "message": "underage" }, "inputs": { - "input_bool": { + "input_value": { "connections": [ { - "node": 36, + "node": 17, "output": "output_bool", "data": {} } @@ -474,21 +474,21 @@ }, "outputs": {}, "position": [ - 1809.2160825804826, - -20.217793077837147 + 1727.6510435313726, + 0.5048435292831783 ], - "name": "BoolOutput" + "name": "Output" }, - "39": { - "id": 39, + "20": { + "id": 20, "data": { - "message": "underage" + "message": "invalid age" }, "inputs": { - "input_bool": { + "input_value": { "connections": [ { - "node": 37, + "node": 18, "output": "output_bool", "data": {} } @@ -497,10 +497,10 @@ }, "outputs": {}, "position": [ - 1805.5744940644936, - 197.36671274352045 + 1719.7704046829726, + 214.48677799020416 ], - "name": "BoolOutput" + "name": "Output" } } } \ No newline at end of file From 7cc94a423f3c3b7d3288c313ad815e2c9fab0179 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Thu, 23 Feb 2023 11:07:35 -0300 Subject: [PATCH 09/11] Create validators validation --- retrack/engine/parser.py | 20 +- retrack/engine/validators/__init__.py | 12 ++ retrack/engine/validators/base.py | 13 ++ retrack/engine/validators/node_exists.py | 38 ++++ retrack/nodes/check.py | 6 + tests/resources/age-negative-data.json | 221 ++++++++++++++++++++++ tests/test_engine/test_request_manager.py | 5 +- tests/test_parser.py | 60 ++---- 8 files changed, 331 insertions(+), 44 deletions(-) create mode 100644 retrack/engine/validators/__init__.py create mode 100644 retrack/engine/validators/base.py create mode 100644 retrack/engine/validators/node_exists.py create mode 100644 tests/resources/age-negative-data.json diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index 4d6715c..a63b1eb 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -1,5 +1,6 @@ import typing +from retrack.engine.validators import registry as GLOBAL_VALIDATOR_REGISTRY from retrack.nodes import BaseNode from retrack.nodes import registry as GLOBAL_NODE_REGISTRY from retrack.utils.registry import Registry @@ -8,17 +9,26 @@ class Parser: def __init__( self, - data: dict, + graph_data: dict, component_registry: Registry = GLOBAL_NODE_REGISTRY, + validator_registry: Registry = GLOBAL_VALIDATOR_REGISTRY, unknown_node_error: bool = False, ): - Parser._check_input_data(data) + """Parses a dictionary of nodes and returns a dictionary of BaseNode objects. + + Args: + data (dict): A dictionary of nodes. + component_registry (Registry, optional): A registry of BaseNode objects. Defaults to GLOBAL_NODE_REGISTRY. + validator_registry (Registry, optional): A registry of BaseValidator objects. Defaults to GLOBAL_VALIDATOR_REGISTRY. + unknown_node_error (bool, optional): Whether to raise an error if an unknown node is found. Defaults to False. + """ + Parser._check_input_data(graph_data) node_registry = Registry() self._indexes_by_kind_map = {} self._indexes_by_name_map = {} - for node_id, node_data in data["nodes"].items(): + for node_id, node_data in graph_data["nodes"].items(): node_name = node_data.get("name", None) Parser._check_node_name(node_name, node_id) @@ -45,6 +55,10 @@ def __init__( elif unknown_node_error: raise ValueError(f"Unknown node name: {node_name}") + for validator_name, validator in validator_registry.data.items(): + if not validator.validate(graph_data): + raise ValueError(f"Invalid graph data: {validator_name}") + self._node_registry = node_registry @staticmethod diff --git a/retrack/engine/validators/__init__.py b/retrack/engine/validators/__init__.py new file mode 100644 index 0000000..9e41bf3 --- /dev/null +++ b/retrack/engine/validators/__init__.py @@ -0,0 +1,12 @@ +from retrack.engine.validators.base import BaseValidator +from retrack.engine.validators.node_exists import NodeExistsValidator +from retrack.utils.registry import Registry + +registry = Registry() + +registry.register( + "single_start_node_exists", + NodeExistsValidator("start", min_quantity=1, max_quantity=1), +) + +__all__ = ["registry", "BaseValidator"] diff --git a/retrack/engine/validators/base.py b/retrack/engine/validators/base.py new file mode 100644 index 0000000..ff754aa --- /dev/null +++ b/retrack/engine/validators/base.py @@ -0,0 +1,13 @@ +class BaseValidator: + """Base class for all validators.""" + + def validate(self, graph_data: dict) -> bool: + """Validate the graph data. + + Args: + graph_data: The graph data to validate. + + Returns: + True if the graph data is valid, False otherwise. + """ + raise NotImplementedError diff --git a/retrack/engine/validators/node_exists.py b/retrack/engine/validators/node_exists.py new file mode 100644 index 0000000..fdacb42 --- /dev/null +++ b/retrack/engine/validators/node_exists.py @@ -0,0 +1,38 @@ +from retrack.engine.validators.base import BaseValidator + + +class NodeExistsValidator(BaseValidator): + """Validator that checks if a node exists in the graph.""" + + def __init__( + self, node_name: str, max_quantity: int = None, min_quantity: int = None + ) -> None: + """Initialize the validator. + + Args: + node_name: The name of the node to validate. + max_quantity: The maximum quantity of nodes to validate. + min_quantity: The minimum quantity of nodes to validate. + """ + self.node_name = node_name.lower() + self.max_quantity = max_quantity + self.min_quantity = min_quantity + + def validate(self, graph_data: dict) -> bool: + """Validate the graph data. + + Args: + graph_data: The graph data to validate. + + Returns: + True if the graph data is valid, False otherwise. + """ + nodes = graph_data.get("nodes", []) + nodes = [ + node for _, node in nodes.items() if node["name"].lower() == self.node_name + ] + if self.max_quantity is not None and len(nodes) > self.max_quantity: + return False + if self.min_quantity is not None and len(nodes) < self.min_quantity: + return False + return True diff --git a/retrack/nodes/check.py b/retrack/nodes/check.py index 17d0deb..17170c7 100644 --- a/retrack/nodes/check.py +++ b/retrack/nodes/check.py @@ -25,6 +25,12 @@ class CheckOperator(str, enum.Enum): GREATER_THAN_OR_EQUAL = ">=" LESS_THAN_OR_EQUAL = "<=" + def __str__(self): + return self.value + + def __repr__(self) -> str: + return self.value + class CheckMetadataModel(pydantic.BaseModel): operator: typing.Optional[CheckOperator] = CheckOperator.EQUAL diff --git a/tests/resources/age-negative-data.json b/tests/resources/age-negative-data.json new file mode 100644 index 0000000..4f59313 --- /dev/null +++ b/tests/resources/age-negative-data.json @@ -0,0 +1,221 @@ +{ + "0": { + "id": "0", + "inputs": {}, + "outputs": { + "output_up_void": { + "connections": [{"node": "2", "input": "input_void"}] + }, + "output_down_void": { + "connections": [{"node": "3", "input": "input_void"}] + } + } + }, + "2": { + "id": "2", + "inputs": { + "input_void": { + "connections": [{"node": "0", "output": "output_up_void"}] + } + }, + "outputs": { + "output_value": { + "connections": [{"node": "4", "input": "input_value_0"}] + } + }, + "data": {"name": "age", "default": null} + }, + "3": { + "id": "3", + "inputs": { + "input_void": { + "connections": [{"node": "0", "output": "output_down_void"}] + } + }, + "outputs": { + "output_value": { + "connections": [{"node": "4", "input": "input_value_1"}] + } + }, + "data": {"value": "18"} + }, + "4": { + "id": "4", + "inputs": { + "input_value_0": { + "connections": [{"node": "2", "output": "output_value"}] + }, + "input_value_1": { + "connections": [{"node": "3", "output": "output_value"}] + } + }, + "outputs": { + "output_bool": { + "connections": [{"node": "6", "input": "input_bool"}] + } + }, + "data": {"operator": ">="} + }, + "6": { + "id": "6", + "inputs": { + "input_bool": { + "connections": [{"node": "4", "output": "output_bool"}] + } + }, + "outputs": { + "output_then_filter": { + "connections": [{"node": "9", "input": "input_void"}] + }, + "output_else_filter": { + "connections": [ + {"node": "13", "input": "input_void"}, + {"node": "14", "input": "input_void"} + ] + } + } + }, + "9": { + "id": "9", + "inputs": { + "input_void": { + "connections": [ + {"node": "6", "output": "output_then_filter"} + ] + } + }, + "outputs": { + "output_bool": { + "connections": [{"node": "10", "input": "input_value"}] + } + }, + "data": {"value": true} + }, + "10": { + "id": "10", + "inputs": { + "input_value": { + "connections": [{"node": "9", "output": "output_bool"}] + } + }, + "outputs": {}, + "data": {"message": "valid age"} + }, + "13": { + "id": "13", + "inputs": { + "input_void": { + "connections": [ + {"node": "6", "output": "output_else_filter"} + ] + } + }, + "outputs": { + "output_value": { + "connections": [{"node": "15", "input": "input_value_0"}] + } + }, + "data": {"name": "age", "default": null} + }, + "14": { + "id": "14", + "inputs": { + "input_void": { + "connections": [ + {"node": "6", "output": "output_else_filter"} + ] + } + }, + "outputs": { + "output_value": { + "connections": [{"node": "15", "input": "input_value_1"}] + } + }, + "data": {"value": "0"} + }, + "15": { + "id": "15", + "inputs": { + "input_value_0": { + "connections": [{"node": "13", "output": "output_value"}] + }, + "input_value_1": { + "connections": [{"node": "14", "output": "output_value"}] + } + }, + "outputs": { + "output_bool": { + "connections": [{"node": "16", "input": "input_bool"}] + } + }, + "data": {"operator": ">"} + }, + "16": { + "id": "16", + "inputs": { + "input_bool": { + "connections": [{"node": "15", "output": "output_bool"}] + } + }, + "outputs": { + "output_then_filter": { + "connections": [{"node": "17", "input": "input_void"}] + }, + "output_else_filter": { + "connections": [{"node": "18", "input": "input_void"}] + } + } + }, + "17": { + "id": "17", + "inputs": { + "input_void": { + "connections": [ + {"node": "16", "output": "output_then_filter"} + ] + } + }, + "outputs": { + "output_bool": { + "connections": [{"node": "19", "input": "input_value"}] + } + }, + "data": {"value": false} + }, + "18": { + "id": "18", + "inputs": { + "input_void": { + "connections": [ + {"node": "16", "output": "output_else_filter"} + ] + } + }, + "outputs": { + "output_bool": { + "connections": [{"node": "20", "input": "input_value"}] + } + }, + "data": {"value": false} + }, + "19": { + "id": "19", + "inputs": { + "input_value": { + "connections": [{"node": "17", "output": "output_bool"}] + } + }, + "outputs": {}, + "data": {"message": "underage"} + }, + "20": { + "id": "20", + "inputs": { + "input_value": { + "connections": [{"node": "18", "output": "output_bool"}] + } + }, + "outputs": {}, + "data": {"message": "invalid age"} + } +} \ No newline at end of file diff --git a/tests/test_engine/test_request_manager.py b/tests/test_engine/test_request_manager.py index 3758920..2a04fc7 100644 --- a/tests/test_engine/test_request_manager.py +++ b/tests/test_engine/test_request_manager.py @@ -19,7 +19,10 @@ def test_create_request_manager_with_list_of_inputs_and_duplicate_names( valid_input_dict_before_validation, ): pm = RequestManager( - [Input(**valid_input_dict_before_validation), Input(**valid_input_dict_before_validation)] + [ + Input(**valid_input_dict_before_validation), + Input(**valid_input_dict_before_validation), + ] ) assert len(pm.inputs) == 1 assert pm.model is not None diff --git a/tests/test_parser.py b/tests/test_parser.py index 6db1263..b578b5a 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,58 +1,38 @@ +import json + import pytest from retrack.engine.parser import Parser @pytest.mark.parametrize( - "input_data,expected_output_data", + "data_filename,expected_data_filename,expected_tokens", [ ( + "tests/resources/age-negative.json", + "tests/resources/age-negative-data.json", { - "nodes": { - "1": { - "id": 3, - "data": {"name": "age"}, - "inputs": { - "input_void": { - "connections": [ - {"node": 0, "output": "output_void", "data": {}} - ] - } - }, - "outputs": { - "output_value": { - "connections": [ - {"node": 5, "input": "input_value", "data": {}} - ] - } - }, - "name": "Input", - } - } - }, - { - "1": { - "id": "1", - "data": {"name": "age", "default": None}, - "inputs": { - "input_void": { - "connections": [{"node": "0", "output": "output_void"}] - } - }, - "outputs": { - "output_value": { - "connections": [{"node": "5", "input": "input_value"}] - } - }, - } + "start": ["0"], + "input": ["2", "13"], + "constant": ["3", "14"], + "check": ["4", "15"], + "if": ["6", "16"], + "bool": ["9", "17", "18"], + "output": ["10", "19", "20"], }, ) ], ) -def test_parser_extract(input_data, expected_output_data): +def test_parser_extract(data_filename, expected_data_filename, expected_tokens): + with open(data_filename) as f: + input_data = json.load(f) + + with open(expected_data_filename) as f: + expected_output_data = json.load(f) + parser = Parser(input_data) assert parser.data == expected_output_data - assert parser.tokens == {"input": ["1"]} + assert parser.tokens == expected_tokens def test_parser_with_unknown_node(): From dabdd09212521591c90b73a9d6e96302da6fa680 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 27 Feb 2023 14:41:41 -0300 Subject: [PATCH 10/11] Add dag validator --- pyproject.toml | 1 + retrack/engine/parser.py | 21 ++++++++++++++++++--- retrack/engine/validators/__init__.py | 2 ++ retrack/engine/validators/base.py | 5 +---- retrack/engine/validators/check_is_dag.py | 15 +++++++++++++++ retrack/engine/validators/node_exists.py | 2 +- 6 files changed, 38 insertions(+), 8 deletions(-) create mode 100644 retrack/engine/validators/check_is_dag.py diff --git a/pyproject.toml b/pyproject.toml index cf8cd47..7ef65d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ keywords = ["rules", "models", "business", "node", "graph"] python = "^3.8.16" pydantic = "^1.10.4" pandas = "^1.5.2" +networkx = "^3.0" [tool.poetry.dev-dependencies] pytest = "^6.2.4" diff --git a/retrack/engine/parser.py b/retrack/engine/parser.py index a63b1eb..a735e59 100644 --- a/retrack/engine/parser.py +++ b/retrack/engine/parser.py @@ -27,6 +27,7 @@ def __init__( node_registry = Registry() self._indexes_by_kind_map = {} self._indexes_by_name_map = {} + self.__edges = None for node_id, node_data in graph_data["nodes"].items(): node_name = node_data.get("name", None) @@ -55,12 +56,12 @@ def __init__( elif unknown_node_error: raise ValueError(f"Unknown node name: {node_name}") + self._node_registry = node_registry + for validator_name, validator in validator_registry.data.items(): - if not validator.validate(graph_data): + if not validator.validate(graph_data=graph_data, edges=self.edges): raise ValueError(f"Invalid graph data: {validator_name}") - self._node_registry = node_registry - @staticmethod def _check_node_name(node_name: str, node_id: str): if node_name is None: @@ -86,6 +87,20 @@ def _check_input_data(data: dict): def nodes(self) -> typing.Dict[str, BaseNode]: return self._node_registry.data + @property + def edges(self) -> typing.List[typing.Tuple[str, str]]: + if self.__edges is None: + edges = [] + + for node_id, node in self.nodes.items(): + for _, output_connection in node.outputs: + for c in output_connection.connections: + edges.append((node_id, c.node)) + + self.__edges = edges + + return self.__edges + @property def data(self) -> dict: return {i: j.dict(by_alias=True) for i, j in self.nodes.items()} diff --git a/retrack/engine/validators/__init__.py b/retrack/engine/validators/__init__.py index 9e41bf3..2f9b0f3 100644 --- a/retrack/engine/validators/__init__.py +++ b/retrack/engine/validators/__init__.py @@ -1,4 +1,5 @@ from retrack.engine.validators.base import BaseValidator +from retrack.engine.validators.check_is_dag import CheckIsDAG from retrack.engine.validators.node_exists import NodeExistsValidator from retrack.utils.registry import Registry @@ -8,5 +9,6 @@ "single_start_node_exists", NodeExistsValidator("start", min_quantity=1, max_quantity=1), ) +registry.register("check_is_dag", CheckIsDAG()) __all__ = ["registry", "BaseValidator"] diff --git a/retrack/engine/validators/base.py b/retrack/engine/validators/base.py index ff754aa..bd47890 100644 --- a/retrack/engine/validators/base.py +++ b/retrack/engine/validators/base.py @@ -1,12 +1,9 @@ class BaseValidator: """Base class for all validators.""" - def validate(self, graph_data: dict) -> bool: + def validate(self, **kwargs) -> bool: """Validate the graph data. - Args: - graph_data: The graph data to validate. - Returns: True if the graph data is valid, False otherwise. """ diff --git a/retrack/engine/validators/check_is_dag.py b/retrack/engine/validators/check_is_dag.py new file mode 100644 index 0000000..1c2c9db --- /dev/null +++ b/retrack/engine/validators/check_is_dag.py @@ -0,0 +1,15 @@ +import networkx as nx + +from retrack.engine.validators.base import BaseValidator + + +class CheckIsDAG(BaseValidator): + def validate(self, edges: list, **kwargs) -> bool: + """Validate the graph data. + + Returns: + True if the graph data is valid, False otherwise. + """ + graph = nx.DiGraph() + graph.add_edges_from(edges) + return nx.is_directed_acyclic_graph(graph) diff --git a/retrack/engine/validators/node_exists.py b/retrack/engine/validators/node_exists.py index fdacb42..58172fc 100644 --- a/retrack/engine/validators/node_exists.py +++ b/retrack/engine/validators/node_exists.py @@ -18,7 +18,7 @@ def __init__( self.max_quantity = max_quantity self.min_quantity = min_quantity - def validate(self, graph_data: dict) -> bool: + def validate(self, graph_data: dict, **kwargs) -> bool: """Validate the graph data. Args: From 8322ded79b9d7428f1d8852f9bdf2ea87c45a295 Mon Sep 17 00:00:00 2001 From: Gabriel Guarisa Date: Mon, 27 Feb 2023 14:43:29 -0300 Subject: [PATCH 11/11] v0.3.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7ef65d4..03282b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "retrack" -version = "0.2.2" +version = "0.3.0" description = "A business rules engine" authors = ["Gabriel Guarisa ", "Nathalia Trotte "] license = "MIT"