diff --git a/opteryx/__version__.py b/opteryx/__version__.py index cc2f499bf..6ee2e14cb 100644 --- a/opteryx/__version__.py +++ b/opteryx/__version__.py @@ -1,4 +1,4 @@ -__build__ = 323 +__build__ = 324 # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/opteryx/connectors/base/base_connector.py b/opteryx/connectors/base/base_connector.py index fa084a49c..6aba5b5a1 100644 --- a/opteryx/connectors/base/base_connector.py +++ b/opteryx/connectors/base/base_connector.py @@ -25,7 +25,7 @@ MIN_CHUNK_SIZE: int = 500 INITIAL_CHUNK_SIZE: int = 500 -DEFAULT_MORSEL_SIZE: int = 1024 * 1024 +DEFAULT_MORSEL_SIZE: int = 16 * 1024 * 1024 class BaseConnector: diff --git a/opteryx/connectors/sql_connector.py b/opteryx/connectors/sql_connector.py index fadc6670e..907c012e3 100644 --- a/opteryx/connectors/sql_connector.py +++ b/opteryx/connectors/sql_connector.py @@ -77,6 +77,12 @@ class SqlConnector(BaseConnector, PredicatePushable): "LtEq": "<=", "Like": "LIKE", "NotLike": "NOT LIKE", + "IsTrue": "IS TRUE", + "IsNotTrue": "IS NOT TRUE", + "IsFalse": "IS FALSE", + "IsNotFalse": "IS NOT FALSE", + "IsNull": "IS NULL", + "IsNotNull": "IS NOT NULL", } def __init__(self, *args, connection: str = None, engine=None, **kwargs): @@ -104,6 +110,11 @@ def __init__(self, *args, connection: str = None, engine=None, **kwargs): self.schema = None # type: ignore self.metadata = MetaData() + def can_push(self, operator: Node, types: set = None) -> bool: + if super().can_push(operator, types): + return True + return operator.condition.node_type == NodeType.UNARY_OPERATOR + def read_dataset( # type:ignore self, *, @@ -134,14 +145,20 @@ def read_dataset( # type:ignore # Update SQL if we've pushed predicates parameters: dict = {} for predicate in predicates: - left_operand = predicate.left - right_operand = predicate.right - operator = self.OPS_XLAT[predicate.value] + if predicate.node_type == NodeType.UNARY_OPERATOR: + operand = predicate.centre.current_name + operator = self.OPS_XLAT[predicate.value] + + query_builder.WHERE(f"{operand} {operator}") + else: + left_operand = predicate.left + right_operand = predicate.right + operator = self.OPS_XLAT[predicate.value] - left_value, parameters = _handle_operand(left_operand, parameters) - right_value, parameters = _handle_operand(right_operand, parameters) + left_value, parameters = _handle_operand(left_operand, parameters) + right_value, parameters = _handle_operand(right_operand, parameters) - query_builder.WHERE(f"{left_value} {operator} {right_value}") + query_builder.WHERE(f"{left_value} {operator} {right_value}") # Use orso as an intermediatary, it's row-based so is well suited to processing # records coming back from a SQL query, and it has a well-optimized to arrow @@ -154,7 +171,7 @@ def read_dataset( # type:ignore # DEBUG: log ("READ DATASET\n", str(query_builder)) # DEBUG: log ("PARAMETERS\n", parameters) # Execution Options allows us to handle datasets larger than memory - result = conn.execution_options(stream_results=True, max_row_buffer=500).execute( + result = conn.execution_options(stream_results=True, max_row_buffer=5000).execute( text(str(query_builder)), parameters=parameters ) @@ -171,7 +188,7 @@ def read_dataset( # type:ignore # Dynamically adjust chunk size based on the data size, we start by downloading # 500 records to get an idea of the row size, assuming these 500 are - # representative, we work out how many rows fit into 8Mb. + # representative, we work out how many rows fit into 16Mb (check setting). # Don't keep recalculating, this is not a cheap operation and it's predicting # the future so isn't going to ever be 100% correct if self.chunk_size == INITIAL_CHUNK_SIZE and morsel.nbytes() > 0: diff --git a/tests/plan_optimization/test_predicate_pushdown_postgres.py b/tests/plan_optimization/test_predicate_pushdown_postgres.py index 276835212..9edeff642 100644 --- a/tests/plan_optimization/test_predicate_pushdown_postgres.py +++ b/tests/plan_optimization/test_predicate_pushdown_postgres.py @@ -7,8 +7,11 @@ sys.path.insert(1, os.path.join(sys.path[0], "../..")) +import time +import pytest import opteryx from opteryx.connectors import SqlConnector +from opteryx.utils.formatter import format_sql from tests.tools import is_arm, is_mac, is_version, is_windows, skip_if POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD") @@ -22,48 +25,76 @@ connection=f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@trumpet.db.elephantsql.com/{POSTGRES_USER}", ) +test_cases = [ + ("SELECT * FROM pg.planets WHERE gravity <= 3.7", 3, 3), + ("SELECT * FROM pg.planets WHERE name != 'Earth'", 8, 8), + ("SELECT * FROM pg.planets WHERE name != 'E\"arth'", 9, 9), + ("SELECT * FROM pg.planets WHERE gravity != 3.7", 7, 7), + ("SELECT * FROM pg.planets WHERE gravity < 3.7", 1, 1), + ("SELECT * FROM pg.planets WHERE gravity > 3.7", 6, 6), + ("SELECT * FROM pg.planets WHERE gravity >= 3.7", 8, 8), + ("SELECT * FROM pg.planets WHERE name LIKE '%a%'", 4, 4), + ("SELECT * FROM pg.planets WHERE id > gravity", 2, 2), + ("SELECT * FROM pg.planets WHERE surface_pressure IS NULL", 4, 4), +] + # skip to reduce contention @skip_if(is_arm() or is_windows() or is_mac() or not is_version("3.10")) -def test_predicate_pushdown_postgres_other(): - res = opteryx.query("SELECT * FROM pg.planets WHERE gravity <= 3.7") - assert res.rowcount == 3, res.rowcount - assert res.stats.get("rows_read", 0) == 3, res.stats - - res = opteryx.query("SELECT * FROM pg.planets WHERE name != 'Earth'") - assert res.rowcount == 8, res.rowcount - assert res.stats.get("rows_read", 0) == 8, res.stats - - res = opteryx.query("SELECT * FROM pg.planets WHERE name != 'E\"arth'") - assert res.rowcount == 9, res.rowcount - assert res.stats.get("rows_read", 0) == 9, res.stats - - res = opteryx.query("SELECT * FROM pg.planets WHERE gravity != 3.7") - assert res.rowcount == 7, res.rowcount - assert res.stats.get("rows_read", 0) == 7, res.stats - - res = opteryx.query("SELECT * FROM pg.planets WHERE gravity < 3.7") - assert res.rowcount == 1, res.rowcount - assert res.stats.get("rows_read", 0) == 1, res.stats - - res = opteryx.query("SELECT * FROM pg.planets WHERE gravity > 3.7") - assert res.rowcount == 6, res.rowcount - assert res.stats.get("rows_read", 0) == 6, res.stats - - res = opteryx.query("SELECT * FROM pg.planets WHERE gravity >= 3.7") - assert res.rowcount == 8, res.rowcount - assert res.stats.get("rows_read", 0) == 8, res.stats - - res = opteryx.query("SELECT * FROM pg.planets WHERE name LIKE '%a%'") - assert res.rowcount == 4, res.rowcount - assert res.stats.get("rows_read", 0) == 4, res.stats - - res = opteryx.query("SELECT * FROM pg.planets WHERE id > gravity") - assert res.rowcount == 2, res.rowcount - assert res.stats.get("rows_read", 0) == 2, res.stats +@pytest.mark.parametrize("statement,expected_rowcount,expected_rows_read", test_cases) +def test_predicate_pushdown_postgres_parameterized( + statement, expected_rowcount, expected_rows_read +): + res = opteryx.query(statement) + assert res.rowcount == expected_rowcount, f"Expected {expected_rowcount}, got {res.rowcount}" + assert ( + res.stats.get("rows_read", 0) == expected_rows_read + ), f"Expected {expected_rows_read}, got {res.stats.get('rows_read', 0)}" if __name__ == "__main__": # pragma: no cover - from tests.tools import run_tests - - run_tests() + import shutil + + from tests.tools import trunc_printable + + start_suite = time.monotonic_ns() + passed = 0 + failed = 0 + + width = shutil.get_terminal_size((80, 20))[0] - 15 + + print(f"RUNNING BATTERY OF {len(test_cases)} TESTS") + for index, (statement, returned_rows, read_rows) in enumerate(test_cases): + print( + f"\033[38;2;255;184;108m{(index + 1):04}\033[0m" + f" {trunc_printable(format_sql(statement), width - 1)}", + end="", + flush=True, + ) + try: + start = time.monotonic_ns() + test_predicate_pushdown_postgres_parameterized(statement, returned_rows, read_rows) + print( + f"\033[38;2;26;185;67m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms\033[0m ✅", + end="", + ) + passed += 1 + if failed > 0: + print(" \033[0;31m*\033[0m") + else: + print() + except Exception as err: + print(f"\033[0;31m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms ❌ *\033[0m") + print(">", err) + failed += 1 + + print("--- ✅ \033[0;32mdone\033[0m") + + if failed > 0: + print("\n\033[38;2;139;233;253m\033[3mFAILURES\033[0m") + + print( + f"\n\033[38;2;139;233;253m\033[3mCOMPLETE\033[0m ({((time.monotonic_ns() - start_suite) / 1e9):.2f} seconds)\n" + f" \033[38;2;26;185;67m{passed} passed ({(passed * 100) // (passed + failed)}%)\033[0m\n" + f" \033[38;2;255;121;198m{failed} failed\033[0m" + ) diff --git a/tests/plan_optimization/test_predicate_pushdown_sqlite.py b/tests/plan_optimization/test_predicate_pushdown_sqlite.py index 514deeff6..85a858b29 100644 --- a/tests/plan_optimization/test_predicate_pushdown_sqlite.py +++ b/tests/plan_optimization/test_predicate_pushdown_sqlite.py @@ -53,6 +53,20 @@ def test_predicate_pushdowns_sqlite_eq(): assert cur.rowcount == 1, cur.rowcount assert cur.stats.get("rows_read", 0) == 2, cur.stats + cur = conn.cursor() + cur.execute("SELECT * FROM sqlite.planets WHERE surfacePressure IS NULL;") + # We push unary ops to SQL + assert cur.rowcount == 4, cur.rowcount + assert cur.stats.get("rows_read", 0) == 4, cur.stats + + cur = conn.cursor() + cur.execute( + "SELECT * FROM sqlite.planets WHERE orbitalInclination IS FALSE AND name IN ('Earth', 'Mars');" + ) + # We push unary ops to SQL + assert cur.rowcount == 1, cur.rowcount + assert cur.stats.get("rows_read", 0) == 1, cur.stats + conn.close() diff --git a/tests/sql_battery/test_shapes_and_errors_battery.py b/tests/sql_battery/test_shapes_and_errors_battery.py index 312e91765..73b0cc8ff 100644 --- a/tests/sql_battery/test_shapes_and_errors_battery.py +++ b/tests/sql_battery/test_shapes_and_errors_battery.py @@ -1556,7 +1556,7 @@ def test_sql_battery(statement, rows, columns, exception): print(f"RUNNING BATTERY OF {len(STATEMENTS)} SHAPE TESTS") for index, (statement, rows, cols, err) in enumerate(STATEMENTS): - start = time.monotonic_ns() + printable = statement if hasattr(printable, "decode"): printable = printable.decode() @@ -1567,6 +1567,7 @@ def test_sql_battery(statement, rows, columns, exception): flush=True, ) try: + start = time.monotonic_ns() test_sql_battery(statement, rows, cols, err) print( f"\033[38;2;26;185;67m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms\033[0m ✅",