diff --git a/opteryx/functions/__init__.py b/opteryx/functions/__init__.py index 02496d5c9..e0cfa6197 100644 --- a/opteryx/functions/__init__.py +++ b/opteryx/functions/__init__.py @@ -165,13 +165,27 @@ def _raise_exception(text): raise UnsupportedSyntaxError(text) -def _coalesce(*args): - """wrap the pyarrow coalesce function because NaN != None""" - coerced = [] - for arg in args: - # there's no reasonable test to see if we need to do this before we start - coerced.append([None if value != value else value for value in arg]) # nosemgrep - return compute.coalesce(*coerced) +def _coalesce(*arrays): + """ + Element-wise coalesce function for multiple numpy arrays. + Selects the first non-None item in each row across the input arrays. + + Parameters: + arrays: tuple of numpy arrays + + Returns: + numpy array with coalesced values + """ + # Start with an array full of None values + result = numpy.array(arrays[0], dtype=object) + + mask = result == None + + for arr in arrays[1:]: + mask = numpy.array([None if value != value else value for value in result]) == None + numpy.copyto(result, arr, where=mask) + + return result # fmt:off diff --git a/tests/sql_battery/test_shapes_and_errors_battery.py b/tests/sql_battery/test_shapes_and_errors_battery.py index f800d0435..28ec75705 100644 --- a/tests/sql_battery/test_shapes_and_errors_battery.py +++ b/tests/sql_battery/test_shapes_and_errors_battery.py @@ -790,11 +790,10 @@ ("SELECT s.* FROM $planets AS s INNER JOIN $planets AS p USING (id, name)", 9, 20, None), ("SELECT p.* FROM $planets AS s INNER JOIN $planets AS p USING (id, name)", 9, 20, None), ("SELECT id, name FROM $planets AS s INNER JOIN $planets AS p USING (id, name)", 9, 2, None), -] -A = [ + ("SELECT DATE_TRUNC('month', birth_date) FROM $astronauts", 357, 1, None), - ("SELECT DISTINCT * FROM (SELECT DATE_TRUNC('year', birth_date) AS BIRTH_YEAR FROM $astronauts)", 54, 1, None), - ("SELECT DISTINCT * FROM (SELECT DATE_TRUNC('month', birth_date) AS BIRTH_YEAR_MONTH FROM $astronauts)", 247, 1, None), + ("SELECT DISTINCT * FROM (SELECT DATE_TRUNC('year', birth_date) AS BIRTH_YEAR FROM $astronauts) AS SQ", 54, 1, None), + ("SELECT DISTINCT * FROM (SELECT DATE_TRUNC('month', birth_date) AS BIRTH_YEAR_MONTH FROM $astronauts) AS SQ", 247, 1, None), ("SELECT time_bucket(birth_date, 10, 'year') AS decade, count(*) from $astronauts GROUP BY time_bucket(birth_date, 10, 'year')", 6, 2, None), ("SELECT time_bucket(birth_date, 6, 'month') AS half, count(*) from $astronauts GROUP BY time_bucket(birth_date, 6, 'month')", 97, 2, None), @@ -802,7 +801,8 @@ ("SELECT graduate_major, undergraduate_major FROM $astronauts WHERE COALESCE(graduate_major, undergraduate_major) = 'Aeronautical Engineering'", 41, 2, None), ("SELECT COALESCE(death_date, '2030-01-01') FROM $astronauts", 357, 1, None), ("SELECT * FROM $astronauts WHERE COALESCE(death_date, '2030-01-01') < '2000-01-01'", 30, 19, None), - +] +A = [ ("SELECT SEARCH(name, 'al'), name FROM $satellites", 177, 2, None), ("SELECT name FROM $satellites WHERE SEARCH(name, 'al')", 18, 1, None), ("SELECT SEARCH(missions, 'Apollo 11'), missions FROM $astronauts", 357, 2, None),