diff --git a/rust/cubesql/cubesql/src/compile/rewrite/converter.rs b/rust/cubesql/cubesql/src/compile/rewrite/converter.rs index 99d5057e706f3..b46c555274e88 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/converter.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/converter.rs @@ -1350,10 +1350,10 @@ impl LanguageToLogicalPlanConverter { LogicalPlanLanguage::Join(params) => { let left_on = match_data_node!(node_by_id, params[2], JoinLeftOn); let right_on = match_data_node!(node_by_id, params[3], JoinRightOn); - let left = self.to_logical_plan(params[0]); - let right = self.to_logical_plan(params[1]); + let left = self.to_logical_plan(params[0])?; + let right = self.to_logical_plan(params[1])?; - if self.is_cube_scan_node(params[0]) && self.is_cube_scan_node(params[1]) { + if Self::have_cube_scan_inside(&left) && Self::have_cube_scan_inside(&right) { if left_on.iter().any(|c| c.name == "__cubeJoinField") || right_on.iter().any(|c| c.name == "__cubeJoinField") { @@ -1370,8 +1370,8 @@ impl LanguageToLogicalPlanConverter { } } - let left = Arc::new(left?); - let right = Arc::new(right?); + let left = Arc::new(left); + let right = Arc::new(right); let join_type = match_data_node!(node_by_id, params[4], JoinJoinType); let join_constraint = match_data_node!(node_by_id, params[5], JoinJoinConstraint); @@ -1394,7 +1394,10 @@ impl LanguageToLogicalPlanConverter { }) } LogicalPlanLanguage::CrossJoin(params) => { - if self.is_cube_scan_node(params[0]) && self.is_cube_scan_node(params[1]) { + let left = self.to_logical_plan(params[0])?; + let right = self.to_logical_plan(params[1])?; + + if Self::have_cube_scan_inside(&left) && Self::have_cube_scan_inside(&right) { return Err(CubeError::internal( "Can not join Cubes. This is most likely due to one of the following reasons:\n\ • one of the cubes contains a group by\n\ @@ -1403,8 +1406,8 @@ impl LanguageToLogicalPlanConverter { )); } - let left = Arc::new(self.to_logical_plan(params[0])?); - let right = Arc::new(self.to_logical_plan(params[1])?); + let left = Arc::new(left); + let right = Arc::new(right); let schema = Arc::new(left.schema().join(right.schema())?); LogicalPlan::CrossJoin(CrossJoin { @@ -2287,16 +2290,18 @@ impl LanguageToLogicalPlanConverter { }) } - fn is_cube_scan_node(&self, node_id: Id) -> bool { - let node_by_id = &self.best_expr; - match node_by_id.index(node_id) { - LogicalPlanLanguage::CubeScan(_) | LogicalPlanLanguage::CubeScanWrapper(_) => { - return true - } - _ => (), + fn have_cube_scan_inside(node: &LogicalPlan) -> bool { + match node { + LogicalPlan::Projection(Projection { input, .. }) + | LogicalPlan::Aggregate(Aggregate { input, .. }) + | LogicalPlan::Filter(Filter { input, .. }) + | LogicalPlan::Sort(Sort { input, .. }) + | LogicalPlan::Limit(Limit { input, .. }) => Self::have_cube_scan_inside(input), + LogicalPlan::Extension(Extension { node }) => { + node.as_any().is::() || node.as_any().is::() + } + _ => false, } - - return false; } } diff --git a/rust/cubesql/cubesql/src/compile/test/test_cube_join.rs b/rust/cubesql/cubesql/src/compile/test/test_cube_join.rs index 7294609b48e68..657e9616ec89b 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_cube_join.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_cube_join.rs @@ -497,8 +497,8 @@ async fn test_join_cubes_on_wrong_field_error() { let query = convert_sql_to_cube_query( &r#" SELECT * - FROM KibanaSampleDataEcommerce - LEFT JOIN Logs ON (KibanaSampleDataEcommerce.has_subscription = Logs.read) + FROM (SELECT customer_gender, has_subscription FROM KibanaSampleDataEcommerce) kibana + LEFT JOIN (SELECT read, content FROM Logs) logs ON (kibana.has_subscription = logs.read) "# .to_string(), meta.clone(), @@ -567,6 +567,7 @@ async fn test_join_cubes_with_aggr_error() { ) } +// TODO it seems this query should not execute: it has join of grouped CubeScan with ungrouped CubeScan by __cubeJoinField #[tokio::test] async fn test_join_cubes_with_postprocessing() { if !Rewriter::sql_push_down_enabled() { @@ -621,6 +622,7 @@ async fn test_join_cubes_with_postprocessing() { ) } +// TODO it seems this query should not execute: it has join of grouped CubeScan with ungrouped CubeScan, and we explicitly try to forbid that #[tokio::test] async fn test_join_cubes_with_postprocessing_and_no_cubejoinfield() { if !Rewriter::sql_push_down_enabled() {