Skip to content

Commit

Permalink
Merge pull request #1478 from mabel-dev/#1459
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Feb 27, 2024
2 parents 88af380 + ab885fa commit 2536938
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 5 deletions.
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 322
__build__ = 323

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion opteryx/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def query_planner(operation: str, parameters: Union[Iterable, None], connection,
from opteryx.models import QueryProperties
from opteryx.third_party import sqloxide

# SQL Rewriter removes whitespace and comments, and extracts temporal filters
# SQL Rewriter extracts temporal filters
clean_sql, temporal_filters = do_sql_rewrite(operation)
params = [p for p in parameters or []]

Expand Down
4 changes: 3 additions & 1 deletion opteryx/components/ast_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@
from typing import List
from typing import Union

import numpy

from opteryx.exceptions import ParameterError

LiteralNode = Dict[str, Any]
Expand All @@ -89,7 +91,7 @@ def _build_literal_node(value: Any) -> LiteralNode:
return {"Value": {"Boolean": value}}
elif isinstance(value, str):
return {"Value": {"SingleQuotedString": value}}
elif isinstance(value, (int, float, decimal.Decimal)):
elif isinstance(value, (int, float, decimal.Decimal, numpy.int64)):
return {"Value": {"Number": [value, False]}}
elif isinstance(value, (datetime.date, datetime.datetime, datetime.time)):
return {"Value": {"SingleQuotedString": value.isoformat()}}
Expand Down
52 changes: 51 additions & 1 deletion opteryx/components/logical_planner/logical_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,56 @@ def create_node_relation(relation):
return root_node, sub_plan


def plan_execute_query(statement) -> LogicalPlan:

import orjson

from opteryx.components.ast_rewriter import do_ast_rewriter
from opteryx.components.logical_planner import do_logical_planning_phase
from opteryx.components.sql_rewriter import do_sql_rewrite
from opteryx.exceptions import SqlError
from opteryx.third_party import sqloxide
from opteryx.utils import sql

statement_name = statement["Execute"]["name"]["value"]
parameters = [
logical_planner_builders.build(p["Value"]) for p in statement["Execute"]["parameters"]
]
try:
with open("prepared_statements.json", "r") as ps:
prepared_statatements = orjson.loads(ps.read())
except OSError:
prepared_statatements = {}

# we have an inbuilt statements as a fallback
prepared_statatements["planets_by_id"] = "SELECT * FROM $planets WHERE id = ?"
prepared_statatements["version"] = "SELECT version()"

if statement_name not in prepared_statatements:
raise SqlError("Unable to EXECUTE prepared statement, '{statement_name}' not defined.")
operation = prepared_statatements[statement_name]

operation = sql.remove_comments(operation)
operation = sql.clean_statement(operation)
statements = sql.split_sql_statements(operation)
if len(statements) > 1:
raise UnsupportedSyntaxError("EXECUTE cannot run multi-part and batch prepared statements.")

clean_sql, temporal_filters = do_sql_rewrite(operation)
try:
parsed_statements = sqloxide.parse_sql(clean_sql, dialect="mysql")
except ValueError as parser_error:
raise SqlError(parser_error) from parser_error

parsed_statements = do_ast_rewriter(
parsed_statements,
temporal_filters=temporal_filters,
paramters=parameters,
connection=None,
)
return list(do_logical_planning_phase(parsed_statements))[0][0]


def plan_explain(statement) -> LogicalPlan:
plan = LogicalPlan()
explain_node = LogicalPlanNode(node_type=LogicalPlanStepType.Explain)
Expand Down Expand Up @@ -928,7 +978,7 @@ def plan_show_variables(statement):

QUERY_BUILDERS = {
# "Analyze": analyze_query,
# "Execute": plan_execute_query,
"Execute": plan_execute_query,
"Explain": plan_explain,
"Query": plan_query,
"SetVariable": plan_set_variable,
Expand Down
3 changes: 2 additions & 1 deletion prepared_statements.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"get_planet_by_id": "SELECT id, name FROM $planets WHERE id = ?"
"get_satellites_by_planet_name": "/* A test case FOR the EXECUTE query functionality */ SELECT s.name AS satellite_name FROM $satellites AS s INNER JOIN $planets AS p ON p.id = s.planetId WHERE p.name = ?",
"multiply_two_numbers": "SELECT ? * ?"
}
5 changes: 5 additions & 0 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,11 @@

("SELECT missions[0] as m FROM $astronauts CROSS JOIN FAKE(1, 1) AS F order by m", 357, 1, None),

("EXECUTE planets_by_id (1)", 1, 20, None), # simple case
("EXECUTE version", 1, 1, None), # no paramters
("EXECUTE get_satellites_by_planet_name('Jupiter')", 67, 1, None), # string param
("EXECUTE multiply_two_numbers (1.0, 9.9)", 1, 1, None), # multiple params

# These are queries which have been found to return the wrong result or not run correctly
# FILTERING ON FUNCTIONS
("SELECT DATE(birth_date) FROM $astronauts FOR TODAY WHERE DATE(birth_date) < '1930-01-01'", 14, 1, None),
Expand Down

0 comments on commit 2536938

Please sign in to comment.