Skip to content

Commit

Permalink
Merge pull request #1677 from mabel-dev/#1676
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored May 23, 2024
2 parents ff861a5 + cae876c commit b2fb25e
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 15 deletions.
52 changes: 52 additions & 0 deletions .github/workflows/regression_suite_mac_ARM.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: Regression Suite (Mac ARM)

on:
push

jobs:
regression_matrix:
strategy:
max-parallel: 4
matrix:
python-version: ['3.10', '3.11']
os: ['macos-14']
runs-on: ${{ matrix.os }}
steps:

- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }} x64
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
architecture: x64

- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
override: true

- name: Install Requirements
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade wheel numpy cython setuptools_rust pytest
python -m pip install --upgrade -r requirements.txt
python -m pip install --upgrade -r tests/requirements_arm.txt
python setup.py build_ext --inplace
- name: "Authenticate to Google Cloud"
uses: google-github-actions/auth@v1
with:
credentials_json: '${{ secrets.GCP_KEY }}'

- name: Run Regression Tests
run: python -m pytest --color=yes
env:
GCP_PROJECT_ID: mabeldev
MYSQL_USER: '${{ secrets.MYSQL_USER }}'
MYSQL_PASSWORD: '${{ secrets.MYSQL_PASSWORD }}'
POSTGRES_USER: '${{ secrets.POSTGRES_USER }}'
POSTGRES_PASSWORD: '${{ secrets.POSTGRES_PASSWORD }}'
MEMCACHED_SERVER: 'localhost:11211'
4 changes: 2 additions & 2 deletions .github/workflows/regression_suite_mac_x86.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Regression Suite (Mac)
name: Regression Suite (Mac x86)

on:
push
Expand All @@ -9,7 +9,7 @@ jobs:
max-parallel: 4
matrix:
python-version: ['3.10', '3.11']
os: ['macos-latest']
os: ['macos-13']
runs-on: ${{ matrix.os }}
steps:

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ jobs:
path: dist

build-macos:
runs-on: macos-latest
runs-on: macos-13
strategy:
max-parallel: 4
matrix:
Expand Down
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 514
__build__ = 517

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
27 changes: 16 additions & 11 deletions opteryx/planner/sql_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,19 @@ def sql_parts(string):
+ r")",
re.IGNORECASE,
)
quoted_strings = re.compile(r"(\"(?:\\.|[^\"])*\"|\'(?:\\.|[^\'])*\'|`(?:\\.|[^`])*`)")
# Match ", ', b", b', `
# We match b prefixes separately after the non-prefix versions
quoted_strings = re.compile(
r"(\"(?:\\.|[^\"])*\"|\'(?:\\.|[^\'])*\'|\b[bB]\"(?:\\.|[^\"])*\"|\b[bB]\'(?:\\.|[^\'])*\'|`(?:\\.|[^`])*`)"
)

parts = []
for part in quoted_strings.split(string):
if part and part[-1] in ("'", '"', "`"):
parts.append(part)
if part[0] in ("b", "B"):
parts.append(f"blob({part[1:]})")
else:
parts.append(part)
else:
for subpart in keywords.split(part):
subpart = subpart.strip()
Expand Down Expand Up @@ -230,15 +237,13 @@ def _temporal_extration_state_machine(parts: List[str]) -> Tuple[List[Tuple[str,
Returns:
Tuple containing two lists, first with the temporal filters, second with the remaining SQL parts.
"""
"""
we use a four state machine to extract the temporal information from the query
and maintain the relation to filter information.
We separate out the two key parts of the algorithm, first we determine the state,
then we work out if the state transition means we should do something.
We're essentially using a bit mask to record state and transitions.
"""
# We use a four state machine to extract the temporal information from the query
# and maintain the relation to filter information.
#
# We separate out the two key parts of the algorithm, first we determine the state,
# then we work out if the state transition means we should do something.
#
# We're essentially using a bit mask to record state and transitions.

state = WAITING
relation = ""
Expand Down
98 changes: 98 additions & 0 deletions tests/query_planner/test_b_strings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import pytest
import sys

sys.path.insert(1, os.path.join(sys.path[0], "../.."))

from opteryx.planner.sql_rewriter import sql_parts

# Define the test cases as a list of (input, expected_output) tuples
# fmt:off
test_cases = [
# Contrived cases
("This is a test string with b'abc' and B\"def\".", "This is a test string with blob('abc') and blob(\"def\")."),
("b'123' should become blob('123')", "blob('123') should become blob('123')"),
('B"xyz" should become blob("xyz")', 'blob("xyz") should become blob("xyz")'),
("Mix of b'one' and B\"two\"", "Mix of blob('one') and blob(\"two\")"),
("No prefixed strings here.", "No prefixed strings here."),
("B'' and b\"\" should be handled.", "blob('') and blob(\"\") should be handled."),

# Basic replacements
("SELECT * FROM table WHERE column = b'abc';", "SELECT * FROM table WHERE column = blob('abc');"),
("INSERT INTO table (column) VALUES (B\"def\");", "INSERT INTO table (column) VALUES (blob(\"def\"));"),
("UPDATE table SET column = b'123' WHERE id = 1;", "UPDATE table SET column = blob('123') WHERE id = 1;"),

# Mixed cases
("SELECT * FROM table WHERE column = B'xyz' OR column = b\"uvw\";", "SELECT * FROM table WHERE column = blob('xyz') OR column = blob(\"uvw\");"),
("INSERT INTO table (col1, col2) VALUES (b'val1', B\"val2\");", "INSERT INTO table (col1, col2) VALUES (blob('val1'), blob(\"val2\"));"),

# Edge cases
("SELECT * FROM table WHERE column = b'';", "SELECT * FROM table WHERE column = blob('');"),
("SELECT * FROM table WHERE column = B\"\";", "SELECT * FROM table WHERE column = blob(\"\");"),
("SELECT b'abc' AS col1, B'def' AS col2 FROM table;", "SELECT blob('abc') AS col1, blob('def') AS col2 FROM table;"),

# No replacements
("SELECT * FROM table WHERE column = 'abc';", "SELECT * FROM table WHERE column = 'abc';"),
("SELECT * FROM table WHERE column = \"def\";", "SELECT * FROM table WHERE column = \"def\";"),
("SELECT * FROM table WHERE column = '';", "SELECT * FROM table WHERE column = '';"),
("SELECT * FROM table WHERE column = \"\";", "SELECT * FROM table WHERE column = \"\";"),

# Complex statements
("SELECT * FROM table1 JOIN table2 ON table1.col = table2.col WHERE table1.col = b'join' AND table2.col = B\"join\";", "SELECT * FROM table1 JOIN table2 ON table1.col = table2.col WHERE table1.col = blob('join') AND table2.col = blob(\"join\");"),
("WITH cte AS (SELECT b'cte' AS col FROM table) SELECT * FROM cte WHERE col = B'cte';", "WITH cte AS (SELECT blob('cte') AS col FROM table) SELECT * FROM cte WHERE col = blob('cte');"),

# Specific cases
("SELECT * FROM table WHERE column = blob'a';", "SELECT * FROM table WHERE column = blob'a';"),
("SELECT * FROM table WHERE column = blob(\"a\");", "SELECT * FROM table WHERE column = blob(\"a\");"),
("SELECT * FROM table WHERE column = blob('a');", "SELECT * FROM table WHERE column = blob('a');"),
("SELECT * FROM table WHERE column = b'abc' AND function_call(b'xyz');", "SELECT * FROM table WHERE column = blob('abc') AND function_call(blob('xyz'));"),

# failed case
("SELECT * FROM $satellites WHERE (((id = 5 OR (10<11)) AND ('a'='b')) OR (name = 'Europa' AND (TRUE AND (11=11))));", "SELECT * FROM $satellites WHERE (((id = 5 OR (10<11)) AND ('a'='b')) OR (name = 'Europa' AND (TRUE AND (11=11)))) ;"),

# complex quotes
("SELECT * FROM table WHERE column = 'This is a ''test'' string';", "SELECT * FROM table WHERE column = 'This is a ''test'' string';"),
("SELECT * FROM table WHERE column = \"He said, \\\"Hello, World!\\\"\";", "SELECT * FROM table WHERE column = \"He said, \\\"Hello, World!\\\"\";"),
("SELECT * FROM table WHERE column = 'Single quote within '' single quotes';", "SELECT * FROM table WHERE column = 'Single quote within '' single quotes';"),
("SELECT * FROM table WHERE column = \"Double quote within \\\" double quotes\";", "SELECT * FROM table WHERE column = \"Double quote within \\\" double quotes\";"),
("SELECT * FROM table WHERE column = `Backticks are used for column names`;", "SELECT * FROM table WHERE column = `Backticks are used for column names`;"),
("SELECT * FROM table WHERE column = 'Multiple ''single quotes'' in one string';", "SELECT * FROM table WHERE column = 'Multiple ''single quotes'' in one string';"),
("SELECT * FROM table WHERE column = \"Multiple \\\"double quotes\\\" in one string\";", "SELECT * FROM table WHERE column = \"Multiple \\\"double quotes\\\" in one string\";"),
("SELECT * FROM table WHERE column = 'Combination of ''single'' and \"double\" quotes';", "SELECT * FROM table WHERE column = 'Combination of ''single'' and \"double\" quotes';"),
("SELECT * FROM table WHERE column = 'String with newline\ncharacter';", "SELECT * FROM table WHERE column = 'String with newline\ncharacter';")
]
# fmt:on


@pytest.mark.parametrize("input_text, expected_output", test_cases)
def test_replace_b_strings(input_text, expected_output):
assert " ".join(sql_parts(input_text)).replace(" ", "") == expected_output.replace(" ", "")


if __name__ == "__main__": # pragma: no cover

"""
Running in the IDE we do some formatting - it's not functional but helps
when reading the outputs.
"""

import shutil
import time

width = shutil.get_terminal_size((80, 20))[0] - 15

print(f"RUNNING BATTERY OF {len(test_cases)} B STRINGS")
for index, (case, expected) in enumerate(test_cases):
start = time.monotonic_ns()
print(
f"\033[0;36m{(index + 1):04}\033[0m {case[0:width - 1].ljust(width)}",
end="",
)
if " ".join(sql_parts(case)).replace(" ", "") == expected.replace(" ", ""):
print(f"\033[0;32m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms\033[0m ✅")
else:
print(f"\033[0;31m{str(int((time.monotonic_ns() - start)/1e6)).rjust(4)}ms\033[0m ❌")
print("Expected:", expected)
print("Recieved:", " ".join(sql_parts(case)))

print("--- ✅ \033[0;32mdone\033[0m")
4 changes: 4 additions & 0 deletions tests/sql_battery/test_shapes_and_errors_battery.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@
("SELECT CAST('2022-01-0' || planetId::VARCHAR AS TIMESTAMP) FROM $satellites", 177, 1, None),
("SELECT planetId::INTEGER FROM $satellites", 177, 1, None),
("SELECT planetId::DOUBLE FROM $satellites", 177, 1, None),
("SELECT 1::double", 1, 1, None),
("SELECT TRY_CAST(planetId AS BOOLEAN) FROM $satellites", 177, 1, None),
("SELECT TRY_CAST(planetId AS VARCHAR) FROM $satellites", 177, 1, None),
("SELECT TRY_CAST(planetId AS TIMESTAMP) FROM $satellites", 177, 1, None),
Expand All @@ -471,6 +472,9 @@
("SELECT TRY_CAST(planetId AS DECIMAL) AS VALUE FROM $satellites", 177, 1, None),
("SELECT * FROM $planets WHERE id = GET(STRUCT('{\"a\":1,\"b\":\"c\"}'), 'a')", 1, 20, None),
# ("SELECT * FROM $planets WHERE id = STRUCT('{\"a\":1,\"b\":\"c\"}')->'a'", 1, 20, None),
("SELECT b'binary'", 1, 1, None),
("SELECT B'binary'", 1, 1, None),
("SELECT * FROM $planets WHERE name = b'Earth';", 1, 20, None),

("SELECT PI()", 1, 1, None),
("SELECT E()", 1, 1, None),
Expand Down

0 comments on commit b2fb25e

Please sign in to comment.