Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renames related to SqlPlan - Part 1 #1581

Merged
merged 5 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from metricflow.dataset.dataset_classes import DataSet
from metricflow.sql.sql_plan import (
SqlQueryPlanNode,
SqlPlanNode,
SqlSelectStatementNode,
)

Expand All @@ -28,7 +28,7 @@ def __init__(
self,
instance_set: InstanceSet,
sql_select_node: Optional[SqlSelectStatementNode] = None,
sql_node: Optional[SqlQueryPlanNode] = None,
sql_node: Optional[SqlPlanNode] = None,
) -> None:
"""Constructor.

Expand All @@ -42,7 +42,7 @@ def __init__(
super().__init__(instance_set=instance_set)

@property
def sql_node(self) -> SqlQueryPlanNode: # noqa: D102
def sql_node(self) -> SqlPlanNode: # noqa: D102
node_to_return = self._sql_select_node or self._sql_node
if node_to_return is None:
raise RuntimeError(
Expand Down
4 changes: 2 additions & 2 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
)
from metricflow.execution.execution_plan import ExecutionPlan, SqlStatement
from metricflow.execution.executor import SequentialPlanExecutor
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlPlanConverter
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from metricflow.telemetry.models import TelemetryLevel
Expand Down Expand Up @@ -399,7 +399,7 @@ def __init__(
source_node_builder=source_node_builder,
dataflow_plan_builder_cache=self._dataflow_plan_builder_cache,
)
self._to_sql_query_plan_converter = DataflowToSqlQueryPlanConverter(
self._to_sql_query_plan_converter = DataflowToSqlPlanConverter(
column_association_resolver=self._column_association_resolver,
semantic_manifest_lookup=self._semantic_manifest_lookup,
)
Expand Down
4 changes: 2 additions & 2 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
SqlStatement,
)
from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlPlanConverter
from metricflow.protocols.sql_client import SqlClient
from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel
from metricflow.sql.render.sql_plan_renderer import SqlPlanRenderResult, SqlQueryPlanRenderer
Expand All @@ -52,7 +52,7 @@ class DataflowToExecutionPlanConverter(DataflowPlanNodeVisitor[ConvertToExecutio

def __init__(
self,
sql_plan_converter: DataflowToSqlQueryPlanConverter,
sql_plan_converter: DataflowToSqlPlanConverter,
sql_plan_renderer: SqlQueryPlanRenderer,
sql_client: SqlClient,
sql_optimization_level: SqlQueryOptimizationLevel,
Expand Down
24 changes: 12 additions & 12 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,21 +126,21 @@
from metricflow.plan_conversion.sql_join_builder import (
AnnotatedSqlDataSet,
ColumnEqualityDescription,
SqlQueryPlanJoinBuilder,
SqlPlanJoinBuilder,
)
from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.optimizer.optimization_levels import (
SqlGenerationOptionSet,
SqlQueryOptimizationLevel,
)
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlPlan,
SqlQueryPlanNode,
SqlPlanNode,
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableNode,
Expand All @@ -149,7 +149,7 @@
logger = logging.getLogger(__name__)


class DataflowToSqlQueryPlanConverter:
class DataflowToSqlPlanConverter:
"""Generates an SQL query plan from a node in the metric dataflow plan."""

def __init__(
Expand Down Expand Up @@ -272,7 +272,7 @@ def convert_using_specifics(
dataflow_plan_node: DataflowPlanNode,
sql_query_plan_id: Optional[DagId],
nodes_to_convert_to_cte: FrozenSet[DataflowPlanNode],
optimizers: Sequence[SqlQueryPlanOptimizer],
optimizers: Sequence[SqlPlanOptimizer],
) -> ConvertToSqlPlanResult:
"""Helper method to convert using specific options. Main use case are tests."""
logger.debug(
Expand Down Expand Up @@ -313,7 +313,7 @@ def convert_using_specifics(
),
)

sql_node: SqlQueryPlanNode = data_set.sql_node
sql_node: SqlPlanNode = data_set.sql_node

for optimizer in optimizers:
logger.debug(LazyFormat(lambda: f"Applying optimizer: {optimizer.__class__.__name__}"))
Expand Down Expand Up @@ -539,7 +539,7 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDat
join_spec = self._choose_instance_for_time_spine_join(agg_time_dimension_instances).spec
annotated_parent = parent_data_set.annotate(alias=parent_data_set_alias, metric_time_spec=join_spec)
annotated_time_spine = time_spine_data_set.annotate(alias=time_spine_data_set_alias, metric_time_spec=join_spec)
join_desc = SqlQueryPlanJoinBuilder.make_cumulative_metric_time_range_join_description(
join_desc = SqlPlanJoinBuilder.make_cumulative_metric_time_range_join_description(
node=node, metric_data_set=annotated_parent, time_spine_data_set=annotated_time_spine
)

Expand Down Expand Up @@ -598,7 +598,7 @@ def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> SqlDataSet:
right_data_set_alias = self._next_unique_table_alias()

# Build join description.
sql_join_desc = SqlQueryPlanJoinBuilder.make_base_output_join_description(
sql_join_desc = SqlPlanJoinBuilder.make_base_output_join_description(
left_data_set=AnnotatedSqlDataSet(data_set=from_data_set, alias=from_data_set_alias),
right_data_set=AnnotatedSqlDataSet(data_set=right_data_set, alias=right_data_set_alias),
join_description=join_description,
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNo
aliases_seen = [from_data_set.alias]
for join_data_set in join_data_sets:
joins_descriptions.append(
SqlQueryPlanJoinBuilder.make_join_description_for_combining_datasets(
SqlPlanJoinBuilder.make_join_description_for_combining_datasets(
from_data_set=from_data_set,
join_data_set=join_data_set,
join_type=join_type,
Expand Down Expand Up @@ -1397,7 +1397,7 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe
)

join_data_set_alias = self._next_unique_table_alias()
sql_join_desc = SqlQueryPlanJoinBuilder.make_column_equality_sql_join_description(
sql_join_desc = SqlPlanJoinBuilder.make_column_equality_sql_join_description(
right_source_node=row_filter_sql_select_node,
left_source_alias=from_data_set_alias,
right_source_alias=join_data_set_alias,
Expand Down Expand Up @@ -1444,7 +1444,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet

# Build join expression.
join_column_name = self._column_association_resolver.resolve_spec(node.join_on_time_dimension_spec).column_name
join_description = SqlQueryPlanJoinBuilder.make_join_to_time_spine_join_description(
join_description = SqlPlanJoinBuilder.make_join_to_time_spine_join_description(
node=node,
time_spine_alias=time_spine_alias,
agg_time_dimension_column_name=join_column_name,
Expand Down Expand Up @@ -1756,7 +1756,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S
constant_property_column_names.append((base_property_col_name, conversion_property_col_name))

# Builds the join conditions that is required for a successful conversion
sql_join_description = SqlQueryPlanJoinBuilder.make_join_conversion_join_description(
sql_join_description = SqlPlanJoinBuilder.make_join_conversion_join_description(
node=node,
base_data_set=AnnotatedSqlDataSet(
data_set=base_data_set,
Expand Down
18 changes: 9 additions & 9 deletions metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ColumnEqualityDescription:
treat_nulls_as_equal: bool = False


class SqlQueryPlanJoinBuilder:
class SqlPlanJoinBuilder:
"""Helper class for constructing various join components in a SqlQueryPlan."""

@staticmethod
Expand Down Expand Up @@ -194,11 +194,11 @@ def make_base_output_join_description(
)
)

validity_conditions = SqlQueryPlanJoinBuilder._make_validity_window_on_conditions(
validity_conditions = SqlPlanJoinBuilder._make_validity_window_on_conditions(
left_data_set=left_data_set, right_data_set=right_data_set, join_description=join_description
)

return SqlQueryPlanJoinBuilder.make_column_equality_sql_join_description(
return SqlPlanJoinBuilder.make_column_equality_sql_join_description(
right_source_node=right_data_set.data_set.checked_sql_select_node,
left_source_alias=left_data_set.alias,
right_source_alias=right_data_set.alias,
Expand Down Expand Up @@ -253,7 +253,7 @@ def _make_validity_window_on_conditions(
window_end_dimension_name = right_data_set.data_set.column_association_for_time_dimension(
join_description.validity_window.window_end_dimension
).column_name
window_join_condition = SqlQueryPlanJoinBuilder._make_time_window_join_condition(
window_join_condition = SqlPlanJoinBuilder._make_time_window_join_condition(
left_source_alias=left_data_set.alias,
left_source_time_dimension_name=left_data_set_time_dimension_name,
right_source_alias=right_data_set.alias,
Expand Down Expand Up @@ -331,7 +331,7 @@ def make_join_description_for_combining_datasets(
len(column_names) > 0
), "Attempting to do a FULL OUTER JOIN to combine metrics, but no columns were provided for join keys!"
equality_exprs = [
SqlQueryPlanJoinBuilder._make_equality_expression_for_full_outer_join(
SqlPlanJoinBuilder._make_equality_expression_for_full_outer_join(
table_aliases_for_coalesce, join_data_set.alias, colname
)
for colname in column_names
Expand All @@ -352,7 +352,7 @@ def make_join_description_for_combining_datasets(
ColumnEqualityDescription(left_column_alias=name, right_column_alias=name, treat_nulls_as_equal=True)
for name in column_names
]
return SqlQueryPlanJoinBuilder.make_column_equality_sql_join_description(
return SqlPlanJoinBuilder.make_column_equality_sql_join_description(
right_source_node=join_data_set.data_set.checked_sql_select_node,
left_source_alias=from_data_set.alias,
right_source_alias=join_data_set.alias,
Expand Down Expand Up @@ -484,12 +484,12 @@ def make_join_conversion_join_description(
under the condition of being within the time range and other conditions (such as constant properties).
This builds the join description to satisfy all those conditions.
"""
window_condition = SqlQueryPlanJoinBuilder._make_time_range_window_join_condition(
window_condition = SqlPlanJoinBuilder._make_time_range_window_join_condition(
base_data_set=base_data_set,
time_comparison_dataset=conversion_data_set,
window=node.window,
)
return SqlQueryPlanJoinBuilder.make_column_equality_sql_join_description(
return SqlPlanJoinBuilder.make_column_equality_sql_join_description(
right_source_node=conversion_data_set.data_set.checked_sql_select_node,
left_source_alias=base_data_set.alias,
right_source_alias=conversion_data_set.alias,
Expand All @@ -509,7 +509,7 @@ def make_cumulative_metric_time_range_join_description(
Cumulative metrics must be joined against a time spine in a backward-looking fashion, with
a range determined by a time window (delta against metric_time) and optional cumulative grain.
"""
cumulative_join_condition = SqlQueryPlanJoinBuilder._make_time_range_window_join_condition(
cumulative_join_condition = SqlPlanJoinBuilder._make_time_range_window_join_condition(
base_data_set=metric_data_set,
time_comparison_dataset=time_spine_data_set,
window=node.window,
Expand Down
22 changes: 11 additions & 11 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from typing_extensions import override

from metricflow.sql.optimizer.required_column_aliases import SqlMapRequiredColumnAliasesVisitor
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlQueryPlanNode,
SqlQueryPlanNodeVisitor,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
Expand All @@ -21,7 +21,7 @@
logger = logging.getLogger(__name__)


class SqlColumnPrunerVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]):
class SqlColumnPrunerVisitor(SqlPlanNodeVisitor[SqlPlanNode]):
"""Removes unnecessary columns from SELECT statements in the SQL query plan.

This requires a set of tagged column aliases that should be kept for each SQL node.
Expand All @@ -38,7 +38,7 @@ def __init__(
"""
self._required_alias_mapping = required_alias_mapping

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNode: # noqa: D102
# Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't
# need them. However, keep columns that are in group bys because that changes the meaning of the query.
# Similarly, if this node is a distinct select node, keep all columns as it may return a different result set.
Expand Down Expand Up @@ -81,29 +81,29 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
distinct=node.distinct,
)

def visit_table_node(self, node: SqlTableNode) -> SqlQueryPlanNode:
def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode:
"""There are no SELECT columns in this node, so pruning cannot apply."""
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlQueryPlanNode:
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode:
"""Pruning cannot be done here since this is an arbitrary user-provided SQL query."""
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlanNode: # noqa: D102
def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102
return SqlCreateTableAsNode.create(
sql_table=node.sql_table,
parent_node=node.parent_node.accept(self),
)

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
def visit_cte_node(self, node: SqlCteNode) -> SqlPlanNode:
return node.with_new_select(node.select_statement.accept(self))


class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer):
class SqlColumnPrunerOptimizer(SqlPlanOptimizer):
"""Removes unnecessary columns in the SELECT statements."""

def optimize(self, node: SqlQueryPlanNode) -> SqlQueryPlanNode: # noqa: D102
def optimize(self, node: SqlPlanNode) -> SqlPlanNode: # noqa: D102
# ALl columns in the nearest SELECT node need to be kept as otherwise, the meaning of the query changes.
required_select_columns = node.nearest_select_columns({})

Expand Down
6 changes: 3 additions & 3 deletions metricflow/sql/optimizer/optimization_levels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from metricflow.sql.optimizer.column_pruner import SqlColumnPrunerOptimizer
from metricflow.sql.optimizer.rewriting_sub_query_reducer import SqlRewritingSubQueryReducer
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.optimizer.sub_query_reducer import SqlSubQueryReducer
from metricflow.sql.optimizer.table_alias_simplifier import SqlTableAliasSimplifier

Expand Down Expand Up @@ -40,7 +40,7 @@ def __lt__(self, other: SqlQueryOptimizationLevel) -> bool: # noqa: D105
class SqlGenerationOptionSet:
"""Defines the different SQL generation optimizers / options that should be used at each level."""

optimizers: Tuple[SqlQueryPlanOptimizer, ...]
optimizers: Tuple[SqlPlanOptimizer, ...]

# Specifies whether CTEs can be used to simplify generated SQL.
allow_cte: bool
Expand All @@ -49,7 +49,7 @@ class SqlGenerationOptionSet:
def options_for_level( # noqa: D102
level: SqlQueryOptimizationLevel, use_column_alias_in_group_by: bool
) -> SqlGenerationOptionSet:
optimizers: Tuple[SqlQueryPlanOptimizer, ...] = ()
optimizers: Tuple[SqlPlanOptimizer, ...] = ()
allow_cte = False
if level is SqlQueryOptimizationLevel.O0:
pass
Expand Down
10 changes: 5 additions & 5 deletions metricflow/sql/optimizer/required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlQueryPlanNode,
SqlQueryPlanNodeVisitor,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
Expand All @@ -23,7 +23,7 @@
logger = logging.getLogger(__name__)


class SqlMapRequiredColumnAliasesVisitor(SqlQueryPlanNodeVisitor[None]):
class SqlMapRequiredColumnAliasesVisitor(SqlPlanNodeVisitor[None]):
"""To aid column pruning, traverse the SQL-query representation DAG and map the SELECT columns needed at each node.

For example, the query:
Expand Down Expand Up @@ -55,7 +55,7 @@ class SqlMapRequiredColumnAliasesVisitor(SqlQueryPlanNodeVisitor[None]):
) source_0
"""

def __init__(self, start_node: SqlQueryPlanNode, required_column_aliases_in_start_node: FrozenSet[str]) -> None:
def __init__(self, start_node: SqlPlanNode, required_column_aliases_in_start_node: FrozenSet[str]) -> None:
"""Initializer.

Args:
Expand Down Expand Up @@ -114,7 +114,7 @@ def visit_cte_node(self, node: SqlCteNode) -> None:
# Visit parent nodes.
select_statement.accept(self)

def _visit_parents(self, node: SqlQueryPlanNode) -> None:
def _visit_parents(self, node: SqlPlanNode) -> None:
"""Default recursive handler to visit the parents of the given node."""
for parent_node in node.parent_nodes:
parent_node.accept(self)
Expand Down
Loading
Loading