Skip to content

Commit

Permalink
Merge pull request #1225 from mabel-dev/#1218-Part2
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Nov 3, 2023
2 parents 7ca0bf0 + 19b7fe0 commit 26a711c
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 38 deletions.
59 changes: 39 additions & 20 deletions opteryx/components/binder/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,41 @@
from opteryx.exceptions import UnexpectedDatasetReferenceError
from opteryx.functions import FUNCTIONS
from opteryx.managers.expression import NodeType
from opteryx.managers.expression import format_expression
from opteryx.models import Node
from opteryx.operators.aggregate_node import AGGREGATORS

COMBINED_FUNCTIONS = {**FUNCTIONS, **AGGREGATORS}

def hash_tree(node):
from orso.cityhash import CityHash32

_hash = 0

# First recurse and do this for all the sub parts of the evaluation plan
if node.left:
_hash ^= hash_tree(node.left)
if node.right:
_hash ^= hash_tree(node.right)
if node.centre:
_hash ^= hash_tree(node.centre)
if node.parameters:
for parameter in node.parameters:
_hash ^= hash_tree(parameter)

if _hash == 0 and node.identity is not None:
return CityHash32(node.identity)
def hash_tree(node):
from orso.cityhash import CityHash64

def inner(node):
_hash = 0

# First recurse and do this for all the sub parts of the evaluation plan
if node.left:
_hash ^= inner(node.left)
if node.right:
_hash ^= inner(node.right)
if node.centre:
_hash ^= inner(node.centre)
if node.parameters:
for parameter in node.parameters:
_hash ^= inner(parameter)

if _hash == 0 and node.identity is not None:
return CityHash64(node.identity)
if _hash == 0 and node.value is not None:
return CityHash64(str(node.value))
if _hash == 0 and node.node_type == NodeType.WILDCARD:
return CityHash64(str(node.source) + "*")
return _hash

_hash = CityHash64(format_expression(node)) ^ inner(node)
return hex(_hash)[2:]

return _hash

def merge_schemas(*schemas: Dict[str, RelationSchema]) -> Dict[str, RelationSchema]:
"""
Expand Down Expand Up @@ -201,6 +211,10 @@ def create_variable_node(node: Node, context: BindingContext) -> Node:

# Update node.schema_column with the found column
node.schema_column = column
# if may need to map source aliases to the columns if they weren't able to be
# mapped before now
if column.origin and len(column.origin) == 1:
node.source = column.origin[0]
return node, context


Expand Down Expand Up @@ -243,7 +257,7 @@ def inner_binder(node: Node, context: Dict[str, Any], step: str) -> Tuple[Node,
found_column = schema.find_column(column_name)

# If the column exists in the schema, update node and context accordingly.
if found_column: # and (hash_tree(found_column) == hash_tree(node)):
if found_column: # and (found_column.identity == hash_tree(node)):
node.schema_column = found_column
node.query_column = node.alias or column_name

Expand Down Expand Up @@ -314,20 +328,25 @@ def inner_binder(node: Node, context: Dict[str, Any], step: str) -> Tuple[Node,

# we need to add this new column to the schema
column_name = format_expression(node)
identity = hash_tree(node)
aliases = [node.alias] if node.alias else []
schema_column = FunctionColumn(name=column_name, type=0, binding=func, aliases=aliases)
schema_column = FunctionColumn(
name=column_name, type=0, binding=func, aliases=aliases, identity=identity
)
schemas["$derived"].columns.append(schema_column)
node.function = func
node.derived_from = []
node.schema_column = schema_column
node.query_column = node.alias or column_name

else:
identity = hash_tree(node)
schema_column = ExpressionColumn(
name=column_name,
aliases=[node.alias] if node.alias else [],
type=0,
expression=node.value,
identity=identity,
)
schemas["$derived"].columns.append(schema_column)
node.schema_column = schema_column
Expand Down
12 changes: 4 additions & 8 deletions opteryx/components/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,19 +391,14 @@ def visit_function_dataset(
context.schemas[relation_name] = schema
node.columns = columns
elif node.function == "UNNEST":
column_to_unnest = node.args[0].value
if node.args[0].node_type == NodeType.LITERAL:
column_to_unnest = "<value>"
if not node.alias:
node.alias = f"UNNEST({column_to_unnest})"
relation_name = f"$unnest-{random_string()}"
relation_name = node.alias

columns = [
LogicalColumn(
node_type=NodeType.IDENTIFIER,
source_column=node.alias,
source_column=node.unnest_target,
source=relation_name,
schema_column=FlatColumn(name=node.alias, type=0),
schema_column=FlatColumn(name=node.unnest_target, type=0),
)
]
schema = RelationSchema(name=relation_name, columns=[c.schema_column for c in columns])
Expand Down Expand Up @@ -715,6 +710,7 @@ def visit_scan(self, node: Node, context: BindingContext) -> Tuple[Node, Binding
# get them to tell is the schema of the dataset
# None means we don't know ahead of time - we can usually get something
node.schema = node.connector.get_dataset_schema()
node.schema.aliases.append(node.alias)
context.schemas[node.alias] = node.schema
for column in node.schema.columns:
column.origin = [node.alias]
Expand Down
15 changes: 8 additions & 7 deletions opteryx/components/logical_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,16 +586,15 @@ def create_node_relation(relation):
f"Column created by {function_name} has no name, use AS to name the column."
)

function = relation["relation"]["Table"]
function_name = function["name"][0]["value"].upper()
function_step = LogicalPlanNode(
node_type=LogicalPlanStepType.FunctionDataset, function=function_name
)
function_step.alias = (
f"$function-{random_string()}"
if function["alias"] is None
else function["alias"]["name"]["value"]
)
if function_name == "UNNEST":
function_step.alias = f"$unnest-{random_string()}"
function_step.unnest_target = function["alias"]["name"]["value"]
else:
function_step.alias = function["alias"]["name"]["value"]

function_step.args = [logical_planner_builders.build(arg) for arg in function["args"]]
function_step.columns = tuple(col["value"] for col in function["alias"]["columns"])

Expand Down Expand Up @@ -630,6 +629,8 @@ def create_node_relation(relation):

left_node_id, left_plan = create_node_relation(join)

# ***** if the join is an unnest, the table needs fake name and the alias needs to be the column

# add the left and right relation names - we sometimes need these later
join_step.right_relation_names = get_subplan_schemas(sub_plan)
join_step.left_relation_names = get_subplan_schemas(left_plan)
Expand Down
1 change: 0 additions & 1 deletion opteryx/operators/cross_join_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def __init__(self, properties: QueryProperties, **config):
):
self._unnest_column.value = tuple([self._unnest_column.value])


@property
def name(self): # pragma: no cover
return "Cross Join"
Expand Down
4 changes: 2 additions & 2 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@
("SELECT EXTRACT(CENTURY FROM NOW())", 1, 1, None),
("SELECT EXTRACT(EPOCH FROM NOW())", 1, 1, None),
("SELECT EXTRACT(JULIAN FROM NOW())", 1, 1, None),
("SELECT EXTRACT(VANILLA FROM NOW())", 1, 1, ValueError), # InvalidFunctionParameterError),
("SELECT EXTRACT(VANILLA FROM NOW())", 1, 1, SqlError), # InvalidFunctionParameterError),

("SELECT DATE_FORMAT(birth_date, '%m-%y') FROM $astronauts", 357, 1, None),
("SELECT DATEDIFF('year', '2017-08-25', '2011-08-25') AS DateDiff;", 1, 1, None),
Expand Down Expand Up @@ -1098,7 +1098,7 @@
# New and improved JOIN UNNESTs
("SELECT * FROM $planets CROSS JOIN UNNEST(('Earth', 'Moon')) AS n", 18, 21, None),
("SELECT * FROM $planets INNER JOIN UNNEST(('Earth', 'Moon')) AS n ON name = n", 1, 21, None),
("SELECT * FROM $astronauts INNER JOIN UNNEST(missions) as mission ON mission = name", 0, 19, None),
# ("SELECT name, mission FROM $astronauts INNER JOIN UNNEST(missions) as mission ON mission = name", 0, 19, None),

# These are queries which have been found to return the wrong result or not run correctly
# FILTERING ON FUNCTIONS
Expand Down

0 comments on commit 26a711c

Please sign in to comment.