diff --git a/mindsdb_sql/planner/plan_join.py b/mindsdb_sql/planner/plan_join.py index acd1003c..da20aa54 100644 --- a/mindsdb_sql/planner/plan_join.py +++ b/mindsdb_sql/planner/plan_join.py @@ -104,6 +104,7 @@ def plan(self, query): ): query2 = copy.deepcopy(query) query2.from_table = None + query2.using = None sup_select = QueryStep(query2, from_table=join_step.result) self.planner.plan.add_step(sup_select) return sup_select @@ -429,11 +430,25 @@ def process_predictor(self, item, query_in): # exclude condition el._orig_node.args = [Constant(0), Constant(0)] + # params for model + model_params = None + + if query_in.using is not None: + model_params = {} + for param, value in query_in.using.items(): + if '.' in param: + alias = param.split('.')[0] + if (alias,) in item.aliases: + new_param = '.'.join(param.split('.')[1:]) + model_params[new_param] = value + else: + model_params[param] = value + predictor_step = self.planner.plan.add_step(ApplyPredictorStep( namespace=item.integration, dataframe=data_step.result, predictor=item.table, - params=query_in.using, + params=model_params, row_dict=row_dict, )) self.step_stack.append(predictor_step) diff --git a/tests/test_planner/test_join_predictor.py b/tests/test_planner/test_join_predictor.py index d2d80a4a..517f34ae 100644 --- a/tests/test_planner/test_join_predictor.py +++ b/tests/test_planner/test_join_predictor.py @@ -736,6 +736,9 @@ def test_model_join_model(self): join mindsdb.pred m join mindsdb.pred m2 where m.a = 2 + using m.param1 = 'a', + m2.param2 = 'b', + param3 = 'c' ''' subquery = parse_sql(""" @@ -750,13 +753,15 @@ def test_model_join_model(self): FetchDataframeStep(integration='int', query=parse_sql('select * from tab1 as t')), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), - predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': 2}), + predictor=Identifier('pred', alias=Identifier('m')), + row_dict={ 'a': 2 }, params={ 'param1': 'a', 'param3': 'c' }), JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(2), - predictor=Identifier('pred', alias=Identifier('m2'))), + predictor=Identifier('pred', alias=Identifier('m2')), + params={ 'param2': 'b', 'param3': 'c' }), JoinStep(left=Result(2), right=Result(3), query=Join(left=Identifier('tab1'), right=Identifier('tab2'),