Skip to content

Commit

Permalink
Merge pull request #2183 from mabel-dev/#2159
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Jan 1, 2025
2 parents 323a5bd + 0e97892 commit bad4b67
Show file tree
Hide file tree
Showing 12 changed files with 253 additions and 62 deletions.
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 939
__build__ = 941

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
1 change: 1 addition & 0 deletions opteryx/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def sleep(x):
"STR": cast_varchar,
"STRUCT": _iterate_single_parameter(lambda x: orjson.loads(str(x)) if x is not None else None),
"DATE": lambda x: compute.cast(x, pyarrow.date32()),
"PASSTHRU": lambda x: x,
"BLOB": cast_blob,
"TRY_TIMESTAMP": try_cast("TIMESTAMP"),
"TRY_BOOLEAN": try_cast("BOOLEAN"),
Expand Down
10 changes: 5 additions & 5 deletions opteryx/functions/other_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def search(array, item, ignore_case: Optional[List[bool]] = None):
return results_mask


def if_null(values, replacement):
def if_null(values, replacements):
"""
Replace null values in the input array with corresponding values from the replacement array.
Expand All @@ -102,10 +102,10 @@ def if_null(values, replacement):
from opteryx.managers.expression.unary_operations import _is_null

# Create a mask for null values
is_null_array = _is_null(values)
is_null_mask = _is_null(values)

# Use NumPy's where function to vectorize the operation
return numpy.where(is_null_array, replacement, values)
return numpy.where(is_null_mask, replacements, values).astype(replacements.dtype)


def if_not_null(values: numpy.ndarray, replacements: numpy.ndarray) -> numpy.ndarray:
Expand All @@ -122,8 +122,8 @@ def if_not_null(values: numpy.ndarray, replacements: numpy.ndarray) -> numpy.nda
"""
from opteryx.managers.expression.unary_operations import _is_not_null

not_null_mask = _is_not_null(values)
return numpy.where(not_null_mask, replacements, values)
is_not_null_mask = _is_not_null(values)
return numpy.where(is_not_null_mask, replacements, values).astype(replacements.dtype)


def null_if(col1, col2):
Expand Down
1 change: 0 additions & 1 deletion opteryx/functions/string_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def _inner(val, _from, _for):
_from -= 1
_for = int(_for) if _for and _for == _for else None # nosec
if _for is None:
print(val, _from)
return val[_from:]
return val[_from : _for + _from]

Expand Down
8 changes: 4 additions & 4 deletions opteryx/operators/cross_join_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from . import JoinNode

INTERNAL_BATCH_SIZE: int = 7500 # config
INTERNAL_BATCH_SIZE: int = 10000 # config
MAX_JOIN_SIZE: int = 1000 # config
MORSEL_SIZE_BYTES: int = 16 * 1024 * 1024
CROSS_JOIN_UNNEST_BATCH_SIZE = 10000
Expand Down Expand Up @@ -115,9 +115,9 @@ def _cross_join_unnest_column(
else:
# Rebuild the block with the new column data if we have any rows to build for

total_rows = len(indices) # Both arrays have the same length
total_rows = indices.size # Both arrays have the same length
block_size = MORSEL_SIZE_BYTES / (left_block.nbytes / left_block.num_rows)
block_size = int(block_size // 1000) * 1000
block_size = int(block_size / 1000) * 1000

for start_block in range(0, total_rows, block_size):
# Compute the end index for the current chunk
Expand All @@ -128,7 +128,7 @@ def _cross_join_unnest_column(
new_column_data_chunk = new_column_data[start_block:end_block]

# Create a new block using the chunk of indices
indices_chunk = numpy.array(indices_chunk, dtype=numpy.int32)
indices_chunk = numpy.array(indices_chunk, dtype=numpy.int64)
new_block = left_block.take(indices_chunk)
new_block = pyarrow.Table.from_batches([new_block], schema=morsel.schema)

Expand Down
128 changes: 83 additions & 45 deletions opteryx/planner/cost_based_optimizer/strategies/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,27 @@
from .optimization_strategy import OptimizerContext


def _build_if_not_null_node(root, value, value_if_not_null) -> Node:
node = Node(node_type=NodeType.FUNCTION)
node.value = "IFNOTNULL"
node.parameters = [value, value_if_not_null]
node.schema_column = root.schema_column
node.query_column = root.query_column
return node


def _build_passthru_node(root, value) -> Node:
if root.node_type == NodeType.COMPARISON_OPERATOR:
return root

node = Node(node_type=NodeType.FUNCTION)
node.value = "PASSTHRU"
node.parameters = [value]
node.schema_column = root.schema_column
node.query_column = root.query_column
return node


def fold_constants(root: Node, statistics: QueryStatistics) -> Node:
if root.node_type == NodeType.LITERAL:
# if we're already a literal (constant), we can't fold
Expand All @@ -58,70 +79,70 @@ def fold_constants(root: Node, statistics: QueryStatistics) -> Node:
and root.right.node_type == NodeType.IDENTIFIER
and root.left.value == 0
):
# 0 * anything = 0
root.left.schema_column = root.schema_column
# 0 * anything = 0 (except NULL)
node = _build_if_not_null_node(root, root.right, build_literal_node(0))
statistics.optimization_constant_fold_reduce += 1
return root.left # 0
return node
if (
root.value == "Multiply"
and root.right.node_type == NodeType.LITERAL
and root.left.node_type == NodeType.IDENTIFIER
and root.right.value == 0
):
# anything * 0 = 0
root.right.schema_column = root.schema_column
# anything * 0 = 0 (except NULL)
node = _build_if_not_null_node(root, root.left, build_literal_node(0))
statistics.optimization_constant_fold_reduce += 1
return root.right # 0
return node
if (
root.value == "Multiply"
and root.left.node_type == NodeType.LITERAL
and root.right.node_type == NodeType.IDENTIFIER
and root.left.value == 1
):
# 1 * anything = anything
root.right.query_column = root.query_column
# 1 * anything = anything (except NULL)
node = _build_passthru_node(root, root.right)
statistics.optimization_constant_fold_reduce += 1
return root.right # anything
return node
if (
root.value == "Multiply"
and root.right.node_type == NodeType.LITERAL
and root.left.node_type == NodeType.IDENTIFIER
and root.right.value == 1
):
# anything * 1 = anything
root.left.query_column = root.query_column
# anything * 1 = anything (except NULL)
node = _build_passthru_node(root, root.left)
statistics.optimization_constant_fold_reduce += 1
return root.left # anything
return node
if (
root.value in "Plus"
and root.left.node_type == NodeType.LITERAL
and root.right.node_type == NodeType.IDENTIFIER
and root.left.value == 0
):
# 0 + anything = anything
root.right.query_column = root.query_column
# 0 + anything = anything (except NULL)
node = _build_passthru_node(root, root.right)
statistics.optimization_constant_fold_reduce += 1
return root.right # anything
return node
if (
root.value in ("Plus", "Minus")
and root.right.node_type == NodeType.LITERAL
and root.left.node_type == NodeType.IDENTIFIER
and root.right.value == 0
):
# anything +/- 0 = anything
root.left.query_column = root.query_column
# anything +/- 0 = anything (except NULL)
node = _build_passthru_node(root, root.left)
statistics.optimization_constant_fold_reduce += 1
return root.left # anything
return node
if (
root.value == "Divide"
and root.right.node_type == NodeType.LITERAL
and root.left.node_type == NodeType.IDENTIFIER
and root.right.value == 1
):
# anything / 1 = anything
root.left.schema_column = root.schema_column
# anything / 1 = anything (except NULL)
node = _build_passthru_node(root, root.left)
statistics.optimization_constant_fold_reduce += 1
return root.left # anything
return node

if root.node_type == NodeType.COMPARISON_OPERATOR:
# anything LIKE '%' is true for non null values
Expand All @@ -138,6 +159,7 @@ def fold_constants(root: Node, statistics: QueryStatistics) -> Node:
node.schema_column = root.schema_column
node.centre = root.left
node.query_column = root.query_column
node.alias = root.alias
statistics.optimization_constant_fold_reduce += 1
return node

Expand All @@ -154,75 +176,76 @@ def fold_constants(root: Node, statistics: QueryStatistics) -> Node:
and root.left.type == OrsoTypes.BOOLEAN
and root.left.value
):
# True OR anything is True
root.left.schema_column = root.schema_column
# True OR anything is True (including NULL)
node = _build_passthru_node(root, root.left)
statistics.optimization_constant_fold_boolean_reduce += 1
return root.left
return node
if (
root.right.node_type == NodeType.LITERAL
and root.right.type == OrsoTypes.BOOLEAN
and root.right.value
):
# anything OR True is True
root.right.schema_column = root.schema_column
# anything OR True is True (including NULL)
node = _build_passthru_node(root, root.right)
statistics.optimization_constant_fold_boolean_reduce += 1
return root.right
return node
if (
root.left.node_type == NodeType.LITERAL
and root.left.type == OrsoTypes.BOOLEAN
and not root.left.value
):
# False OR anything is anything
root.right.schema_column = root.schema_column
# False OR anything is anything (except NULL)
node = _build_passthru_node(root, root.right)
statistics.optimization_constant_fold_boolean_reduce += 1
return root.right
return node
if (
root.right.node_type == NodeType.LITERAL
and root.right.type == OrsoTypes.BOOLEAN
and not root.right.value
):
# anything OR False is anything
root.left.schema_column = root.schema_column
# anything OR False is anything (except NULL)
node = _build_passthru_node(root, root.left)
statistics.optimization_constant_fold_boolean_reduce += 1
return root.left
return node

elif root.node_type == NodeType.AND:
if (
root.left.node_type == NodeType.LITERAL
and root.left.type == OrsoTypes.BOOLEAN
and not root.left.value
):
# False AND anything is False
root.left.schema_column = root.schema_column
# False AND anything is False (including NULL)
node = _build_passthru_node(root, root.left)
statistics.optimization_constant_fold_boolean_reduce += 1
return root.left
return node
if (
root.right.node_type == NodeType.LITERAL
and root.right.type == OrsoTypes.BOOLEAN
and not root.right.value
):
# anything AND False is False
root.right.schema_column = root.schema_column
# anything AND False is False (including NULL)
node = _build_passthru_node(root, root.right)
statistics.optimization_constant_fold_boolean_reduce += 1
return root.right
return node
if (
root.left.node_type == NodeType.LITERAL
and root.left.type == OrsoTypes.BOOLEAN
and root.left.value
):
# True AND anything is anything
root.right.schema_column = root.schema_column
# True AND anything is anything (except NULL)
node = _build_passthru_node(root, root.right)
statistics.optimization_constant_fold_boolean_reduce += 1
return root.right
return node
if (
root.right.node_type == NodeType.LITERAL
and root.right.type == OrsoTypes.BOOLEAN
and root.right.value
):
# anything AND True is anything
root.left.schema_column = root.schema_column
# anything AND True is anything (except NULL)
node = _build_passthru_node(root, root.left)
node.type = OrsoTypes.BOOLEAN
statistics.optimization_constant_fold_boolean_reduce += 1
return root.left
return node

return root

Expand All @@ -248,6 +271,7 @@ def fold_constants(root: Node, statistics: QueryStatistics) -> Node:
statistics.optimization_constant_aggregation += 1
return root
if agg.value in ("AVG", "MIN", "MAX"):
# AVG, MIN, MAX of a constant is the constant
statistics.optimization_constant_aggregation += 1
return build_literal_node(agg.parameters[0].value, root, root.schema_column.type)

Expand Down Expand Up @@ -286,6 +310,20 @@ def visit(self, node: LogicalPlanNode, context: OptimizerContext) -> OptimizerCo
node.columns = [fold_constants(c, self.statistics) for c in node.columns]
context.optimized_plan[context.node_id] = node

# remove nesting in order by and group by clauses
if node.node_type == LogicalPlanStepType.Order:
new_order_by = []
for field, order in node.order_by:
while field.node_type == NodeType.NESTED:
field = field.centre
new_order_by.append((field, order))
node.order_by = new_order_by
context.optimized_plan[context.node_id] = node

if node.node_type == LogicalPlanStepType.AggregateAndGroup:
node.groups = [g.centre if g.node_type == NodeType.NESTED else g for g in node.groups]
context.optimized_plan[context.node_id] = node

return context

def complete(self, plan: LogicalPlan, context: OptimizerContext) -> LogicalPlan:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ def pattern_match(branch, alias: Optional[List[str]] = None, key=None):
value=key,
left=left,
right=right,
alias=alias,
)


Expand Down
Loading

0 comments on commit bad4b67

Please sign in to comment.