Skip to content

Commit

Permalink
Merge pull request #1243 from mabel-dev/#1242
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Nov 7, 2023
2 parents de231d4 + cb60d56 commit 7a03de3
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 16 deletions.
6 changes: 0 additions & 6 deletions opteryx/components/binder/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,6 @@ def inner_binder(node: Node, context: Dict[str, Any]) -> Tuple[Node, Dict[str, A
if node_type in (NodeType.IDENTIFIER, NodeType.EVALUATED):
return locate_identifier(node, context)

# Expression Lists are part of how CASE statements are represented
if node_type == NodeType.EXPRESSION_LIST:
node.value, new_contexts = zip(*(inner_binder(parm, context) for parm in node.value))
merged_schemas = merge_schemas(*[ctx.schemas for ctx in new_contexts])
context.schemas = merged_schemas

# Early exit for nodes representing calculated columns.
# If the node represents a calculated column, if we're seeing it again it's because it
# has appeared earlier in the plan and in that case we don't need to recalcuate, we just
Expand Down
2 changes: 0 additions & 2 deletions opteryx/components/logical_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,6 @@ def create_node_relation(relation):

left_node_id, left_plan = create_node_relation(join)

# ***** if the join is an unnest, the table needs fake name and the alias needs to be the column

# add the left and right relation names - we sometimes need these later
join_step.right_relation_names = get_subplan_schemas(sub_plan)
join_step.left_relation_names = get_subplan_schemas(left_plan)
Expand Down
4 changes: 2 additions & 2 deletions opteryx/components/logical_planner_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,14 +564,14 @@ def case_when(value, alias: list = None, key=None):
)
if else_result is not None:
conditions.append(Node(NodeType.LITERAL, type=OrsoTypes.BOOLEAN, value=True))
conditions_node = Node(NodeType.EXPRESSION_LIST, value=conditions)
conditions_node = Node(NodeType.EXPRESSION_LIST, parameters=conditions)

results = []
for result in value["results"]:
results.append(build(result))
if else_result is not None:
results.append(else_result)
results_node = Node(NodeType.EXPRESSION_LIST, value=results)
results_node = Node(NodeType.EXPRESSION_LIST, parameters=results)

return Node(
NodeType.FUNCTION,
Expand Down
2 changes: 1 addition & 1 deletion opteryx/managers/expression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _inner_evaluate(root: Node, table: Table, context: ExecutionContext):
context.store(identity, result)
return result
if node_type == NodeType.EXPRESSION_LIST:
values = [_inner_evaluate(val, table, context) for val in root.value]
values = [_inner_evaluate(val, table, context) for val in root.parameters]
return values


Expand Down
44 changes: 41 additions & 3 deletions opteryx/managers/expression/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

from orso.schema import FlatColumn
from orso.tools import random_string
from orso.types import OrsoTypes


Expand All @@ -10,6 +11,41 @@ class ExpressionColumn(FlatColumn):
expression: Any = None


def _format_interval(value):
import datetime

# MonthDayNano is a superclass of list, do before list

if isinstance(value, tuple):
months, days, seconds = value
elif hasattr(value, "days"):
days = value.days
months = value.months
seconds = value.nanoseconds / 1e9
elif isinstance(value, datetime.timedelta):
days = value.days
months = 0
seconds = value.microseconds / 1e6 + value.seconds

hours, seconds = divmod(seconds, 3600)
minutes, seconds = divmod(seconds, 60)
years, months = divmod(months, 12)
parts = []
if years >= 1:
parts.append(f"{int(years)} YEAR")
if months >= 1:
parts.append(f"{int(months)} MONTH")
if days >= 1:
parts.append(f"{int(days)} DAY")
if hours >= 1:
parts.append(f"{int(hours)} HOUR")
if minutes >= 1:
parts.append(f"{int(minutes)} MINUTE")
if seconds > 0:
parts.append(f"{seconds:.2f} SECOND")
return " ".join(parts)


def format_expression(root, qualify: bool = False):
# circular imports
from . import INTERNAL_TYPE
Expand Down Expand Up @@ -38,16 +74,16 @@ def format_expression(root, qualify: bool = False):
if literal_type == OrsoTypes.TIMESTAMP:
return "'" + str(root.value) + "'"
if literal_type == OrsoTypes.INTERVAL:
return "<INTERVAL>"
return _format_interval(root.value)
if literal_type == OrsoTypes.NULL:
return "null"
return str(root.value)
# INTERAL IDENTIFIERS
if node_type & INTERNAL_TYPE == INTERNAL_TYPE:
if node_type in (NodeType.FUNCTION, NodeType.AGGREGATOR):
if root.value == "CASE":
con = [format_expression(a, qualify) for a in root.parameters[0].value]
vals = [format_expression(a, qualify) for a in root.parameters[1].value]
con = [format_expression(a, qualify) for a in root.parameters[0].parameters]
vals = [format_expression(a, qualify) for a in root.parameters[1].parameters]
return "CASE " + "".join([f"WHEN {c} THEN {v} " for c, v in zip(con, vals)]) + "END"
distinct = "DISTINCT " if root.distinct else ""
order = ""
Expand Down Expand Up @@ -77,6 +113,8 @@ def format_expression(root, qualify: bool = False):
"ShiftRight": ">>",
}
return f"{format_expression(root.left, qualify)} {_map.get(root.value, root.value).upper()} {format_expression(root.right, qualify)}"
if node_type == NodeType.EXPRESSION_LIST:
return f"<EXPRESSIONS {random_string(4)}>"
if node_type == NodeType.COMPARISON_OPERATOR:
_map = {
"Eq": "=",
Expand Down
4 changes: 2 additions & 2 deletions tests/sql_battery/test_results_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

def get_tests(test_type):
suites = glob.glob(f"**/**.{test_type}", recursive=True)
for suite in suites:
for suite in sorted(suites):
with open(suite, mode="r") as test_file:
try:
yield {"file": suite, **orjson.loads(test_file.read())}
Expand All @@ -51,7 +51,7 @@ def test_results_tests(test):
cursor.execute(sql)
result = cursor.arrow().to_pydict()

printable_result = orjson.dumps(result).decode()
printable_result = orjson.dumps(result, default=str).decode()
printable_expected = orjson.dumps(test["result"]).decode()

assert (
Expand Down
7 changes: 7 additions & 0 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,9 @@
("SELECT * FROM FAKE(100, (Name, Name)) AS FK", 100, 2, None),
("SELECT * FROM FAKE(100, 10) AS FK(nom, nim, nam)", 100, 10, None),

("SELECT * FROM $planets WHERE diameter > 10000 AND gravity BETWEEN 0.5 AND 2.0;", 0, 20, None),
("SELECT * FROM $planets WHERE diameter > 100 AND gravity BETWEEN 0.5 AND 2.0;", 1, 20, None),

# virtual dataset doesn't exist
("SELECT * FROM $RomanGods", None, None, DatasetNotFoundError),
# disk dataset doesn't exist
Expand Down Expand Up @@ -1257,6 +1260,10 @@
("SELECT p.name, s.name FROM $planets as p, $satellites as s WHERE p.id = s.planetId", 177, 2, None),
# Can't qualify fields used in subscripts
("SELECT d.birth_place['town'] FROM $astronauts AS d", 357, 1, None),
# Null columns can't be inverted
("SELECT NOT NULL", 1, 1, None),
# Columns in CAST statements appear to not be bound correctly
("SELECT SUM(CASE WHEN gm > 10 THEN 1 ELSE 0 END) AS gm_big_count FROM $satellites", 1, 1, None),
]
# fmt:on

Expand Down
9 changes: 9 additions & 0 deletions tests/sql_battery/tests/results/complex_001.results_tests
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"summary": "Complex tests - This is a variation of the regression 001 query",
"statement": "SELECT name AS nom, IFNULL(satellites_count.gm_big_count, 0) AS big_occurrences, IFNULL(satellites_count.gm_small_count, 0) AS small_occurrences FROM (SELECT id, name FROM $planets WHERE name = 'Saturn') AS planets LEFT JOIN (SELECT planetId, SUM(CASE WHEN gm > 10 THEN 1 ELSE 0 END) AS gm_big_count, SUM(CASE WHEN gm <= 10 THEN 1 ELSE 0 END) AS gm_small_count FROM $satellites GROUP BY planetId) AS satellites_count ON planets.id = satellites_count.planetId;",
"result": {
"nom": ["Saturn"],
"big_occurrences": [5],
"small_occurrences": [56]
}
}
11 changes: 11 additions & 0 deletions tests/sql_battery/tests/results/complex_002.results_tests
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"summary": "Retrieve planets with specific gravity and diameter metrics, along with count and average radius of high GM satellites and average magnitude of visible satellites.",
"statement": "SELECT p.name AS planet_name, p.diameter, high_gm_stats.high_gm_satellites_count, high_gm_stats.avg_high_gm_radius, visible_stats.avg_magnitude FROM $planets p LEFT JOIN (SELECT planetId, COUNT(*) AS high_gm_satellites_count, AVG(radius) AS avg_high_gm_radius FROM $satellites WHERE gm > 5 GROUP BY planetId) high_gm_stats ON p.id = high_gm_stats.planetId LEFT JOIN (SELECT planetId, AVG(magnitude) AS avg_magnitude FROM $satellites WHERE magnitude < 2.0 GROUP BY planetId) visible_stats ON p.id = visible_stats.planetId WHERE p.diameter > 100 AND p.gravity BETWEEN 0.5 AND 2.0 ORDER BY high_gm_stats.high_gm_satellites_count DESC, visible_stats.avg_magnitude ASC;",
"result": {
"planet_name": ["Pluto"],
"p.diameter": [2370],
"high_gm_stats.high_gm_satellites_count": [1],
"high_gm_stats.avg_high_gm_radius": [603.6],
"visible_stats.avg_magnitude": [null]
}
}
13 changes: 13 additions & 0 deletions tests/sql_battery/tests/results/complex_003.results_tests
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"summary": "This query retrieves planets with their orbital periods and diameters, along with counts and averages of their dense and bright satellites. It filters planets based on their distance from the sun and orbital eccentricity.",
"statement": "SELECT pl.name AS planet_name, pl.orbital_period, pl.diameter, dense_moons_stats.total_dense_moons, dense_moons_stats.avg_density, bright_moons_stats.avg_magnitude, bright_moons_stats.total_bright_moons FROM $planets pl LEFT JOIN (SELECT planetId, COUNT(*) AS total_dense_moons, AVG(density) AS avg_density FROM $satellites WHERE density > 2 GROUP BY planetId) dense_moons_stats ON pl.id = dense_moons_stats.planetId LEFT JOIN (SELECT planetId, AVG(magnitude) AS avg_magnitude, COUNT(*) AS total_bright_moons FROM $satellites WHERE magnitude < 5 GROUP BY planetId) bright_moons_stats ON pl.id = bright_moons_stats.planetId WHERE pl.distance_from_sun BETWEEN 100 AND 200 AND pl.orbital_eccentricity < 0.1 ORDER BY dense_moons_stats.total_dense_moons DESC, bright_moons_stats.avg_magnitude ASC LIMIT 10;",
"result": {
"bright_moons_stats.avg_magnitude": [-12.74, null],
"bright_moons_stats.total_bright_moons": [1, null],
"dense_moons_stats.total_dense_moons": [1, null],
"dense_moons_stats.avg_density": [3.344, null],
"planet_name": ["Earth", "Venus"],
"pl.diameter": [12756, 12104],
"pl.orbital_period": [365.2, 224.7]
}
}

0 comments on commit 7a03de3

Please sign in to comment.