From 7f11562e2f44c6fc334bbe9e7bd2a6ec6d10f22f Mon Sep 17 00:00:00 2001 From: andrew Date: Fri, 16 Feb 2024 20:52:57 +0300 Subject: [PATCH 1/2] Mapping 'using' variables to model #351 --- mindsdb_sql/planner/plan_join.py | 17 ++++++++++++++++- tests/test_planner/test_join_predictor.py | 9 +++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/mindsdb_sql/planner/plan_join.py b/mindsdb_sql/planner/plan_join.py index acd1003c..da20aa54 100644 --- a/mindsdb_sql/planner/plan_join.py +++ b/mindsdb_sql/planner/plan_join.py @@ -104,6 +104,7 @@ def plan(self, query): ): query2 = copy.deepcopy(query) query2.from_table = None + query2.using = None sup_select = QueryStep(query2, from_table=join_step.result) self.planner.plan.add_step(sup_select) return sup_select @@ -429,11 +430,25 @@ def process_predictor(self, item, query_in): # exclude condition el._orig_node.args = [Constant(0), Constant(0)] + # params for model + model_params = None + + if query_in.using is not None: + model_params = {} + for param, value in query_in.using.items(): + if '.' in param: + alias = param.split('.')[0] + if (alias,) in item.aliases: + new_param = '.'.join(param.split('.')[1:]) + model_params[new_param] = value + else: + model_params[param] = value + predictor_step = self.planner.plan.add_step(ApplyPredictorStep( namespace=item.integration, dataframe=data_step.result, predictor=item.table, - params=query_in.using, + params=model_params, row_dict=row_dict, )) self.step_stack.append(predictor_step) diff --git a/tests/test_planner/test_join_predictor.py b/tests/test_planner/test_join_predictor.py index d2d80a4a..517f34ae 100644 --- a/tests/test_planner/test_join_predictor.py +++ b/tests/test_planner/test_join_predictor.py @@ -736,6 +736,9 @@ def test_model_join_model(self): join mindsdb.pred m join mindsdb.pred m2 where m.a = 2 + using m.param1 = 'a', + m2.param2 = 'b', + param3 = 'c' ''' subquery = parse_sql(""" @@ -750,13 +753,15 @@ def test_model_join_model(self): FetchDataframeStep(integration='int', query=parse_sql('select * from tab1 as t')), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), - predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': 2}), + predictor=Identifier('pred', alias=Identifier('m')), + row_dict={ 'a': 2 }, params={ 'param1': 'a', 'param3': 'c' }), JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(2), - predictor=Identifier('pred', alias=Identifier('m2'))), + predictor=Identifier('pred', alias=Identifier('m2')), + params={ 'param2': 'b', 'param3': 'c' }), JoinStep(left=Result(2), right=Result(3), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), From bde72e3d1358b79f665709ef38b7d8d810db93b3 Mon Sep 17 00:00:00 2001 From: andrew Date: Thu, 22 Feb 2024 11:50:06 +0300 Subject: [PATCH 2/2] renderer: get sqlalchemy object --- mindsdb_sql/render/sqlalchemy_render.py | 41 ++++++++++++++----------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index c138b113..7e4f4288 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -592,28 +592,33 @@ def prepare_update(self, ast_query): return stmt + def get_query(self, ast_query): + if isinstance(ast_query, ast.Select): + stmt = self.prepare_select(ast_query) + elif isinstance(ast_query, ast.Insert): + stmt = self.prepare_insert(ast_query) + elif isinstance(ast_query, ast.Update): + stmt = self.prepare_update(ast_query) + elif isinstance(ast_query, ast.CreateTable): + stmt = self.prepare_create_table(ast_query) + elif isinstance(ast_query, ast.DropTables): + stmt = self.prepare_drop_table(ast_query) + else: + raise NotImplementedError(f'Unknown statement: {ast_query.__class__.__name__}') + return stmt + def get_string(self, ast_query, with_failback=True): + if isinstance(ast_query, (ast.CreateTable, ast.DropTables)): + render_func = render_ddl_query + else: + render_func = render_dml_query + try: - if isinstance(ast_query, ast.Select): - stmt = self.prepare_select(ast_query) - sql = render_dml_query(stmt, self.dialect) - elif isinstance(ast_query, ast.Insert): - stmt = self.prepare_insert(ast_query) - sql = render_dml_query(stmt, self.dialect) - elif isinstance(ast_query, ast.Update): - stmt = self.prepare_update(ast_query) - sql = render_dml_query(stmt, self.dialect) - elif isinstance(ast_query, ast.CreateTable): - stmt = self.prepare_create_table(ast_query) - sql = render_ddl_query(stmt, self.dialect) - elif isinstance(ast_query, ast.DropTables): - stmt = self.prepare_drop_table(ast_query) - sql = render_ddl_query(stmt, self.dialect) - else: - raise NotImplementedError(f'Unknown statement: {ast_query.__class__.__name__}') + stmt = self.get_query(ast_query) - return sql + sql = render_func(stmt, self.dialect) + return sql except (SQLAlchemyError, NotImplementedError) as e: if not with_failback: