From cb60d56f714582e06fa3ef173292cd18e0cdf948 Mon Sep 17 00:00:00 2001 From: joocer Date: Tue, 7 Nov 2023 07:51:19 +0000 Subject: [PATCH] #1242 --- opteryx/components/binder/binder.py | 6 --- opteryx/components/logical_planner.py | 2 - .../components/logical_planner_builders.py | 4 +- opteryx/managers/expression/__init__.py | 2 +- opteryx/managers/expression/formatter.py | 44 +++++++++++++++++-- tests/sql_battery/test_results_battery.py | 4 +- .../test_shapes_and_errors_battery.py | 7 +++ .../tests/results/complex_001.results_tests | 9 ++++ .../tests/results/complex_002.results_tests | 11 +++++ .../tests/results/complex_003.results_tests | 13 ++++++ 10 files changed, 86 insertions(+), 16 deletions(-) create mode 100644 tests/sql_battery/tests/results/complex_001.results_tests create mode 100644 tests/sql_battery/tests/results/complex_002.results_tests create mode 100644 tests/sql_battery/tests/results/complex_003.results_tests diff --git a/opteryx/components/binder/binder.py b/opteryx/components/binder/binder.py index 5a45cdafe..2480d2616 100644 --- a/opteryx/components/binder/binder.py +++ b/opteryx/components/binder/binder.py @@ -229,12 +229,6 @@ def inner_binder(node: Node, context: Dict[str, Any]) -> Tuple[Node, Dict[str, A if node_type in (NodeType.IDENTIFIER, NodeType.EVALUATED): return locate_identifier(node, context) - # Expression Lists are part of how CASE statements are represented - if node_type == NodeType.EXPRESSION_LIST: - node.value, new_contexts = zip(*(inner_binder(parm, context) for parm in node.value)) - merged_schemas = merge_schemas(*[ctx.schemas for ctx in new_contexts]) - context.schemas = merged_schemas - # Early exit for nodes representing calculated columns. # If the node represents a calculated column, if we're seeing it again it's because it # has appeared earlier in the plan and in that case we don't need to recalcuate, we just diff --git a/opteryx/components/logical_planner.py b/opteryx/components/logical_planner.py index f6189dd25..2dce3168a 100644 --- a/opteryx/components/logical_planner.py +++ b/opteryx/components/logical_planner.py @@ -638,8 +638,6 @@ def create_node_relation(relation): left_node_id, left_plan = create_node_relation(join) - # ***** if the join is an unnest, the table needs fake name and the alias needs to be the column - # add the left and right relation names - we sometimes need these later join_step.right_relation_names = get_subplan_schemas(sub_plan) join_step.left_relation_names = get_subplan_schemas(left_plan) diff --git a/opteryx/components/logical_planner_builders.py b/opteryx/components/logical_planner_builders.py index 060e28479..f7c139c95 100644 --- a/opteryx/components/logical_planner_builders.py +++ b/opteryx/components/logical_planner_builders.py @@ -564,14 +564,14 @@ def case_when(value, alias: list = None, key=None): ) if else_result is not None: conditions.append(Node(NodeType.LITERAL, type=OrsoTypes.BOOLEAN, value=True)) - conditions_node = Node(NodeType.EXPRESSION_LIST, value=conditions) + conditions_node = Node(NodeType.EXPRESSION_LIST, parameters=conditions) results = [] for result in value["results"]: results.append(build(result)) if else_result is not None: results.append(else_result) - results_node = Node(NodeType.EXPRESSION_LIST, value=results) + results_node = Node(NodeType.EXPRESSION_LIST, parameters=results) return Node( NodeType.FUNCTION, diff --git a/opteryx/managers/expression/__init__.py b/opteryx/managers/expression/__init__.py index 04c633f0c..1996c6017 100644 --- a/opteryx/managers/expression/__init__.py +++ b/opteryx/managers/expression/__init__.py @@ -231,7 +231,7 @@ def _inner_evaluate(root: Node, table: Table, context: ExecutionContext): context.store(identity, result) return result if node_type == NodeType.EXPRESSION_LIST: - values = [_inner_evaluate(val, table, context) for val in root.value] + values = [_inner_evaluate(val, table, context) for val in root.parameters] return values diff --git a/opteryx/managers/expression/formatter.py b/opteryx/managers/expression/formatter.py index c00abed17..3dd6177fd 100644 --- a/opteryx/managers/expression/formatter.py +++ b/opteryx/managers/expression/formatter.py @@ -2,6 +2,7 @@ from typing import Any from orso.schema import FlatColumn +from orso.tools import random_string from orso.types import OrsoTypes @@ -10,6 +11,41 @@ class ExpressionColumn(FlatColumn): expression: Any = None +def _format_interval(value): + import datetime + + # MonthDayNano is a superclass of list, do before list + + if isinstance(value, tuple): + months, days, seconds = value + elif hasattr(value, "days"): + days = value.days + months = value.months + seconds = value.nanoseconds / 1e9 + elif isinstance(value, datetime.timedelta): + days = value.days + months = 0 + seconds = value.microseconds / 1e6 + value.seconds + + hours, seconds = divmod(seconds, 3600) + minutes, seconds = divmod(seconds, 60) + years, months = divmod(months, 12) + parts = [] + if years >= 1: + parts.append(f"{int(years)} YEAR") + if months >= 1: + parts.append(f"{int(months)} MONTH") + if days >= 1: + parts.append(f"{int(days)} DAY") + if hours >= 1: + parts.append(f"{int(hours)} HOUR") + if minutes >= 1: + parts.append(f"{int(minutes)} MINUTE") + if seconds > 0: + parts.append(f"{seconds:.2f} SECOND") + return " ".join(parts) + + def format_expression(root, qualify: bool = False): # circular imports from . import INTERNAL_TYPE @@ -38,7 +74,7 @@ def format_expression(root, qualify: bool = False): if literal_type == OrsoTypes.TIMESTAMP: return "'" + str(root.value) + "'" if literal_type == OrsoTypes.INTERVAL: - return "" + return _format_interval(root.value) if literal_type == OrsoTypes.NULL: return "null" return str(root.value) @@ -46,8 +82,8 @@ def format_expression(root, qualify: bool = False): if node_type & INTERNAL_TYPE == INTERNAL_TYPE: if node_type in (NodeType.FUNCTION, NodeType.AGGREGATOR): if root.value == "CASE": - con = [format_expression(a, qualify) for a in root.parameters[0].value] - vals = [format_expression(a, qualify) for a in root.parameters[1].value] + con = [format_expression(a, qualify) for a in root.parameters[0].parameters] + vals = [format_expression(a, qualify) for a in root.parameters[1].parameters] return "CASE " + "".join([f"WHEN {c} THEN {v} " for c, v in zip(con, vals)]) + "END" distinct = "DISTINCT " if root.distinct else "" order = "" @@ -77,6 +113,8 @@ def format_expression(root, qualify: bool = False): "ShiftRight": ">>", } return f"{format_expression(root.left, qualify)} {_map.get(root.value, root.value).upper()} {format_expression(root.right, qualify)}" + if node_type == NodeType.EXPRESSION_LIST: + return f"" if node_type == NodeType.COMPARISON_OPERATOR: _map = { "Eq": "=", diff --git a/tests/sql_battery/test_results_battery.py b/tests/sql_battery/test_results_battery.py index a9be5c144..e5f20a668 100644 --- a/tests/sql_battery/test_results_battery.py +++ b/tests/sql_battery/test_results_battery.py @@ -28,7 +28,7 @@ def get_tests(test_type): suites = glob.glob(f"**/**.{test_type}", recursive=True) - for suite in suites: + for suite in sorted(suites): with open(suite, mode="r") as test_file: try: yield {"file": suite, **orjson.loads(test_file.read())} @@ -51,7 +51,7 @@ def test_results_tests(test): cursor.execute(sql) result = cursor.arrow().to_pydict() - printable_result = orjson.dumps(result).decode() + printable_result = orjson.dumps(result, default=str).decode() printable_expected = orjson.dumps(test["result"]).decode() assert ( diff --git a/tests/sql_battery/test_shapes_and_errors_battery.py b/tests/sql_battery/test_shapes_and_errors_battery.py index 39a7dfa31..56121f80b 100644 --- a/tests/sql_battery/test_shapes_and_errors_battery.py +++ b/tests/sql_battery/test_shapes_and_errors_battery.py @@ -1046,6 +1046,9 @@ ("SELECT * FROM FAKE(100, (Name, Name)) AS FK", 100, 2, None), ("SELECT * FROM FAKE(100, 10) AS FK(nom, nim, nam)", 100, 10, None), + ("SELECT * FROM $planets WHERE diameter > 10000 AND gravity BETWEEN 0.5 AND 2.0;", 0, 20, None), + ("SELECT * FROM $planets WHERE diameter > 100 AND gravity BETWEEN 0.5 AND 2.0;", 1, 20, None), + # virtual dataset doesn't exist ("SELECT * FROM $RomanGods", None, None, DatasetNotFoundError), # disk dataset doesn't exist @@ -1257,6 +1260,10 @@ ("SELECT p.name, s.name FROM $planets as p, $satellites as s WHERE p.id = s.planetId", 177, 2, None), # Can't qualify fields used in subscripts ("SELECT d.birth_place['town'] FROM $astronauts AS d", 357, 1, None), + # Null columns can't be inverted + ("SELECT NOT NULL", 1, 1, None), + # Columns in CAST statements appear to not be bound correctly + ("SELECT SUM(CASE WHEN gm > 10 THEN 1 ELSE 0 END) AS gm_big_count FROM $satellites", 1, 1, None), ] # fmt:on diff --git a/tests/sql_battery/tests/results/complex_001.results_tests b/tests/sql_battery/tests/results/complex_001.results_tests new file mode 100644 index 000000000..08e292452 --- /dev/null +++ b/tests/sql_battery/tests/results/complex_001.results_tests @@ -0,0 +1,9 @@ +{ + "summary": "Complex tests - This is a variation of the regression 001 query", + "statement": "SELECT name AS nom, IFNULL(satellites_count.gm_big_count, 0) AS big_occurrences, IFNULL(satellites_count.gm_small_count, 0) AS small_occurrences FROM (SELECT id, name FROM $planets WHERE name = 'Saturn') AS planets LEFT JOIN (SELECT planetId, SUM(CASE WHEN gm > 10 THEN 1 ELSE 0 END) AS gm_big_count, SUM(CASE WHEN gm <= 10 THEN 1 ELSE 0 END) AS gm_small_count FROM $satellites GROUP BY planetId) AS satellites_count ON planets.id = satellites_count.planetId;", + "result": { + "nom": ["Saturn"], + "big_occurrences": [5], + "small_occurrences": [56] + } +} diff --git a/tests/sql_battery/tests/results/complex_002.results_tests b/tests/sql_battery/tests/results/complex_002.results_tests new file mode 100644 index 000000000..6526641a6 --- /dev/null +++ b/tests/sql_battery/tests/results/complex_002.results_tests @@ -0,0 +1,11 @@ +{ + "summary": "Retrieve planets with specific gravity and diameter metrics, along with count and average radius of high GM satellites and average magnitude of visible satellites.", + "statement": "SELECT p.name AS planet_name, p.diameter, high_gm_stats.high_gm_satellites_count, high_gm_stats.avg_high_gm_radius, visible_stats.avg_magnitude FROM $planets p LEFT JOIN (SELECT planetId, COUNT(*) AS high_gm_satellites_count, AVG(radius) AS avg_high_gm_radius FROM $satellites WHERE gm > 5 GROUP BY planetId) high_gm_stats ON p.id = high_gm_stats.planetId LEFT JOIN (SELECT planetId, AVG(magnitude) AS avg_magnitude FROM $satellites WHERE magnitude < 2.0 GROUP BY planetId) visible_stats ON p.id = visible_stats.planetId WHERE p.diameter > 100 AND p.gravity BETWEEN 0.5 AND 2.0 ORDER BY high_gm_stats.high_gm_satellites_count DESC, visible_stats.avg_magnitude ASC;", + "result": { + "planet_name": ["Pluto"], + "p.diameter": [2370], + "high_gm_stats.high_gm_satellites_count": [1], + "high_gm_stats.avg_high_gm_radius": [603.6], + "visible_stats.avg_magnitude": [null] + } +} diff --git a/tests/sql_battery/tests/results/complex_003.results_tests b/tests/sql_battery/tests/results/complex_003.results_tests new file mode 100644 index 000000000..841bb1a59 --- /dev/null +++ b/tests/sql_battery/tests/results/complex_003.results_tests @@ -0,0 +1,13 @@ +{ + "summary": "This query retrieves planets with their orbital periods and diameters, along with counts and averages of their dense and bright satellites. It filters planets based on their distance from the sun and orbital eccentricity.", + "statement": "SELECT pl.name AS planet_name, pl.orbital_period, pl.diameter, dense_moons_stats.total_dense_moons, dense_moons_stats.avg_density, bright_moons_stats.avg_magnitude, bright_moons_stats.total_bright_moons FROM $planets pl LEFT JOIN (SELECT planetId, COUNT(*) AS total_dense_moons, AVG(density) AS avg_density FROM $satellites WHERE density > 2 GROUP BY planetId) dense_moons_stats ON pl.id = dense_moons_stats.planetId LEFT JOIN (SELECT planetId, AVG(magnitude) AS avg_magnitude, COUNT(*) AS total_bright_moons FROM $satellites WHERE magnitude < 5 GROUP BY planetId) bright_moons_stats ON pl.id = bright_moons_stats.planetId WHERE pl.distance_from_sun BETWEEN 100 AND 200 AND pl.orbital_eccentricity < 0.1 ORDER BY dense_moons_stats.total_dense_moons DESC, bright_moons_stats.avg_magnitude ASC LIMIT 10;", + "result": { + "bright_moons_stats.avg_magnitude": [-12.74, null], + "bright_moons_stats.total_bright_moons": [1, null], + "dense_moons_stats.total_dense_moons": [1, null], + "dense_moons_stats.avg_density": [3.344, null], + "planet_name": ["Earth", "Venus"], + "pl.diameter": [12756, 12104], + "pl.orbital_period": [365.2, 224.7] + } +}