Skip to content

Commit

Permalink
Merge pull request #1489 from mabel-dev/#1486
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Mar 3, 2024
2 parents 4f04b6c + dce84cc commit 05dd9ff
Show file tree
Hide file tree
Showing 15 changed files with 421 additions and 173 deletions.
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 330
__build__ = 333

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 0 additions & 2 deletions opteryx/components/binder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@
- schema lookup and propagation (add columns and types, add aliases)
"""


from opteryx.components.binder.binder_visitor import BinderVisitor
from opteryx.components.binder.binding_context import BindingContext
from opteryx.components.logical_planner import LogicalPlan
Expand Down
8 changes: 5 additions & 3 deletions opteryx/components/binder/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from orso.schema import FlatColumn
from orso.schema import FunctionColumn
from orso.schema import RelationSchema
from orso.types import OrsoTypes

from opteryx.components.binder.operator_map import determine_type
from opteryx.exceptions import AmbiguousIdentifierError
from opteryx.exceptions import ColumnNotFoundError
from opteryx.exceptions import InvalidInternalStateError
Expand Down Expand Up @@ -308,11 +310,11 @@ def inner_binder(node: Node, context: Any) -> Tuple[Node, Any]:

elif node.value and node.value.startswith("AnyOp"):
# IMPROVE: check types here
schema_column = ExpressionColumn(name=column_name, type=0)
schema_column = ExpressionColumn(name=column_name, type=OrsoTypes.BOOLEAN)
node.schema_column = schema_column
elif node.value and node.value.startswith("AllOp"):
# IMPROVE: check types here
schema_column = ExpressionColumn(name=column_name, type=0)
schema_column = ExpressionColumn(name=column_name, type=OrsoTypes.BOOLEAN)
node.schema_column = schema_column
else:
# fmt:off
Expand All @@ -329,7 +331,7 @@ def inner_binder(node: Node, context: Any) -> Tuple[Node, Any]:
schema_column = ExpressionColumn(
name=column_name,
aliases=[node.alias] if node.alias else [],
type=0,
type=determine_type(node),
expression=node.value,
)
schemas["$derived"].columns.append(schema_column)
Expand Down
5 changes: 4 additions & 1 deletion opteryx/components/binder/binder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,10 @@ def visit_join(self, node: Node, context: BindingContext) -> Tuple[Node, Binding
node.unnest_target, found_source_relation = locate_identifier_in_loaded_schemas(
node.unnest_alias, context.schemas
)
if node.unnest_column.schema_column.type not in (0, OrsoTypes.ARRAY):
if node.unnest_column.schema_column.type not in (
OrsoTypes._MISSING_TYPE,
OrsoTypes.ARRAY,
):
from opteryx.exceptions import IncorrectTypeError

raise IncorrectTypeError("CROSS JOIN UNNEST requires an ARRAY type column.")
Expand Down
207 changes: 207 additions & 0 deletions opteryx/components/binder/operator_map.py

Large diffs are not rendered by default.

26 changes: 13 additions & 13 deletions opteryx/components/logical_planner/logical_planner_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from typing import Tuple

import numpy
import pyarrow
from orso.types import OrsoTypes

from opteryx import functions
Expand Down Expand Up @@ -141,6 +140,11 @@ def binary_op(branch, alias: Optional[List[str]] = None, key=None):
operator = branch["op"]
right = build(branch["right"])

if operator in ("PGRegexMatch", "SimilarTo"):
operator = "RLike"
if operator in ("PGRegexNotMatch", "NotSimilarTo"):
operator = "NotRLike"

operator_type = NodeType.COMPARISON_OPERATOR
if operator in BINARY_OPERATORS:
operator_type = NodeType.BINARY_OPERATOR
Expand Down Expand Up @@ -418,7 +422,7 @@ def literal_interval(branch, alias: Optional[List[str]] = None, key=None):

unit_index = parts.index(leading_unit)

month, day, nano = (0, 0, 0)
month, seconds = (0, 0)

for index, value in enumerate(values):
value = int(value)
Expand All @@ -428,21 +432,15 @@ def literal_interval(branch, alias: Optional[List[str]] = None, key=None):
if unit == "Month":
month += value
if unit == "Day":
day = value
seconds = value * 24 * 60 * 60
if unit == "Hour":
nano += value * 60 * 60 * 1000000000
seconds += value * 60 * 60
if unit == "Minute":
nano += value * 60 * 1000000000
seconds += value * 60
if unit == "Second":
nano += value * 1000000000
seconds += value

interval = pyarrow.MonthDayNano(
(
month,
day,
nano,
)
)
interval = (month, seconds)

return Node(NodeType.LITERAL, type=OrsoTypes.INTERVAL, value=interval, alias=alias)

Expand Down Expand Up @@ -529,6 +527,8 @@ def pattern_match(branch, alias: Optional[List[str]] = None, key=None):
negated = branch["negated"]
left = build(branch["expr"])
right = build(branch["pattern"])
if key in ("PGRegexMatch", "SimilarTo"):
key = "RLike"
if negated:
key = f"Not{key}"
return Node(
Expand Down
2 changes: 1 addition & 1 deletion opteryx/connectors/sql_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def read_dataset( # type:ignore
# DEBUG: log ("READ DATASET\n", str(query_builder))
# DEBUG: log ("PARAMETERS\n", parameters)
# Execution Options allows us to handle datasets larger than memory
result = conn.execution_options(stream_results=True, max_row_buffer=5000).execute(
result = conn.execution_options(stream_results=True, max_row_buffer=10000).execute(
text(str(query_builder)), parameters=parameters
)

Expand Down
Empty file.
121 changes: 121 additions & 0 deletions opteryx/custom_types/intervals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Tuple

import numpy
import pyarrow
import pyarrow.compute
from orso.types import OrsoTypes


def add_months_numpy(dates, months_to_add):
"""
Adds a specified number of months to dates in a numpy array, adjusting for end-of-month overflow.
Parameters:
- dates: np.ndarray of dates (numpy.datetime64)
- months_to_add: int, the number of months to add to each date
Returns:
- np.ndarray: Adjusted dates
"""
# Convert dates to 'M' (month) granularity for addition
months = dates.astype("datetime64[M]")

# Add months (broadcasts the scalar value across the array)
new_dates = months + numpy.timedelta64(months_to_add, "M")

# Calculate the last day of the new month for each date
last_day_of_new_month = new_dates + numpy.timedelta64(1, "M") - numpy.timedelta64(1, "D")

# Calculate the day of the month for each original date
day_of_month = dates - months

# Adjust dates that would overflow their new month
overflow_mask = day_of_month > (last_day_of_new_month - new_dates)
adjusted_dates = numpy.where(overflow_mask, last_day_of_new_month, new_dates + day_of_month)

return adjusted_dates.astype("datetime64[us]")


def _date_plus_interval(left, left_type, right, right_type, operator):
"""
Adds intervals to dates, utilizing integer arithmetic for performance improvements.
"""
signum = 1 if operator == "Plus" else -1
if left_type == OrsoTypes.INTERVAL:
left, right = right, left

months, seconds = right[0].as_py()

result = left.astype("datetime64[s]") + (seconds * signum)

# Handle months separately, requiring special logic
if months:
for index in range(len(result)):
result[index] = add_months_numpy(result[index], months * signum)

return result


def _simple_interval_op(left, left_type, right, right_type, operator):
from opteryx.third_party.pyarrow_ops.ops import _inner_filter_operations

left_months = pyarrow.compute.list_element(left, 0)
left_seconds = pyarrow.compute.list_element(left, 1)

right_months = pyarrow.compute.list_element(right, 0)
right_seconds = pyarrow.compute.list_element(right, 1)

if (
pyarrow.compute.any(pyarrow.compute.not_equal(left_months, 0)).as_py()
or pyarrow.compute.any(pyarrow.compute.not_equal(right_months, 0)).as_py()
):
from opteryx.exceptions import UnsupportedSyntaxError

raise UnsupportedSyntaxError("Cannot compare INTERVALs with MONTH or YEAR components.")

# months = _inner_filter_operations(left_months, operator, right_months)
# months_eq = _inner_filter_operations(left_months, "Eq", right_months)
seconds = _inner_filter_operations(left_seconds, operator, right_seconds)

# res = [milliseconds[i] if (months[i] or months_eq[i]) else False for i in range(len(months))]
return seconds


INTERVAL_KERNELS: Dict[Tuple[OrsoTypes, OrsoTypes, str], Optional[Callable]] = {
(OrsoTypes.INTERVAL, OrsoTypes.INTERVAL, "Plus"): _simple_interval_op,
(OrsoTypes.INTERVAL, OrsoTypes.INTERVAL, "Minus"): _simple_interval_op,
(OrsoTypes.INTERVAL, OrsoTypes.INTERVAL, "Eq"): _simple_interval_op,
(OrsoTypes.INTERVAL, OrsoTypes.INTERVAL, "NotEq"): _simple_interval_op,
(OrsoTypes.INTERVAL, OrsoTypes.INTERVAL, "Gt"): _simple_interval_op,
(OrsoTypes.INTERVAL, OrsoTypes.INTERVAL, "GtEq"): _simple_interval_op,
(OrsoTypes.INTERVAL, OrsoTypes.INTERVAL, "Lt"): _simple_interval_op,
(OrsoTypes.INTERVAL, OrsoTypes.INTERVAL, "LtEq"): _simple_interval_op,
(OrsoTypes.INTERVAL, OrsoTypes.TIMESTAMP, "Plus"): _date_plus_interval,
(OrsoTypes.INTERVAL, OrsoTypes.TIMESTAMP, "Minus"): _date_plus_interval,
(OrsoTypes.INTERVAL, OrsoTypes.DATE, "Plus"): _date_plus_interval,
(OrsoTypes.INTERVAL, OrsoTypes.DATE, "Minus"): _date_plus_interval,
(OrsoTypes.TIMESTAMP, OrsoTypes.INTERVAL, "Plus"): _date_plus_interval,
(OrsoTypes.TIMESTAMP, OrsoTypes.INTERVAL, "Minus"): _date_plus_interval,
(OrsoTypes.DATE, OrsoTypes.INTERVAL, "Plus"): _date_plus_interval,
(OrsoTypes.DATE, OrsoTypes.INTERVAL, "Minus"): _date_plus_interval,
# we need to type the outcome of calcs better
(0, OrsoTypes.INTERVAL, "Plus"): _simple_interval_op,
(0, OrsoTypes.INTERVAL, "Minus"): _simple_interval_op,
(0, OrsoTypes.INTERVAL, "Eq"): _simple_interval_op,
(0, OrsoTypes.INTERVAL, "NotEq"): _simple_interval_op,
(0, OrsoTypes.INTERVAL, "Gt"): _simple_interval_op,
(0, OrsoTypes.INTERVAL, "GtEq"): _simple_interval_op,
(0, OrsoTypes.INTERVAL, "Lt"): _simple_interval_op,
(0, OrsoTypes.INTERVAL, "LtEq"): _simple_interval_op,
(OrsoTypes.INTERVAL, 0, "Plus"): _simple_interval_op,
(OrsoTypes.INTERVAL, 0, "Minus"): _simple_interval_op,
(OrsoTypes.INTERVAL, 0, "Eq"): _simple_interval_op,
(OrsoTypes.INTERVAL, 0, "NotEq"): _simple_interval_op,
(OrsoTypes.INTERVAL, 0, "Gt"): _simple_interval_op,
(OrsoTypes.INTERVAL, 0, "GtEq"): _simple_interval_op,
(OrsoTypes.INTERVAL, 0, "Lt"): _simple_interval_op,
(OrsoTypes.INTERVAL, 0, "LtEq"): _simple_interval_op,
}
Loading

0 comments on commit 05dd9ff

Please sign in to comment.