Skip to content

Commit

Permalink
Merge pull request #352 from mindsdb/planner_fixes
Browse files Browse the repository at this point in the history
Mapping 'using' variables to model
  • Loading branch information
ea-rus authored Feb 22, 2024
2 parents ae53487 + bde72e3 commit 9285aa7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 21 deletions.
17 changes: 16 additions & 1 deletion mindsdb_sql/planner/plan_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def plan(self, query):
):
query2 = copy.deepcopy(query)
query2.from_table = None
query2.using = None
sup_select = QueryStep(query2, from_table=join_step.result)
self.planner.plan.add_step(sup_select)
return sup_select
Expand Down Expand Up @@ -429,11 +430,25 @@ def process_predictor(self, item, query_in):
# exclude condition
el._orig_node.args = [Constant(0), Constant(0)]

# params for model
model_params = None

if query_in.using is not None:
model_params = {}
for param, value in query_in.using.items():
if '.' in param:
alias = param.split('.')[0]
if (alias,) in item.aliases:
new_param = '.'.join(param.split('.')[1:])
model_params[new_param] = value
else:
model_params[param] = value

predictor_step = self.planner.plan.add_step(ApplyPredictorStep(
namespace=item.integration,
dataframe=data_step.result,
predictor=item.table,
params=query_in.using,
params=model_params,
row_dict=row_dict,
))
self.step_stack.append(predictor_step)
41 changes: 23 additions & 18 deletions mindsdb_sql/render/sqlalchemy_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,28 +592,33 @@ def prepare_update(self, ast_query):

return stmt

def get_query(self, ast_query):
if isinstance(ast_query, ast.Select):
stmt = self.prepare_select(ast_query)
elif isinstance(ast_query, ast.Insert):
stmt = self.prepare_insert(ast_query)
elif isinstance(ast_query, ast.Update):
stmt = self.prepare_update(ast_query)
elif isinstance(ast_query, ast.CreateTable):
stmt = self.prepare_create_table(ast_query)
elif isinstance(ast_query, ast.DropTables):
stmt = self.prepare_drop_table(ast_query)
else:
raise NotImplementedError(f'Unknown statement: {ast_query.__class__.__name__}')
return stmt

def get_string(self, ast_query, with_failback=True):
if isinstance(ast_query, (ast.CreateTable, ast.DropTables)):
render_func = render_ddl_query
else:
render_func = render_dml_query

try:
if isinstance(ast_query, ast.Select):
stmt = self.prepare_select(ast_query)
sql = render_dml_query(stmt, self.dialect)
elif isinstance(ast_query, ast.Insert):
stmt = self.prepare_insert(ast_query)
sql = render_dml_query(stmt, self.dialect)
elif isinstance(ast_query, ast.Update):
stmt = self.prepare_update(ast_query)
sql = render_dml_query(stmt, self.dialect)
elif isinstance(ast_query, ast.CreateTable):
stmt = self.prepare_create_table(ast_query)
sql = render_ddl_query(stmt, self.dialect)
elif isinstance(ast_query, ast.DropTables):
stmt = self.prepare_drop_table(ast_query)
sql = render_ddl_query(stmt, self.dialect)
else:
raise NotImplementedError(f'Unknown statement: {ast_query.__class__.__name__}')
stmt = self.get_query(ast_query)

return sql
sql = render_func(stmt, self.dialect)

return sql

except (SQLAlchemyError, NotImplementedError) as e:
if not with_failback:
Expand Down
9 changes: 7 additions & 2 deletions tests/test_planner/test_join_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,9 @@ def test_model_join_model(self):
join mindsdb.pred m
join mindsdb.pred m2
where m.a = 2
using m.param1 = 'a',
m2.param2 = 'b',
param3 = 'c'
'''

subquery = parse_sql("""
Expand All @@ -750,13 +753,15 @@ def test_model_join_model(self):
FetchDataframeStep(integration='int',
query=parse_sql('select * from tab1 as t')),
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0),
predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': 2}),
predictor=Identifier('pred', alias=Identifier('m')),
row_dict={ 'a': 2 }, params={ 'param1': 'a', 'param3': 'c' }),
JoinStep(left=Result(0), right=Result(1),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
join_type=JoinType.JOIN)),
ApplyPredictorStep(namespace='mindsdb', dataframe=Result(2),
predictor=Identifier('pred', alias=Identifier('m2'))),
predictor=Identifier('pred', alias=Identifier('m2')),
params={ 'param2': 'b', 'param3': 'c' }),
JoinStep(left=Result(2), right=Result(3),
query=Join(left=Identifier('tab1'),
right=Identifier('tab2'),
Expand Down

0 comments on commit 9285aa7

Please sign in to comment.