Skip to content

Commit

Permalink
[BUGFIX] UnexpectedRowsExpectation - Unable to use {batch} keywor…
Browse files Browse the repository at this point in the history
…d with partitioner for some backends (#10721)
  • Loading branch information
NathanFarmer authored Dec 4, 2024
1 parent ed9dec0 commit 7b93b5d
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 98 deletions.
68 changes: 9 additions & 59 deletions great_expectations/expectations/metrics/query_metric_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,43 +75,6 @@ def _get_query_from_metric_value_kwargs(cls, metric_value_kwargs: dict) -> str:

return query

@classmethod
def _get_query_string_with_substituted_batch_parameters(
cls, query: str, batch_subquery: sa.sql.Subquery | sa.sql.Alias
) -> str:
"""Specifying a runtime query string returns the active batch as a Subquery or Alias type
There is no object-based way to apply the subquery alias to columns in the SELECT and
WHERE clauses. Instead, we extract the subquery parameters from the batch selectable
and inject them into the SQL string.
Raises:
MissingElementError if the batch_subquery.selectable does not have an element
for which to extract query parameters.
"""

try:
froms = batch_subquery.selectable.element.get_final_froms() # type: ignore[attr-defined] # possible AttributeError handled
try:
batch_table = froms[0].name
except AttributeError:
batch_table = str(froms[0])
batch_filter = str(batch_subquery.selectable.element.whereclause) # type: ignore[attr-defined] # possible AttributeError handled
except (AttributeError, IndexError) as e:
raise MissingElementError() from e

unfiltered_query = query.format(batch=batch_table)

if "WHERE" in query.upper():
query = unfiltered_query.replace("WHERE", f"WHERE {batch_filter} AND")
elif "GROUP BY" in query.upper():
query = unfiltered_query.replace("GROUP BY", f"WHERE {batch_filter} GROUP BY")
elif "ORDER BY" in query.upper():
query = unfiltered_query.replace("ORDER BY", f"WHERE {batch_filter} ORDER BY")
else:
query = unfiltered_query + f" WHERE {batch_filter}"

return query

@classmethod
def _get_parameters_dict_from_query_parameters(
cls, query_parameters: Optional[QueryParameters]
Expand All @@ -137,30 +100,17 @@ def _get_substituted_batch_subquery_from_query_and_batch_selectable(

if isinstance(batch_selectable, sa.Table):
query = query.format(batch=batch_selectable, **parameters)
elif isinstance(batch_selectable, get_sqlalchemy_subquery_type()):
if (
not parameters
and execution_engine.dialect_name in cls.dialect_columns_require_subquery_aliases
):
try:
query = cls._get_query_string_with_substituted_batch_parameters(
query=query,
batch_subquery=batch_selectable,
)
except MissingElementError:
# if we are unable to extract the subquery parameters,
# we fall back to the default behavior for all dialects
batch = batch_selectable.compile(compile_kwargs={"literal_binds": True})
query = query.format(batch=f"({batch})", **parameters)
else:
batch = batch_selectable.compile(compile_kwargs={"literal_binds": True})
query = query.format(batch=f"({batch})", **parameters)
elif isinstance(
batch_selectable, sa.sql.Select
): # Specifying a row_condition returns the active batch as a Select object
# requiring compilation & aliasing when formatting the parameterized query
batch_selectable, (sa.sql.Select, get_sqlalchemy_subquery_type())
): # specifying a row_condition returns the active batch as a Select
# specifying an unexpected_rows_query returns the active batch as a Subquery or Alias
# this requires compilation & aliasing when formatting the parameterized query
batch = batch_selectable.compile(compile_kwargs={"literal_binds": True})
query = query.format(batch=f"({batch}) AS subselect", **parameters)
# all join queries require the user to have taken care of aliasing themselves
if "JOIN" in query.upper():
query = query.format(batch=f"({batch})", **parameters)
else:
query = query.format(batch=f"({batch}) AS subselect", **parameters)
else:
query = query.format(batch=f"({batch_selectable})", **parameters)

Expand Down
41 changes: 2 additions & 39 deletions tests/expectations/metrics/test_metric_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,44 +339,6 @@ def _spark(
assert "query.custom_metric" in mock_registry._registered_metrics


@pytest.mark.unit
@pytest.mark.parametrize(
"input_query,expected_query",
[
(
"SELECT * FROM {batch}",
"SELECT * FROM iris WHERE datetime_column = '01/12/2024'",
),
(
"SELECT * FROM {batch} WHERE passenger_count > 7",
"SELECT * FROM iris WHERE datetime_column = '01/12/2024' AND passenger_count > 7",
),
(
"SELECT * FROM {batch} WHERE passenger_count > 7 ORDER BY iris.'PetalLengthCm' DESC",
"SELECT * FROM iris WHERE datetime_column = '01/12/2024' "
"AND passenger_count > 7 ORDER BY iris.'PetalLengthCm' DESC",
),
(
"SELECT * FROM {batch} WHERE passenger_count > 7 GROUP BY iris.'Species' DESC",
"SELECT * FROM iris WHERE datetime_column = '01/12/2024' "
"AND passenger_count > 7 GROUP BY iris.'Species' DESC",
),
],
)
def test__get_query_string_with_substituted_batch_parameters(input_query: str, expected_query: str):
batch_subquery = (
sa.select("*")
.select_from(sa.text("iris"))
.where(sa.text("datetime_column = '01/12/2024'"))
.subquery()
)
actual_query = QueryMetricProvider._get_query_string_with_substituted_batch_parameters(
query=input_query,
batch_subquery=batch_subquery,
)
assert actual_query == expected_query


@pytest.mark.unit
@pytest.mark.parametrize(
"query_parameters,expected_dict",
Expand Down Expand Up @@ -421,7 +383,8 @@ def test__get_parameters_dict_from_query_parameters(
),
(
sa.select("*").select_from(sa.text("my_table")).subquery(),
"SELECT my_column FROM (SELECT * \nFROM my_table) WHERE passenger_count > 7",
"SELECT my_column FROM (SELECT * \nFROM my_table) AS subselect "
"WHERE passenger_count > 7",
),
(
sa.select("*").select_from(sa.text("my_table")),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from datetime import datetime, timezone
from typing import Sequence
from uuid import uuid4

import pandas as pd
import pytest

import great_expectations.expectations as gxe
from tests.integration.conftest import parameterize_batch_for_data_sources
from tests.integration.test_utils.data_source_config import (
BigQueryDatasourceTestConfig,
DatabricksDatasourceTestConfig,
DataSourceTestConfig,
# MSSQLDatasourceTestConfig,
MySQLDatasourceTestConfig,
PostgreSQLDatasourceTestConfig,
SnowflakeDatasourceTestConfig,
SparkFilesystemCsvDatasourceTestConfig,
# SqliteDatasourceTestConfig,
)

# pandas not currently supported by this Expecatation
ALL_SUPPORTED_DATA_SOURCES: Sequence[DataSourceTestConfig] = [
BigQueryDatasourceTestConfig(),
DatabricksDatasourceTestConfig(),
# MSSQLDatasourceTestConfig(), # fix me
MySQLDatasourceTestConfig(),
PostgreSQLDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(),
SparkFilesystemCsvDatasourceTestConfig(),
# SqliteDatasourceTestConfig(), # fix me
]

# pandas and spark not currently supporting partitioners
PARTITIONER_SUPPORTED_DATA_SOURCES: Sequence[DataSourceTestConfig] = [
BigQueryDatasourceTestConfig(),
DatabricksDatasourceTestConfig(),
# MSSQLDatasourceTestConfig(), # fix me
MySQLDatasourceTestConfig(),
PostgreSQLDatasourceTestConfig(),
SnowflakeDatasourceTestConfig(),
# SqliteDatasourceTestConfig(), # fix me
]

DATA = pd.DataFrame(
{
"created_at": [
datetime(year=2024, month=12, day=1, tzinfo=timezone.utc).date(),
datetime(year=2024, month=11, day=30, tzinfo=timezone.utc).date(),
],
"quantity": [1, 2],
"temperature": [75, 92],
"color": ["red", "red"],
}
)

DATE_COLUMN = "created_at"

SUCCESS_QUERIES = [
"SELECT * FROM {batch} WHERE quantity > 2",
"SELECT * FROM {batch} WHERE quantity > 2 AND temperature > 91",
"SELECT * FROM {batch} WHERE quantity > 2 OR temperature > 92",
"SELECT * FROM {batch} WHERE quantity > 2 ORDER BY quantity DESC",
"SELECT color FROM {batch} GROUP BY color HAVING SUM(quantity) > 3",
]

FAILURE_QUERIES = [
"SELECT * FROM {batch}",
"SELECT * FROM {batch} WHERE quantity > 0",
"SELECT * FROM {batch} WHERE quantity > 0 AND temperature > 74",
"SELECT * FROM {batch} WHERE quantity > 0 OR temperature > 92",
"SELECT * FROM {batch} WHERE quantity > 0 ORDER BY quantity DESC",
"SELECT color FROM {batch} GROUP BY color HAVING SUM(quantity) > 0",
]


@parameterize_batch_for_data_sources(
data_source_configs=ALL_SUPPORTED_DATA_SOURCES,
data=DATA,
)
@pytest.mark.parametrize("unexpected_rows_query", SUCCESS_QUERIES)
def test_unexpected_rows_expectation_batch_keyword_success(
batch_for_datasource,
unexpected_rows_query,
) -> None:
expectation = gxe.UnexpectedRowsExpectation(
description="Expect query with {batch} keyword to succeed",
unexpected_rows_query=unexpected_rows_query,
)
result = batch_for_datasource.validate(expectation)
assert result.success
assert result.exception_info.get("raised_exception") is False


@parameterize_batch_for_data_sources(
data_source_configs=ALL_SUPPORTED_DATA_SOURCES,
data=DATA,
)
@pytest.mark.parametrize("unexpected_rows_query", FAILURE_QUERIES)
def test_unexpected_rows_expectation_batch_keyword_failure(
batch_for_datasource,
unexpected_rows_query,
) -> None:
expectation = gxe.UnexpectedRowsExpectation(
description="Expect query with {batch} keyword to fail",
unexpected_rows_query=unexpected_rows_query,
)
result = batch_for_datasource.validate(expectation)
assert result.success is False
assert result.exception_info.get("raised_exception") is False


@parameterize_batch_for_data_sources(
data_source_configs=PARTITIONER_SUPPORTED_DATA_SOURCES,
data=DATA,
)
@pytest.mark.parametrize("unexpected_rows_query", SUCCESS_QUERIES)
def test_unexpected_rows_expectation_batch_keyword_partitioner_success(
asset_for_datasource,
unexpected_rows_query,
) -> None:
batch = asset_for_datasource.add_batch_definition_monthly(
name=str(uuid4()), column=DATE_COLUMN
).get_batch()
expectation = gxe.UnexpectedRowsExpectation(
description="Expect query with {batch} keyword and paritioner defined to succeed",
unexpected_rows_query=unexpected_rows_query,
)
result = batch.validate(expectation)
assert result.success
assert result.exception_info.get("raised_exception") is False


@parameterize_batch_for_data_sources(
data_source_configs=PARTITIONER_SUPPORTED_DATA_SOURCES,
data=DATA,
)
@pytest.mark.parametrize("unexpected_rows_query", FAILURE_QUERIES)
def test_unexpected_rows_expectation_batch_keyword_partitioner_failure(
asset_for_datasource,
unexpected_rows_query,
) -> None:
batch = asset_for_datasource.add_batch_definition_monthly(
name=str(uuid4()), column=DATE_COLUMN
).get_batch()
expectation = gxe.UnexpectedRowsExpectation(
description="Expect query with {batch} keyword and partitioner defined to fail",
unexpected_rows_query=unexpected_rows_query,
)
result = batch.validate(expectation)
assert result.success is False
assert result.exception_info.get("raised_exception") is False

0 comments on commit 7b93b5d

Please sign in to comment.