Skip to content

Commit

Permalink
Bug fix: use FULL OUTER JOIN for dimension-only queries (#863)
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb authored Nov 14, 2023
1 parent 45a970d commit 6dd0fd7
Show file tree
Hide file tree
Showing 37 changed files with 1,878 additions and 67 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20231110-162009.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Use FULL OUTER JOIN for dimension-only queries.
time: 2023-11-10T16:20:09.530487-08:00
custom:
Author: courtneyholcomb
Issue: "863"
54 changes: 9 additions & 45 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,49 +80,7 @@ class DataflowRecipe:
@property
def join_targets(self) -> List[JoinDescription]:
"""Joins to be made to source node."""
join_targets = []
for join_recipe in self.join_linkable_instances_recipes:
# Figure out what elements to filter from the joined node.

# Sanity check - all linkable specs should have a link, or else why would we be joining them.
assert all([len(x.entity_links) > 0 for x in join_recipe.satisfiable_linkable_specs])

# If we're joining something in, then we need the associated entity, partitions, and time dimension
# specs defining the validity window (if necessary)
include_specs: List[LinkableInstanceSpec] = [
LinklessEntitySpec.from_reference(x.entity_links[0]) for x in join_recipe.satisfiable_linkable_specs
]
include_specs.extend([x.node_to_join_dimension_spec for x in join_recipe.join_on_partition_dimensions])
include_specs.extend(
[x.node_to_join_time_dimension_spec for x in join_recipe.join_on_partition_time_dimensions]
)
if join_recipe.validity_window:
include_specs.extend(
[
join_recipe.validity_window.window_start_dimension,
join_recipe.validity_window.window_end_dimension,
]
)

# satisfiable_linkable_specs describes what can be satisfied after the join, so remove the entity
# link when filtering before the join.
# e.g. if the node is used to satisfy "user_id__country", then the node must have the entity
# "user_id" and the "country" dimension so that it can be joined to the measure node.
include_specs.extend([x.without_first_entity_link for x in join_recipe.satisfiable_linkable_specs])
filtered_node_to_join = FilterElementsNode(
parent_node=join_recipe.node_to_join,
include_specs=InstanceSpecSet.create_from_linkable_specs(include_specs),
)
join_targets.append(
JoinDescription(
join_node=filtered_node_to_join,
join_on_entity=join_recipe.join_on_entity,
join_on_partition_dimensions=join_recipe.join_on_partition_dimensions,
join_on_partition_time_dimensions=join_recipe.join_on_partition_time_dimensions,
validity_window=join_recipe.validity_window,
)
)
return join_targets
return [join_recipe.join_description for join_recipe in self.join_linkable_instances_recipes]


@dataclass(frozen=True)
Expand Down Expand Up @@ -485,13 +443,15 @@ def _find_dataflow_recipe(
potential_source_nodes: Sequence[BaseOutput] = self._select_source_nodes_with_measures(
measure_specs=set(measure_spec_properties.measure_specs), source_nodes=source_nodes
)
default_join_type = SqlJoinType.LEFT_OUTER
else:
# Only read nodes can be source nodes for queries without measures
source_nodes = self._read_nodes
source_nodes_to_linkable_specs = self._select_read_nodes_with_linkable_specs(
linkable_specs=linkable_spec_set, read_nodes=source_nodes
)
potential_source_nodes = list(source_nodes_to_linkable_specs.keys())
default_join_type = SqlJoinType.FULL_OUTER

logger.info(f"There are {len(potential_source_nodes)} potential source nodes")

Expand All @@ -518,7 +478,9 @@ def _find_dataflow_recipe(
f"After removing unnecessary nodes, there are {len(nodes_available_for_joins)} nodes available for joins"
)
if DataflowPlanBuilder._contains_multihop_linkables(linkable_specs):
nodes_available_for_joins = node_processor.add_multi_hop_joins(linkable_specs, source_nodes)
nodes_available_for_joins = node_processor.add_multi_hop_joins(
desired_linkable_specs=linkable_specs, nodes=source_nodes, join_type=default_join_type
)
logger.info(
f"After adding multi-hop nodes, there are {len(nodes_available_for_joins)} nodes available for joins:\n"
f"{pformat_big_objects(nodes_available_for_joins)}"
Expand Down Expand Up @@ -552,7 +514,9 @@ def _find_dataflow_recipe(
logger.debug(f"Evaluating source node:\n{pformat_big_objects(source_node=dataflow_dag_as_text(node))}")

start_time = time.time()
evaluation = node_evaluator.evaluate_node(start_node=node, required_linkable_specs=list(linkable_specs))
evaluation = node_evaluator.evaluate_node(
start_node=node, required_linkable_specs=list(linkable_specs), default_join_type=default_join_type
)
logger.info(f"Evaluation of {node} took {time.time() - start_time:.2f}s")

logger.debug(
Expand Down
45 changes: 39 additions & 6 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from metricflow.dataflow.builder.partitions import PartitionJoinResolver
from metricflow.dataflow.dataflow_plan import (
BaseOutput,
FilterElementsNode,
JoinDescription,
PartitionDimensionJoinDescription,
PartitionTimeDimensionJoinDescription,
Expand All @@ -37,10 +38,8 @@
from metricflow.model.semantics.semantic_model_join_evaluator import SemanticModelJoinEvaluator
from metricflow.plan_conversion.instance_converters import CreateValidityWindowJoinDescription
from metricflow.protocols.semantics import SemanticModelAccessor
from metricflow.specs.specs import (
LinkableInstanceSpec,
LinklessEntitySpec,
)
from metricflow.specs.specs import InstanceSpecSet, LinkableInstanceSpec, LinklessEntitySpec
from metricflow.sql.sql_plan import SqlJoinType

logger = logging.getLogger(__name__)

Expand All @@ -61,6 +60,8 @@ class JoinLinkableInstancesRecipe:
# the linkable specs in the node that can help to satisfy the query. e.g. "user_id__country" might be one of the
# "satisfiable_linkable_specs", but "country" is the linkable spec in the node.
satisfiable_linkable_specs: List[LinkableInstanceSpec]
# Join type to use when joining nodes
join_type: SqlJoinType

# The partitions to join on, if there are matching partitions between the start_node and node_to_join.
join_on_partition_dimensions: Tuple[PartitionDimensionJoinDescription, ...]
Expand All @@ -71,12 +72,36 @@ class JoinLinkableInstancesRecipe:
@property
def join_description(self) -> JoinDescription:
"""The recipe as a join description to use in the dataflow plan node."""
# Figure out what elements to filter from the joined node.
include_specs: List[LinkableInstanceSpec] = []
assert all([len(spec.entity_links) > 0 for spec in self.satisfiable_linkable_specs])
include_specs.extend(
[LinklessEntitySpec.from_reference(spec.entity_links[0]) for spec in self.satisfiable_linkable_specs]
)

include_specs.extend([join.node_to_join_dimension_spec for join in self.join_on_partition_dimensions])
include_specs.extend([join.node_to_join_time_dimension_spec for join in self.join_on_partition_time_dimensions])
if self.validity_window:
include_specs.extend(
[self.validity_window.window_start_dimension, self.validity_window.window_end_dimension]
)

# `satisfiable_linkable_specs` describes what can be satisfied after the join, so remove the entity
# link when filtering before the join.
# e.g. if the node is used to satisfy "user_id__country", then the node must have the entity
# "user_id" and the "country" dimension so that it can be joined to the source node.
include_specs.extend([spec.without_first_entity_link for spec in self.satisfiable_linkable_specs])
filtered_node_to_join = FilterElementsNode(
parent_node=self.node_to_join, include_specs=InstanceSpecSet.create_from_linkable_specs(include_specs)
)

return JoinDescription(
join_node=self.node_to_join,
join_node=filtered_node_to_join,
join_on_entity=self.join_on_entity,
join_on_partition_dimensions=self.join_on_partition_dimensions,
join_on_partition_time_dimensions=self.join_on_partition_time_dimensions,
validity_window=self.validity_window,
join_type=self.join_type,
)


Expand Down Expand Up @@ -133,6 +158,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
self,
start_node_instance_set: InstanceSet,
needed_linkable_specs: List[LinkableInstanceSpec],
join_type: SqlJoinType,
) -> List[JoinLinkableInstancesRecipe]:
"""Get nodes that can be joined to get 1 or more of the "needed_linkable_specs".
Expand Down Expand Up @@ -257,6 +283,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
join_on_partition_dimensions=join_on_partition_dimensions,
join_on_partition_time_dimensions=join_on_partition_time_dimensions,
validity_window=validity_window_join_description,
join_type=join_type,
)
)

Expand All @@ -271,6 +298,7 @@ def _find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
def _update_candidates_that_can_satisfy_linkable_specs(
candidates_for_join: List[JoinLinkableInstancesRecipe],
already_satisfisfied_linkable_specs: List[LinkableInstanceSpec],
join_type: SqlJoinType,
) -> List[JoinLinkableInstancesRecipe]:
"""Update / filter candidates_for_join based on linkable instance specs that we have already satisfied.
Expand All @@ -294,6 +322,7 @@ def _update_candidates_that_can_satisfy_linkable_specs(
join_on_partition_dimensions=candidate_for_join.join_on_partition_dimensions,
join_on_partition_time_dimensions=candidate_for_join.join_on_partition_time_dimensions,
validity_window=candidate_for_join.validity_window,
join_type=join_type,
)
)
return sorted(
Expand All @@ -306,6 +335,7 @@ def evaluate_node(
self,
start_node: BaseOutput,
required_linkable_specs: Sequence[LinkableInstanceSpec],
default_join_type: SqlJoinType,
) -> LinkableInstanceSatisfiabilityEvaluation:
"""Evaluates if the "required_linkable_specs" can be realized by joining the "start_node" with other nodes.
Expand Down Expand Up @@ -345,7 +375,9 @@ def evaluate_node(
possibly_joinable_linkable_specs.append(required_linkable_spec)

candidates_for_join = self._find_joinable_candidate_nodes_that_can_satisfy_linkable_specs(
start_node_instance_set=candidate_instance_set, needed_linkable_specs=possibly_joinable_linkable_specs
start_node_instance_set=candidate_instance_set,
needed_linkable_specs=possibly_joinable_linkable_specs,
join_type=default_join_type,
)
join_candidates: List[JoinLinkableInstancesRecipe] = []

Expand Down Expand Up @@ -378,6 +410,7 @@ def evaluate_node(
candidates_for_join = self._update_candidates_that_can_satisfy_linkable_specs(
candidates_for_join=candidates_for_join,
already_satisfisfied_linkable_specs=next_candidate.satisfiable_linkable_specs,
join_type=default_join_type,
)

# The once possibly joinable specs are definitely joinable and no longer need to be searched for.
Expand Down
2 changes: 2 additions & 0 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ class JoinDescription:

join_node: BaseOutput
join_on_entity: LinklessEntitySpec
join_type: SqlJoinType

join_on_partition_dimensions: Tuple[PartitionDimensionJoinDescription, ...]
join_on_partition_time_dimensions: Tuple[PartitionTimeDimensionJoinDescription, ...]
Expand Down Expand Up @@ -339,6 +340,7 @@ def with_new_parents(self, new_parent_nodes: Sequence[BaseOutput]) -> JoinToBase
join_on_partition_dimensions=old_join_target.join_on_partition_dimensions,
join_on_partition_time_dimensions=old_join_target.join_on_partition_time_dimensions,
validity_window=old_join_target.validity_window,
join_type=old_join_target.join_type,
)
for i, old_join_target in enumerate(self._join_targets)
],
Expand Down
3 changes: 0 additions & 3 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,9 @@ def visit_join_to_base_output_node(self, node: JoinToBaseOutputNode) -> SqlDataS
# e.g. a data set has the dimension "listing__country_latest" and "listing" is a primary entity in the
# data set. The next step would create an instance like "listing__listing__country_latest" without this
# filter.

# logger.error(f"before filter is:\n{pformat_big_objects(right_data_set.instance_set.spec_set)}")
right_data_set_instance_set_filtered = FilterLinkableInstancesWithLeadingLink(
entity_link=join_on_entity,
).transform(right_data_set.instance_set)
# logger.error(f"after filter is:\n{pformat_big_objects(right_data_set_instance_set_filtered.spec_set)}")

# After the right data set is joined to the "from" data set, we need to change the links for some of the
# instances that represent the right data set. For example, if the "from" data set contains the "bookings"
Expand Down
14 changes: 8 additions & 6 deletions metricflow/plan_conversion/node_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from metricflow.protocols.semantics import SemanticModelAccessor
from metricflow.specs.spec_set_transforms import ToElementNameSet
from metricflow.specs.specs import InstanceSpecSet, LinkableInstanceSpec, LinklessEntitySpec
from metricflow.sql.sql_plan import SqlJoinType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -149,9 +150,7 @@ def _node_contains_entity(
return False

def _get_candidates_nodes_for_multi_hop(
self,
desired_linkable_spec: LinkableInstanceSpec,
nodes: Sequence[BaseOutput],
self, desired_linkable_spec: LinkableInstanceSpec, nodes: Sequence[BaseOutput], join_type: SqlJoinType
) -> Sequence[MultiHopJoinCandidate]:
"""Assemble nodes representing all possible one-hop joins."""
if len(desired_linkable_spec.entity_links) > MAX_JOIN_HOPS:
Expand Down Expand Up @@ -249,6 +248,7 @@ def _get_candidates_nodes_for_multi_hop(
),
join_on_partition_dimensions=join_on_partition_dimensions,
join_on_partition_time_dimensions=join_on_partition_time_dimensions,
join_type=join_type,
)
],
),
Expand Down Expand Up @@ -276,16 +276,18 @@ def _get_candidates_nodes_for_multi_hop(
return multi_hop_join_candidates

def add_multi_hop_joins(
self, desired_linkable_specs: Sequence[LinkableInstanceSpec], nodes: Sequence[BaseOutput]
self,
desired_linkable_specs: Sequence[LinkableInstanceSpec],
nodes: Sequence[BaseOutput],
join_type: SqlJoinType,
) -> Sequence[BaseOutput]:
"""Assemble nodes representing all possible one-hop joins."""
all_multi_hop_join_candidates: List[MultiHopJoinCandidate] = []
lineage_for_all_multi_hop_join_candidates: Set[MultiHopJoinCandidateLineage] = set()

for desired_linkable_spec in desired_linkable_specs:
for multi_hop_join_candidate in self._get_candidates_nodes_for_multi_hop(
desired_linkable_spec=desired_linkable_spec,
nodes=nodes,
desired_linkable_spec=desired_linkable_spec, nodes=nodes, join_type=join_type
):
# Dedupe candidates that are the same join.
if multi_hop_join_candidate.lineage not in lineage_for_all_multi_hop_join_candidates:
Expand Down
2 changes: 1 addition & 1 deletion metricflow/plan_conversion/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def make_base_output_join_description(
left_source_alias=left_data_set.alias,
right_source_alias=right_data_set.alias,
column_equality_descriptions=column_equality_descriptions,
join_type=SqlJoinType.LEFT_OUTER,
join_type=join_description.join_type,
additional_on_conditions=validity_conditions,
)

Expand Down
3 changes: 2 additions & 1 deletion metricflow/test/dataflow/builder/test_node_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
MeasureSpec,
)
from metricflow.sql.sql_exprs import SqlColumnReference, SqlColumnReferenceExpression
from metricflow.sql.sql_plan import SqlSelectColumn, SqlSelectStatementNode, SqlTableFromClauseNode
from metricflow.sql.sql_plan import SqlJoinType, SqlSelectColumn, SqlSelectStatementNode, SqlTableFromClauseNode
from metricflow.test.fixtures.model_fixtures import ConsistentIdObjectRepository
from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState
from metricflow.test.snapshot_utils import assert_spec_set_snapshot_equal
Expand Down Expand Up @@ -111,6 +111,7 @@ def test_joined_node_data_set( # noqa: D
join_on_entity=LinklessEntitySpec.from_element_name("user"),
join_on_partition_dimensions=(),
join_on_partition_time_dimensions=(),
join_type=SqlJoinType.LEFT_OUTER,
)
],
)
Expand Down
Loading

0 comments on commit 6dd0fd7

Please sign in to comment.