diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 61b2780c..20a15602 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,6 +32,8 @@ jobs: coverage: needs: test runs-on: ubuntu-latest + permissions: + pull-requests: write steps: - uses: actions/checkout@v3 - name: Set up Python 3.8 diff --git a/mindsdb_sql/__about__.py b/mindsdb_sql/__about__.py index e6aef77f..ab33edb8 100644 --- a/mindsdb_sql/__about__.py +++ b/mindsdb_sql/__about__.py @@ -1,6 +1,6 @@ __title__ = 'mindsdb_sql' __package_name__ = 'mindsdb_sql' -__version__ = '0.9.0' +__version__ = '0.10.0' __description__ = "Pure python SQL parser" __email__ = "jorge@mindsdb.com" __author__ = 'MindsDB Inc' diff --git a/mindsdb_sql/parser/ast/__init__.py b/mindsdb_sql/parser/ast/__init__.py index 53393c25..7743b7a7 100644 --- a/mindsdb_sql/parser/ast/__init__.py +++ b/mindsdb_sql/parser/ast/__init__.py @@ -14,6 +14,6 @@ from .delete import * from .drop import * from .create import * +from .variable import * -from mindsdb_sql.parser.dialects.mysql.variable import Variable from mindsdb_sql.parser.dialects.mindsdb.latest import Latest diff --git a/mindsdb_sql/parser/ast/select/__init__.py b/mindsdb_sql/parser/ast/select/__init__.py index e30a0add..8ba33604 100644 --- a/mindsdb_sql/parser/ast/select/__init__.py +++ b/mindsdb_sql/parser/ast/select/__init__.py @@ -1,7 +1,7 @@ from .select import Select from .common_table_expression import CommonTableExpression from .union import Union -from .constant import Constant, NullConstant, SpecialConstant, Last +from .constant import Constant, NullConstant, Last from .star import Star from .identifier import Identifier from .join import Join diff --git a/mindsdb_sql/parser/ast/select/constant.py b/mindsdb_sql/parser/ast/select/constant.py index eef0b463..0b31af1e 100644 --- a/mindsdb_sql/parser/ast/select/constant.py +++ b/mindsdb_sql/parser/ast/select/constant.py @@ -4,16 +4,17 @@ class Constant(ASTNode): - def __init__(self, value, *args, **kwargs): + def __init__(self, value, with_quotes=True, *args, **kwargs): super().__init__(*args, **kwargs) self.value = value + self.with_quotes = with_quotes def to_tree(self, *args, level=0, **kwargs): alias_str = f', alias={self.alias.to_tree()}' if self.alias else '' return indent(level) + f'Constant(value={repr(self.value)}{alias_str})' def get_string(self, *args, **kwargs): - if isinstance(self.value, str): + if isinstance(self.value, str) and self.with_quotes: out_str = f"\'{self.value}\'" elif isinstance(self.value, bool): out_str = 'TRUE' if self.value else 'FALSE' @@ -29,26 +30,12 @@ def __init__(self, *args, **kwargs): super().__init__(value=None, *args, **kwargs) def to_tree(self, *args, level=0, **kwargs): - return '\t'*level + 'NullConstant()' + return '\t'*level + 'NullConstant()' def get_string(self, *args, **kwargs): return 'NULL' -# TODO replace it to just Constant? -# DEFAULT -class SpecialConstant(ASTNode): - def __init__(self, name, *args, **kwargs): - super().__init__(*args, **kwargs) - self.name = name - - def to_tree(self, *args, level=0, **kwargs): - return indent(level) + f'SpecialConstant(name={self.name})' - - def get_string(self, *args, **kwargs): - return self.name - - class Last(Constant): def __init__(self, *args, **kwargs): self.value = 'last' diff --git a/mindsdb_sql/parser/ast/select/native_query.py b/mindsdb_sql/parser/ast/select/native_query.py index bc531cce..1268af9c 100644 --- a/mindsdb_sql/parser/ast/select/native_query.py +++ b/mindsdb_sql/parser/ast/select/native_query.py @@ -18,4 +18,7 @@ def to_tree(self, *args, level=0, **kwargs): f'NativeQuery(integration={self.integration.to_string()}, query="{self.query}")' def get_string(self, *args, **kwargs): - return f'{self.integration.to_string()} ({self.query})' + return f'({self.query})' + + def __repr__(self): + return f'{self.__class__.__name__}:{self.integration.to_string()} ({self.query})' diff --git a/mindsdb_sql/parser/ast/set.py b/mindsdb_sql/parser/ast/set.py index b2f02726..f317d892 100644 --- a/mindsdb_sql/parser/ast/set.py +++ b/mindsdb_sql/parser/ast/set.py @@ -6,95 +6,120 @@ class Set(ASTNode): def __init__(self, category=None, - arg=None, + name=None, + value=None, + scope=None, params=None, + set_list=None, *args, **kwargs): super().__init__(*args, **kwargs) - self.category = category - self.arg = arg - self.params = params or {} - def to_tree(self, *args, level=0, **kwargs): - ind = indent(level) - ind1 = indent(level+1) - category_str = f'category={self.category}, ' - arg_str = f'arg={self.arg.to_tree()},' if self.arg else '' - if self.params: - param_str = 'param=' + ', '.join([f'{k}:{v}' for k,v in self.params.items()]) - else: - param_str = '' - out_str = f'{ind}Set(' \ - f'{category_str}' \ - f'{arg_str} ' \ - f'{param_str}' \ - f')' - return out_str + # names / charset / transactions + self.category = category - def get_string(self, *args, **kwargs): - if self.params: - param_str = ' ' + ' '.join([f'{k} {v}' for k, v in self.params.items()]) - else: - param_str = '' - - if isinstance(self.arg, Tuple): - arg_str = ', '.join([str(i) for i in self.arg.items]) - else: - arg_str = f' {str(self.arg)}' if self.arg else '' - return f'SET {self.category if self.category else ""}{arg_str}{param_str}' + # name for variable assigment. category is None it this case + self.name = name + self.value = value + self.params = params or {} -class SetTransaction(ASTNode): - def __init__(self, - isolation_level=None, - access_mode=None, - scope=None, - *args, **kwargs): - super().__init__(*args, **kwargs) + # global / session / ... + self.scope = scope - if isolation_level is not None: - isolation_level = isolation_level.upper() - if access_mode is not None: - access_mode = access_mode.upper() - if scope is not None: - scope = scope.upper() + # contents all set subcommands + self.set_list = set_list - self.scope = scope - self.access_mode = access_mode - self.isolation_level = isolation_level def to_tree(self, *args, level=0, **kwargs): - ind = indent(level) - if self.scope is None: - scope_str = '' + if self.set_list is not None: + items = [set.render() for set in self.set_list] else: - scope_str = f'scope={self.scope}, ' + items = self.render() - properties = [] - if self.isolation_level is not None: - properties.append('ISOLATION LEVEL ' + self.isolation_level) - if self.access_mode is not None: - properties.append(self.access_mode) - prop_str = ', '.join(properties) + ind = indent(level) - out_str = f'{ind}SetTransaction(' \ - f'{scope_str}' \ - f'properties=[{prop_str}]' \ - f'\n{ind})' - return out_str + return f'{ind}Set(items={items})' def get_string(self, *args, **kwargs): - properties = [] - if self.isolation_level is not None: - properties.append('ISOLATION LEVEL ' + self.isolation_level) - if self.access_mode is not None: - properties.append(self.access_mode) + return 'SET ' + self.render() - prop_str = ', '.join(properties) + def render(self): + if self.set_list is not None: + render_list = [set.render() for set in self.set_list] + return ', '.join(render_list) - if self.scope is None: - scope_str = '' + if self.params: + param_str = ' ' + ' '.join([f'{k} {v}' for k, v in self.params.items()]) else: - scope_str = self.scope + ' ' + param_str = '' - return f'SET {scope_str}TRANSACTION {prop_str}' + if self.name is not None: + # category should be empty + content = f'{self.name.to_string()}={self.value.to_string()}' + elif self.value is not None: + content = f'{self.category} {self.value.to_string()}' + else: + content = f'{self.category}' + + scope = '' + if self.scope is not None: + scope = f'{self.scope} ' + + return f'{scope}{content}{param_str}' + + +# class SetTransaction(ASTNode): +# def __init__(self, +# isolation_level=None, +# access_mode=None, +# scope=None, +# *args, **kwargs): +# super().__init__(*args, **kwargs) +# +# if isolation_level is not None: +# isolation_level = isolation_level.upper() +# if access_mode is not None: +# access_mode = access_mode.upper() +# if scope is not None: +# scope = scope.upper() +# +# self.scope = scope +# self.access_mode = access_mode +# self.isolation_level = isolation_level +# +# def to_tree(self, *args, level=0, **kwargs): +# ind = indent(level) +# if self.scope is None: +# scope_str = '' +# else: +# scope_str = f'scope={self.scope}, ' +# +# properties = [] +# if self.isolation_level is not None: +# properties.append('ISOLATION LEVEL ' + self.isolation_level) +# if self.access_mode is not None: +# properties.append(self.access_mode) +# prop_str = ', '.join(properties) +# +# out_str = f'{ind}SetTransaction(' \ +# f'{scope_str}' \ +# f'properties=[{prop_str}]' \ +# f'\n{ind})' +# return out_str +# +# def get_string(self, *args, **kwargs): +# properties = [] +# if self.isolation_level is not None: +# properties.append('ISOLATION LEVEL ' + self.isolation_level) +# if self.access_mode is not None: +# properties.append(self.access_mode) +# +# prop_str = ', '.join(properties) +# +# if self.scope is None: +# scope_str = '' +# else: +# scope_str = self.scope + ' ' +# +# return f'SET {scope_str}TRANSACTION {prop_str}' diff --git a/mindsdb_sql/parser/dialects/mysql/variable.py b/mindsdb_sql/parser/ast/variable.py similarity index 100% rename from mindsdb_sql/parser/dialects/mysql/variable.py rename to mindsdb_sql/parser/ast/variable.py diff --git a/mindsdb_sql/parser/dialects/mindsdb/lexer.py b/mindsdb_sql/parser/dialects/mindsdb/lexer.py index 752dddaf..405fde10 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/lexer.py +++ b/mindsdb_sql/parser/dialects/mindsdb/lexer.py @@ -42,8 +42,10 @@ class MindsDBLexer(Lexer): VARIABLES, SESSION, STATUS, GLOBAL, PROCEDURE, FUNCTION, INDEX, WARNINGS, ENGINES, CHARSET, COLLATION, PLUGINS, CHARACTER, - PERSIST, PERSIST_ONLY, DEFAULT, + PERSIST, PERSIST_ONLY, IF_EXISTS, IF_NOT_EXISTS, COLUMNS, FIELDS, COLLATE, SEARCH_PATH, + VARIABLE, SYSTEM_VARIABLE, + # SELECT Keywords WITH, SELECT, DISTINCT, FROM, WHERE, AS, LIMIT, OFFSET, ASC, DESC, NULLS_FIRST, NULLS_LAST, @@ -170,7 +172,6 @@ class MindsDBLexer(Lexer): PLUGINS = r'\bPLUGINS\b' PERSIST = r'\bPERSIST\b' PERSIST_ONLY = r'\bPERSIST_ONLY\b' - DEFAULT = r'\bDEFAULT\b' IF_EXISTS = r'\bIF[\s]+EXISTS\b' IF_NOT_EXISTS = r'\bIF[\s]+NOT[\s]+EXISTS\b' COLUMNS = r'\bCOLUMNS\b' @@ -295,7 +296,7 @@ class MindsDBLexer(Lexer): def ID(self, t): return t - @_(r'\d+\.\d*') + @_(r'\d+\.\d+') def FLOAT(self, t): return t @@ -303,14 +304,49 @@ def FLOAT(self, t): def INTEGER(self, t): return t - @_(r"'[^']*'") + @_(r"'(?:[^\'\\]|\\.)*'") def QUOTE_STRING(self, t): + t.value = t.value.replace('\\"', '"').replace("\\'", "'") return t - @_(r'"[^"]*"') + @_(r'"(?:[^\"\\]|\\.)*"') def DQUOTE_STRING(self, t): + t.value = t.value.replace('\\"', '"').replace("\\'", "'") return t @_(r'\n+') def ignore_newline(self, t): self.lineno += len(t.value) + + @_(r'@[a-zA-Z_.$]+', + r"@'[a-zA-Z_.$][^']*'", + r"@`[a-zA-Z_.$][^`]*`", + r'@"[a-zA-Z_.$][^"]*"' + ) + def VARIABLE(self, t): + t.value = t.value.lstrip('@') + + if t.value[0] == '"': + t.value = t.value.strip('\"') + elif t.value[0] == "'": + t.value = t.value.strip('\'') + elif t.value[0] == "`": + t.value = t.value.strip('`') + return t + + @_(r'@@[a-zA-Z_.$]+', + r"@@'[a-zA-Z_.$][^']*'", + r"@@`[a-zA-Z_.$][^`]*`", + r'@@"[a-zA-Z_.$][^"]*"' + ) + def SYSTEM_VARIABLE(self, t): + t.value = t.value.lstrip('@') + + if t.value[0] == '"': + t.value = t.value.strip('\"') + elif t.value[0] == "'": + t.value = t.value.strip('\'') + elif t.value[0] == "`": + t.value = t.value.strip('`') + return t + diff --git a/mindsdb_sql/parser/dialects/mindsdb/parser.py b/mindsdb_sql/parser/dialects/mindsdb/parser.py index 41c010cd..616ef566 100644 --- a/mindsdb_sql/parser/dialects/mindsdb/parser.py +++ b/mindsdb_sql/parser/dialects/mindsdb/parser.py @@ -334,42 +334,57 @@ def commit_transaction(self, p): def rollback_transaction(self, p): return RollbackTransaction() - # Set - - @_('SET id identifier', - 'SET id identifier COLLATE constant', - 'SET id identifier COLLATE DEFAULT', - 'SET id constant', - 'SET id constant COLLATE constant', - 'SET id constant COLLATE DEFAULT') + # --- Set --- + @_('SET set_item_list') def set(self, p): - if not p.id.lower() == 'names': - raise ParsingException(f'Expected "SET names", got "SET {p.id}"') - if isinstance(p[2], Constant): - arg = Identifier(p[2].value) + set_list = p[1] + if len(set_list) == 1: + return set_list[0] + return Set(set_list=set_list) + + @_('set_item', + 'set_item_list COMMA set_item') + def set_item_list(self, p): + arr = getattr(p, 'set_item_list', []) + arr.append(p.set_item) + return arr + + # set names + @_('id id', + 'id constant', + 'id id COLLATE constant', + 'id id COLLATE id', + 'id constant COLLATE constant', + 'id constant COLLATE id') + def set_item(self, p): + category = p[0] + if category.lower() != 'names': + raise ParsingException(f'Expected "SET names", got "SET {category}"') + if isinstance(p[1], Constant): + value = p[1] else: - # is identifier - arg = p[2] + # is id + value = Constant(p[1], with_quotes=False) params = {} if hasattr(p, 'COLLATE'): - if isinstance(p[4], Constant): - val = p[4] + if isinstance(p[3], Constant): + val = p[3] else: - val = SpecialConstant('DEFAULT') + val = Constant(p[3], with_quotes=False) params['COLLATE'] = val - return Set(category=p.id.lower(), arg=arg, params=params) + return Set(category=category, value=value, params=params) # set charset - @_('SET charset constant', - 'SET charset DEFAULT') - def set(self, p): - if hasattr(p, 'DEFAULT'): - arg = SpecialConstant('DEFAULT') + @_('charset constant', + 'charset id') + def set_item(self, p): + if hasattr(p, 'id'): + arg = Constant(p.id, with_quotes=False) else: arg = p.constant - return Set(category='CHARSET', arg=arg) + return Set(category='CHARSET', value=arg) @_('CHARACTER SET', 'CHARSET', @@ -378,29 +393,30 @@ def charset(self, p): return p[0] # set transaction - @_('SET transact_scope TRANSACTION transact_property_list', - 'SET TRANSACTION transact_property_list') - def set(self, p): + @_('set_scope TRANSACTION transact_property_list', + 'TRANSACTION transact_property_list') + def set_item(self, p): isolation_level = None access_mode = None - transact_scope = getattr(p, 'transact_scope', None) + transact_scope = getattr(p, 'set_scope', None) for prop in p.transact_property_list: if prop['type'] == 'iso_level': isolation_level = prop['value'] else: access_mode = prop['value'] - return SetTransaction( - isolation_level=isolation_level, - access_mode=access_mode, + params = {} + if isolation_level is not None: + params['isolation_level'] = isolation_level + if access_mode is not None: + params['access_mode'] = access_mode + + return Set( + category='TRANSACTION', scope=transact_scope, + params=params ) - @_('GLOBAL', - 'SESSION') - def transact_scope(self, p): - return p[0] - @_('transact_property_list COMMA transact_property') def transact_property_list(self, p): return p.transact_property_list + [p.transact_property] @@ -429,30 +445,29 @@ def transact_level(self, p): def transact_access_mode(self, p): return ' '.join([x for x in p]) - @_('SET expr_list', - 'SET set_modifier expr_list') - def set(self, p): - if len(p.expr_list) == 1: - arg = p.expr_list[0] - else: - arg = Tuple(items=p.expr_list) + @_('identifier EQUALS expr', + 'set_scope identifier EQUALS expr', + 'variable EQUALS expr', + 'set_scope variable EQUALS expr') + def set_item(self, p): - if hasattr(p, 'set_modifier'): - category = p.set_modifier - else: - category = None + scope = None + name = p[0] + if hasattr(p, 'set_scope'): + scope = p.set_scope + name=p[1] - return Set(category=category, arg=arg) + return Set(name=name, value=p.expr, scope=scope) @_('GLOBAL', 'PERSIST', 'PERSIST_ONLY', 'SESSION', ) - def set_modifier(self, p): + def set_scope(self, p): return p[0] - # Show + # --- Show --- @_('show WHERE expr') def show(self, p): command = p.show @@ -822,7 +837,11 @@ def create_anomaly_detection_model(self, p): @_('RETRAIN identifier', 'RETRAIN identifier PREDICT result_columns', 'RETRAIN identifier FROM identifier LPAREN raw_query RPAREN', - 'RETRAIN identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns') + 'RETRAIN identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns', + 'RETRAIN MODEL identifier', + 'RETRAIN MODEL identifier PREDICT result_columns', + 'RETRAIN MODEL identifier FROM identifier LPAREN raw_query RPAREN', + 'RETRAIN MODEL identifier FROM identifier LPAREN raw_query RPAREN PREDICT result_columns') def create_predictor(self, p): query_str = None if hasattr(p, 'raw_query'): @@ -841,7 +860,8 @@ def create_predictor(self, p): targets=getattr(p, 'result_columns', None) ) - @_('FINETUNE identifier FROM identifier LPAREN raw_query RPAREN') + @_('FINETUNE identifier FROM identifier LPAREN raw_query RPAREN', + 'FINETUNE MODEL identifier FROM identifier LPAREN raw_query RPAREN') def create_predictor(self, p): query_str = None if hasattr(p, 'raw_query'): @@ -935,13 +955,15 @@ def database_engine(self, p): return {'identifier':p.identifier, 'engine':engine, 'if_not_exists':p.if_not_exists_or_empty} # UNION / UNION ALL - @_('select UNION select') + @_('select UNION select', + 'union UNION select') def union(self, p): - return Union(left=p.select0, right=p.select1, unique=True) + return Union(left=p[0], right=p[2], unique=True) - @_('select UNION ALL select') + @_('select UNION ALL select', + 'union UNION ALL select',) def union(self, p): - return Union(left=p.select0, right=p.select1, unique=False) + return Union(left=p[0], right=p[3], unique=False) # tableau @_('LPAREN select RPAREN') @@ -1245,8 +1267,7 @@ def result_column(self, p): @_('expr', 'function', - 'window_function', - 'case') + 'window_function') def result_column(self, p): return p[0] @@ -1444,6 +1465,7 @@ def enumeration(self, p): 'parameter', 'constant', 'latest', + 'case', 'function') def expr(self, p): return p[0] @@ -1690,6 +1712,22 @@ def raw_query(self, p): def raw_query(self, p): return p[0] + p[1] + @_('variable') + def table_or_subquery(self, p): + return p.variable + + @_('variable') + def expr(self, p): + return p.variable + + @_('SYSTEM_VARIABLE') + def variable(self, p): + return Variable(value=p.SYSTEM_VARIABLE, is_system_var=True) + + @_('VARIABLE') + def variable(self, p): + return Variable(value=p.VARIABLE) + @_( 'IF_NOT_EXISTS', 'empty' diff --git a/mindsdb_sql/parser/dialects/mysql/__init__.py b/mindsdb_sql/parser/dialects/mysql/__init__.py index b3ce761e..874edeef 100644 --- a/mindsdb_sql/parser/dialects/mysql/__init__.py +++ b/mindsdb_sql/parser/dialects/mysql/__init__.py @@ -1,2 +1 @@ -from .variable import Variable from .show_index import ShowIndex diff --git a/mindsdb_sql/parser/dialects/mysql/parser.py b/mindsdb_sql/parser/dialects/mysql/parser.py index f2e90c74..ee6541c3 100644 --- a/mindsdb_sql/parser/dialects/mysql/parser.py +++ b/mindsdb_sql/parser/dialects/mysql/parser.py @@ -2,7 +2,6 @@ from mindsdb_sql.parser.parser import SQLParser from mindsdb_sql.parser.ast import * from mindsdb_sql.parser.dialects.mysql.lexer import MySQLLexer -from mindsdb_sql.parser.dialects.mysql.variable import Variable from mindsdb_sql.exceptions import ParsingException from mindsdb_sql.parser.utils import ensure_select_keyword_order, JoinType @@ -103,42 +102,58 @@ def commit_transaction(self, p): def rollback_transaction(self, p): return RollbackTransaction() - # Set - @_('SET id identifier') - @_('SET id identifier COLLATE constant') - @_('SET id identifier COLLATE DEFAULT') - @_('SET id constant') - @_('SET id constant COLLATE constant') - @_('SET id constant COLLATE DEFAULT') + # --- Set --- + @_('SET set_item_list') def set(self, p): - if not p.id.lower() == 'names': - raise ParsingException(f'Expected "SET names", got "SET {p.id}"') - if isinstance(p[2], Constant): - arg = Identifier(p[2].value) + set_list = p[1] + if len(set_list) == 1: + return set_list[0] + return Set(set_list=set_list) + + @_('set_item', + 'set_item_list COMMA set_item') + def set_item_list(self, p): + arr = getattr(p, 'set_item_list', []) + arr.append(p.set_item) + return arr + + # set names + @_('id id', + 'id constant', + 'id id COLLATE constant', + 'id id COLLATE id', + 'id constant COLLATE constant', + 'id constant COLLATE id') + def set_item(self, p): + category = p[0] + if category.lower() != 'names': + raise ParsingException(f'Expected "SET names", got "SET {category}"') + if isinstance(p[1], Constant): + value = p[1] else: - # is identifier - arg = p[2] + # is id + value = Constant(p[1], with_quotes=False) params = {} if hasattr(p, 'COLLATE'): - if isinstance(p[4], Constant): - val = p[4] + if isinstance(p[3], Constant): + val = p[3] else: - val = SpecialConstant('DEFAULT') + val = Constant(p[3], with_quotes=False) params['COLLATE'] = val - return Set(category=p.id.lower(), arg=arg, params=params) + return Set(category=category, value=value, params=params) # set charset - @_('SET charset constant') - @_('SET charset DEFAULT') - def set(self, p): - if hasattr(p, 'DEFAULT'): - arg = SpecialConstant('DEFAULT') + @_('charset constant', + 'charset id') + def set_item(self, p): + if hasattr(p, 'id'): + arg = Constant(p.id, with_quotes=False) else: arg = p.constant - return Set(category='CHARSET', arg=arg) + return Set(category='CHARSET', value=arg) @_('CHARACTER SET', 'CHARSET', @@ -147,29 +162,30 @@ def charset(self, p): return p[0] # set transaction - @_('SET transact_scope TRANSACTION transact_property_list') - @_('SET TRANSACTION transact_property_list') - def set(self, p): + @_('set_scope TRANSACTION transact_property_list', + 'TRANSACTION transact_property_list') + def set_item(self, p): isolation_level = None access_mode = None - transact_scope = getattr(p, 'transact_scope', None) + transact_scope = getattr(p, 'set_scope', None) for prop in p.transact_property_list: if prop['type'] == 'iso_level': isolation_level = prop['value'] else: access_mode = prop['value'] - return SetTransaction( - isolation_level=isolation_level, - access_mode=access_mode, + params = {} + if isolation_level is not None: + params['isolation_level'] = isolation_level + if access_mode is not None: + params['access_mode'] = access_mode + + return Set( + category='TRANSACTION', scope=transact_scope, + params=params ) - @_('GLOBAL', - 'SESSION') - def transact_scope(self, p): - return p[0] - @_('transact_property_list COMMA transact_property') def transact_property_list(self, p): return p.transact_property_list + [p.transact_property] @@ -182,9 +198,9 @@ def transact_property_list(self, p): 'transact_access_mode') def transact_property(self, p): if hasattr(p, 'transact_level'): - return {'type': 'iso_level', 'value': p.transact_level} + return {'type':'iso_level', 'value':p.transact_level} else: - return {'type': 'access_mode', 'value': p.transact_access_mode} + return {'type':'access_mode', 'value':p.transact_access_mode} @_('REPEATABLE READ', 'READ COMMITTED', @@ -198,30 +214,29 @@ def transact_level(self, p): def transact_access_mode(self, p): return ' '.join([x for x in p]) - @_('SET expr_list') - @_('SET set_modifier expr_list') - def set(self, p): - if len(p.expr_list) == 1: - arg = p.expr_list[0] - else: - arg = Tuple(items=p.expr_list) + @_('identifier EQUALS expr', + 'set_scope identifier EQUALS expr', + 'variable EQUALS expr', + 'set_scope variable EQUALS expr') + def set_item(self, p): - if hasattr(p, 'set_modifier'): - category = p.set_modifier - else: - category = None + scope = None + name = p[0] + if hasattr(p, 'set_scope'): + scope = p.set_scope + name=p[1] - return Set(category=category, arg=arg) + return Set(name=name, value=p.expr, scope=scope) @_('GLOBAL', 'PERSIST', 'PERSIST_ONLY', 'SESSION', ) - def set_modifier(self, p): + def set_scope(self, p): return p[0] - # Show + # --- Show --- @_('show WHERE expr') def show(self, p): command = p.show diff --git a/mindsdb_sql/parser/lexer.py b/mindsdb_sql/parser/lexer.py index 1df0ee87..6033689e 100644 --- a/mindsdb_sql/parser/lexer.py +++ b/mindsdb_sql/parser/lexer.py @@ -25,7 +25,7 @@ class SQLLexer(Lexer): VIEW, VARIABLES, SESSION, STATUS, GLOBAL, PROCEDURE, FUNCTION, INDEX, WARNINGS, ENGINES, CHARSET, COLLATION, PLUGINS, CHARACTER, - PERSIST, PERSIST_ONLY, DEFAULT, + PERSIST, PERSIST_ONLY, IF_EXISTS, COLUMNS, FIELDS, COLLATE, # SELECT Keywords @@ -109,7 +109,6 @@ class SQLLexer(Lexer): PLUGINS = r'\bPLUGINS\b' PERSIST = r'\bPERSIST\b' PERSIST_ONLY = r'\bPERSIST_ONLY\b' - DEFAULT = r'\bDEFAULT\b' IF_EXISTS = r'\bIF[\s]+EXISTS\b' COLUMNS = r'\bCOLUMNS\b' FIELDS = r'\bFIELDS\b' diff --git a/mindsdb_sql/planner/plan_join.py b/mindsdb_sql/planner/plan_join.py new file mode 100644 index 00000000..0643946c --- /dev/null +++ b/mindsdb_sql/planner/plan_join.py @@ -0,0 +1,427 @@ +from typing import List +import copy +from dataclasses import dataclass, field + +from mindsdb_sql.exceptions import PlanningException +from mindsdb_sql.parser import ast +from mindsdb_sql.parser.ast import (Select, Identifier, BetweenOperation, Join, Star, BinaryOperation, Constant, + NativeQuery, Parameter) +from mindsdb_sql.planner.steps import (FetchDataframeStep, JoinStep, ApplyPredictorStep, SubSelectStep, QueryStep) +from mindsdb_sql.planner.utils import (query_traversal, filters_to_bin_op) +from mindsdb_sql.planner.plan_join_ts import PlanJoinTSPredictorQuery + + +@dataclass +class TableInfo: + integration: str + table: Identifier + aliases: List[str] = field(default_factory=List) + conditions: List = None + sub_select: ast.ASTNode = None + predictor_info: dict = None + + +class PlanJoin: + + def __init__(self, planner): + self.planner = planner + + def is_timeseries(self, query): + + join = query.from_table + l_predictor = self.planner.get_predictor(join.left) if isinstance(join.left, Identifier) else None + r_predictor = self.planner.get_predictor(join.right) if isinstance(join.right, Identifier) else None + if l_predictor and l_predictor.get('timeseries'): + return True + if r_predictor and r_predictor.get('timeseries'): + return True + + def check_single_integration(self, query): + query_info = self.planner.get_query_info(query) + + # can we send all query to integration? + + # one integration and not mindsdb objects in query + if ( + len(query_info['mdb_entities']) == 0 + and len(query_info['integrations']) == 1 + and 'files' not in query_info['integrations'] + and 'views' not in query_info['integrations'] + ): + + int_name = list(query_info['integrations'])[0] + # if is sql database + if self.planner.integrations.get(int_name, {}).get('class_type') != 'api': + + # send to this integration + return int_name + return None + + def plan(self, query, integration=None): + # FIXME: Tableau workaround, INFORMATION_SCHEMA with Where + # if isinstance(join.right, Identifier) \ + # and self.resolve_database_table(join.right)[0] == 'INFORMATION_SCHEMA': + # pass + + # send join to integration as is? + integration_to_send = self.check_single_integration(query) + if integration_to_send: + self.planner.prepare_integration_select(integration_to_send, query) + + last_step = self.planner.plan.add_step(FetchDataframeStep(integration=integration_to_send, query=query)) + return last_step + elif self.is_timeseries(query): + return PlanJoinTSPredictorQuery(self.planner).plan(query, integration) + else: + return PlanJoinTablesQuery(self.planner).plan(query) + + +class PlanJoinTablesQuery: + + def __init__(self, planner): + self.planner = planner + + # index to lookup tables + self.tables_idx = None + + self.step_stack = None + self.query_context = {} + + def plan(self, query): + self.tables_idx = {} + join_step = self.plan_join_tables(query) + + if ( + query.group_by is not None + or query.order_by is not None + or query.having is not None + or query.distinct is True + or query.where is not None + or query.limit is not None + or query.offset is not None + or len(query.targets) != 1 + or not isinstance(query.targets[0], Star) + ): + query2 = copy.deepcopy(query) + query2.from_table = None + sup_select = QueryStep(query2, from_table=join_step.result) + self.planner.plan.add_step(sup_select) + return sup_select + return join_step + + def resolve_table(self, table): + # gets integration for table and name to access to it + table = copy.deepcopy(table) + # get possible table aliases + aliases = [] + if table.alias is not None: + # to lowercase + parts = tuple(map(str.lower, table.alias.parts)) + aliases.append(parts) + else: + for i in range(0, len(table.parts)): + parts = table.parts[i:] + parts = tuple(map(str.lower, parts)) + aliases.append(parts) + + # try to use default namespace + integration = self.planner.default_namespace + if len(table.parts) > 0: + if table.parts[0] in self.planner.databases: + integration = table.parts.pop(0) + else: + integration = self.planner.default_namespace + + if integration is None and not hasattr(table, 'sub_select'): + raise PlanningException(f'Integration not found for: {table}') + + sub_select = getattr(table, 'sub_select', None) + + return TableInfo(integration, table, aliases, conditions=[], sub_select=sub_select) + + def get_table_for_column(self, column: Identifier): + + # to lowercase + parts = tuple(map(str.lower, column.parts[:-1])) + if parts in self.tables_idx: + return self.tables_idx[parts] + + def get_join_sequence(self, node): + sequence = [] + if isinstance(node, Identifier): + # resolve identifier + + table_info = self.resolve_table(node) + for alias in table_info.aliases: + self.tables_idx[alias] = table_info + + table_info.predictor_info = self.planner.get_predictor(node) + + sequence.append(table_info) + + elif isinstance(node, Join): + # create sequence: 1)table1, 2)table2, 3)join 1 2, 4)table 3, 5)join 3 4 + + # put all tables before + sequence2 = self.get_join_sequence(node.left) + for item in sequence2: + sequence.append(item) + + sequence2 = self.get_join_sequence(node.right) + if len(sequence2) != 1: + raise PlanningException('Unexpected join nesting behavior') + + # put next table + sequence.append(sequence2[0]) + + # put join + sequence.append(node) + + else: + raise NotImplementedError() + return sequence + + def check_node_condition(self, node): + + col_idx = 0 + if len(node.args) == 2: + if not isinstance(node.args[col_idx], Identifier): + # try to use second arg, could be: 'x'=col + col_idx = 1 + + # check the case col constant, col between constant and constant + for i, arg in enumerate(node.args): + if i == col_idx: + if not isinstance(arg, Identifier): + return + else: + if not isinstance(arg, (Constant, Parameter)): + return + + # checked, find table and store condition + + node2 = copy.deepcopy(node) + + arg1 = node2.args[col_idx] + + if len(arg1.parts) < 2: + return + + table_info = self.get_table_for_column(arg1) + if table_info is None: + raise PlanningException(f'Table not found for identifier: {arg1.to_string()}') + + # keep only column name + arg1.parts = [arg1.parts[-1]] + + node2._orig_node = node + table_info.conditions.append(node2) + + def check_query_conditions(self, query): + # get conditions for tables + binary_ops = [] + + def _check_node_condition(node, **kwargs): + if isinstance(node, BetweenOperation): + self.check_node_condition(node) + + if isinstance(node, BinaryOperation): + binary_ops.append(node.op) + + self.check_node_condition(node) + + query_traversal(query.where, _check_node_condition) + + self.query_context['binary_ops'] = binary_ops + + def check_use_limit(self, query_in, join_sequence): + # use limit for first table? + # if only models + use_limit = False + if query_in.having is None or query_in.group_by is None and query_in.limit is not None: + + join = None + use_limit = True + for item in join_sequence: + if isinstance(item, TableInfo): + if item.predictor_info is None and item.sub_select is None: + if join is not None: + if join.join_type.upper() != 'LEFT JOIN': + use_limit = False + elif isinstance(item, Join): + join = item + self.query_context['use_limit'] = use_limit + + def plan_join_tables(self, query_in): + + # plan all nested selects in 'where' + find_selects = self.planner.get_nested_selects_plan_fnc(self.planner.default_namespace, force=True) + query_in.targets = query_traversal(query_in.targets, find_selects) + query_traversal(query_in.where, find_selects) + + query = copy.deepcopy(query_in) + + # replace sub selects, with identifiers with links to original selects + def replace_subselects(node, **args): + if isinstance(node, Select) or isinstance(node, NativeQuery) or isinstance(node, ast.Data): + name = f't_{id(node)}' + node2 = Identifier(name, alias=node.alias) + + # save in attribute + if isinstance(node, NativeQuery) or isinstance(node, ast.Data): + # wrap to select + node = Select(targets=[Star()], from_table=node) + node2.sub_select = node + return node2 + + query_traversal(query.from_table, replace_subselects) + + # get all join tables, form join sequence + join_sequence = self.get_join_sequence(query.from_table) + + # find tables for identifiers used in query + def _check_identifiers(node, is_table, **kwargs): + if not is_table and isinstance(node, Identifier): + if len(node.parts) > 1: + table_info = self.get_table_for_column(node) + if table_info is None: + raise PlanningException(f'Table not found for identifier: {node.to_string()}') + + # # replace identifies name + col_parts = list(table_info.aliases[-1]) + col_parts.append(node.parts[-1]) + node.parts = col_parts + + query_traversal(query, _check_identifiers) + + self.check_query_conditions(query) + + # workaround for 'model join table': swap tables: + if len(join_sequence) == 3 and join_sequence[0].predictor_info is not None: + join_sequence = [join_sequence[1], join_sequence[0], join_sequence[2]] + + self.check_use_limit(query_in, join_sequence) + + # create plan + # TODO add optimization: one integration without predictor + + self.step_stack = [] + for item in join_sequence: + if isinstance(item, TableInfo): + + if item.sub_select is not None: + self.process_subselect(item) + elif item.predictor_info is not None: + self.process_predictor(item, query_in) + else: + # is table + self.process_table(item, query_in) + + elif isinstance(item, Join): + step_right = self.step_stack.pop() + step_left = self.step_stack.pop() + + new_join = copy.deepcopy(item) + + # TODO + new_join.left = Identifier('tab1') + new_join.right = Identifier('tab2') + new_join.implicit = False + + step = self.planner.plan.add_step(JoinStep(left=step_left.result, right=step_right.result, query=new_join)) + + self.step_stack.append(step) + + query_in.where = query.where + return self.step_stack.pop() + + def process_subselect(self, item): + # is sub select + item.sub_select.alias = None + item.sub_select.parentheses = False + step = self.planner.plan_select(item.sub_select) + + where = filters_to_bin_op(item.conditions) + + # apply table alias + query2 = Select(targets=[Star()], where=where) + if item.table.alias is None: + raise PlanningException(f'Subselect in join have to be aliased: {item.sub_select.to_string()}') + table_name = item.table.alias.parts[-1] + + add_absent_cols = False + if hasattr(item.sub_select, 'from_table') and \ + isinstance(item.sub_select.from_table, ast.Data): + add_absent_cols = True + + step2 = SubSelectStep(query2, step.result, table_name=table_name, add_absent_cols=add_absent_cols) + step2 = self.planner.plan.add_step(step2) + self.step_stack.append(step2) + + def process_table(self, item, query_in): + query2 = Select(from_table=item.table, targets=[Star()]) + # parts = tuple(map(str.lower, table_name.parts)) + conditions = item.conditions + if 'or' in self.query_context['binary_ops']: + # not use conditions + conditions = [] + + if self.query_context['use_limit']: + order_by = None + if query_in.order_by is not None: + order_by = [] + # all order column be from this table + for col in query_in.order_by: + if self.get_table_for_column(col.field).table != item.table: + order_by = False + col = copy.deepcopy(col) + col.field.parts = [col.field.parts[-1]] + order_by.append(col) + + if order_by is not False: + # copy limit from upper query + query2.limit = query_in.limit + # move offset from upper query + query2.offset = query_in.offset + query_in.offset = None + # copy order + query2.order_by = order_by + + self.query_context['use_limit'] = False + for cond in conditions: + if query2.where is not None: + query2.where = BinaryOperation('and', args=[query2.where, cond]) + else: + query2.where = cond + + # step = self.planner.get_integration_select_step(query2) + step = FetchDataframeStep(integration=item.integration, query=query2) + self.planner.plan.add_step(step) + self.step_stack.append(step) + + def process_predictor(self, item, query_in): + if len(self.step_stack) == 0: + raise NotImplementedError("Predictor can't be first element of join syntax") + if item.predictor_info.get('timeseries'): + raise NotImplementedError("TS predictor is not supported here yet") + data_step = self.step_stack[-1] + row_dict = None + if item.conditions: + row_dict = {} + for el in item.conditions: + if isinstance(el.args[0], Identifier) and el.op == '=': + + if isinstance(el.args[1], (Constant, Parameter)): + row_dict[el.args[0].parts[-1]] = el.args[1].value + + # exclude condition + item.conditions[0]._orig_node.args = [Constant(0), Constant(0)] + + predictor_step = self.planner.plan.add_step(ApplyPredictorStep( + namespace=item.integration, + dataframe=data_step.result, + predictor=item.table, + params=query_in.using, + row_dict=row_dict, + )) + self.step_stack.append(predictor_step) diff --git a/mindsdb_sql/planner/plan_join_ts.py b/mindsdb_sql/planner/plan_join_ts.py new file mode 100644 index 00000000..e0d85c2f --- /dev/null +++ b/mindsdb_sql/planner/plan_join_ts.py @@ -0,0 +1,373 @@ +import copy + +from mindsdb_sql import Latest, OrderBy, NullConstant +from mindsdb_sql.exceptions import PlanningException +from mindsdb_sql.parser.ast import (Select, Identifier, BetweenOperation, Join, Star, BinaryOperation, Constant) +from mindsdb_sql.planner import utils +from mindsdb_sql.planner.steps import (JoinStep, LimitOffsetStep, MultipleSteps, MapReduceStep, + ApplyTimeseriesPredictorStep) +from mindsdb_sql.planner.ts_utils import validate_ts_where_condition, find_time_filter, replace_time_filter, \ + find_and_remove_time_filter, recursively_check_join_identifiers_for_ambiguity +from mindsdb_sql.planner.utils import (query_traversal, ) + + +class PlanJoinTSPredictorQuery: + + def __init__(self, planner): + self.planner = planner + + def adapt_dbt_query(self, query, integration): + orig_query = query + + join = query.from_table + join_left = join.left + + # dbt query. + + # move latest into subquery + moved_conditions = [] + + def move_latest(node, **kwargs): + if isinstance(node, BinaryOperation): + if Latest() in node.args: + for arg in node.args: + if isinstance(arg, Identifier): + # remove table alias + arg.parts = [arg.parts[-1]] + moved_conditions.append(node) + + query_traversal(query.where, move_latest) + + # TODO make project step from query.target + + # TODO support complex query. Only one table is supported at the moment. + # if not isinstance(join_left.from_table, Identifier): + # raise PlanningException(f'Statement not supported: {query.to_string()}') + + # move properties to upper query + query = join_left + + if query.from_table.alias is not None: + table_alias = [query.from_table.alias.parts[0]] + else: + table_alias = query.from_table.parts + + # add latest to query.where + for cond in moved_conditions: + if query.where is not None: + query.where = BinaryOperation('and', args=[query.where, cond]) + else: + query.where = cond + + def add_aliases(node, is_table, **kwargs): + if not is_table and isinstance(node, Identifier): + if len(node.parts) == 1: + # add table alias to field + node.parts = table_alias + node.parts + + query_traversal(query.where, add_aliases) + + if isinstance(query.from_table, Identifier): + # DBT workaround: allow use tables without integration. + # if table.part[0] not in integration - take integration name from create table command + if ( + integration is not None + and query.from_table.parts[0] not in self.planner.databases + ): + # add integration name to table + query.from_table.parts.insert(0, integration) + + join_left = join_left.from_table + + if orig_query.limit is not None: + if query.limit is None or query.limit.value > orig_query.limit.value: + query.limit = orig_query.limit + query.parentheses = False + query.alias = None + + return query, join_left + + def get_aliased_fields(self, targets): + # get aliases from select target + aliased_fields = {} + for target in targets: + if target.alias is not None: + aliased_fields[target.alias.to_string()] = target + return aliased_fields + + def plan_fetch_timeseries_partitions(self, query, table, predictor_group_by_names): + targets = [ + Identifier(column) + for column in predictor_group_by_names + ] + + query = Select( + distinct=True, + targets=targets, + from_table=table, + where=query.where, + modifiers=query.modifiers, + ) + select_step = self.planner.plan_integration_select(query) + return select_step + + def plan(self, query, integration=None): + # integration is for dbt only + + join = query.from_table + join_left = join.left + join_right = join.right + + predictor_is_left = False + if self.planner.is_predictor(join_left): + # predictor is in the left, put it in the right + join_left, join_right = join_right, join_left + predictor_is_left = True + + if self.planner.is_predictor(join_left): + # in the left is also predictor + raise PlanningException(f'Can\'t join two predictors {join_left} and {join_left}') + + orig_query = query + # dbt query? + if isinstance(join_left, Select) and isinstance(join_left.from_table, Identifier): + query, join_left = self.adapt_dbt_query(query, integration) + + predictor_namespace, predictor = self.planner.get_predictor_namespace_and_name_from_identifier(join_right) + table = join_left + + aliased_fields = self.get_aliased_fields(query.targets) + + recursively_check_join_identifiers_for_ambiguity(query.where) + recursively_check_join_identifiers_for_ambiguity(query.group_by, aliased_fields=aliased_fields) + recursively_check_join_identifiers_for_ambiguity(query.having) + recursively_check_join_identifiers_for_ambiguity(query.order_by, aliased_fields=aliased_fields) + + predictor_steps = self.plan_timeseries_predictor(query, table, predictor_namespace, predictor) + + # add join + # Update reference + + left = Identifier(predictor_steps['predictor'].result.ref_name) + right = Identifier(predictor_steps['data'].result.ref_name) + + if not predictor_is_left: + # swap join + left, right = right, left + new_join = Join(left=left, right=right, join_type=join.join_type) + + left = predictor_steps['predictor'].result + right = predictor_steps['data'].result + if not predictor_is_left: + # swap join + left, right = right, left + + last_step = self.planner.plan.add_step(JoinStep(left=left, right=right, query=new_join)) + + # limit from timeseries + if predictor_steps.get('saved_limit'): + last_step = self.planner.plan.add_step(LimitOffsetStep(dataframe=last_step.result, + limit=predictor_steps['saved_limit'])) + + return self.planner.plan_project(orig_query, last_step.result) + + def plan_timeseries_predictor(self, query, table, predictor_namespace, predictor): + + predictor_metadata = self.planner.get_predictor(predictor) + + predictor_time_column_name = predictor_metadata['order_by_column'] + predictor_group_by_names = predictor_metadata['group_by_columns'] + if predictor_group_by_names is None: + predictor_group_by_names = [] + predictor_window = predictor_metadata['window'] + + if query.order_by: + raise PlanningException( + f'Can\'t provide ORDER BY to time series predictor, it will be taken from predictor settings. Found: {query.order_by}') + + saved_limit = None + if query.limit is not None: + saved_limit = query.limit.value + + if query.group_by or query.having or query.offset: + raise PlanningException(f'Unsupported query to timeseries predictor: {str(query)}') + + allowed_columns = [predictor_time_column_name.lower()] + if len(predictor_group_by_names) > 0: + allowed_columns += [i.lower() for i in predictor_group_by_names] + validate_ts_where_condition(query.where, allowed_columns=allowed_columns) + + time_filter = find_time_filter(query.where, time_column_name=predictor_time_column_name) + + order_by = [OrderBy(Identifier(parts=[predictor_time_column_name]), direction='DESC')] + + preparation_where = copy.deepcopy(query.where) + + query_modifiers = query.modifiers + + # add {order_by_field} is not null + def add_order_not_null(condition): + order_field_not_null = BinaryOperation(op='is not', args=[ + Identifier(parts=[predictor_time_column_name]), + NullConstant() + ]) + if condition is not None: + condition = BinaryOperation(op='and', args=[ + condition, + order_field_not_null + ]) + else: + condition = order_field_not_null + return condition + + preparation_where2 = copy.deepcopy(preparation_where) + preparation_where = add_order_not_null(preparation_where) + + # Obtain integration selects + if isinstance(time_filter, BetweenOperation): + between_from = time_filter.args[1] + preparation_time_filter = BinaryOperation('<', args=[Identifier(predictor_time_column_name), between_from]) + preparation_where2 = replace_time_filter(preparation_where2, time_filter, preparation_time_filter) + integration_select_1 = Select(targets=[Star()], + from_table=table, + where=add_order_not_null(preparation_where2), + modifiers=query_modifiers, + order_by=order_by, + limit=Constant(predictor_window)) + + integration_select_2 = Select(targets=[Star()], + from_table=table, + where=preparation_where, + modifiers=query_modifiers, + order_by=order_by) + + integration_selects = [integration_select_1, integration_select_2] + elif isinstance(time_filter, BinaryOperation) and time_filter.op == '>' and time_filter.args[1] == Latest(): + integration_select = Select(targets=[Star()], + from_table=table, + where=preparation_where, + modifiers=query_modifiers, + order_by=order_by, + limit=Constant(predictor_window), + ) + integration_select.where = find_and_remove_time_filter(integration_select.where, time_filter) + integration_selects = [integration_select] + elif isinstance(time_filter, BinaryOperation) and time_filter.op == '=': + integration_select = Select(targets=[Star()], + from_table=table, + where=preparation_where, + modifiers=query_modifiers, + order_by=order_by, + limit=Constant(predictor_window), + ) + + if type(time_filter.args[1]) is Latest: + integration_select.where = find_and_remove_time_filter(integration_select.where, time_filter) + else: + time_filter_date = time_filter.args[1] + preparation_time_filter = BinaryOperation( + '<=', + args=[ + Identifier(predictor_time_column_name), + time_filter_date + ] + ) + integration_select.where = add_order_not_null( + replace_time_filter( + preparation_where2, time_filter, preparation_time_filter + ) + ) + time_filter.op = '>' + + integration_selects = [integration_select] + elif isinstance(time_filter, BinaryOperation) and time_filter.op in ('>', '>='): + time_filter_date = time_filter.args[1] + preparation_time_filter_op = {'>': '<=', '>=': '<'}[time_filter.op] + + preparation_time_filter = BinaryOperation(preparation_time_filter_op, args=[Identifier(predictor_time_column_name), time_filter_date]) + preparation_where2 = replace_time_filter(preparation_where2, time_filter, preparation_time_filter) + integration_select_1 = Select(targets=[Star()], + from_table=table, + where=add_order_not_null(preparation_where2), + modifiers=query_modifiers, + order_by=order_by, + limit=Constant(predictor_window)) + + integration_select_2 = Select(targets=[Star()], + from_table=table, + where=preparation_where, + modifiers=query_modifiers, + order_by=order_by) + + integration_selects = [integration_select_1, integration_select_2] + else: + integration_select = Select(targets=[Star()], + from_table=table, + where=preparation_where, + modifiers=query_modifiers, + order_by=order_by, + ) + integration_selects = [integration_select] + + if len(predictor_group_by_names) == 0: + # ts query without grouping + # one or multistep + if len(integration_selects) == 1: + select_partition_step = self.planner.get_integration_select_step(integration_selects[0]) + else: + select_partition_step = MultipleSteps( + steps=[self.planner.get_integration_select_step(s) for s in integration_selects], reduce='union') + + # fetch data step + data_step = self.planner.plan.add_step(select_partition_step) + else: + # inject $var to queries + for integration_select in integration_selects: + condition = integration_select.where + for num, column in enumerate(predictor_group_by_names): + cond = BinaryOperation('=', args=[Identifier(column), Constant(f'$var[{column}]')]) + + # join to main condition + if condition is None: + condition = cond + else: + condition = BinaryOperation('and', args=[condition, cond]) + + integration_select.where = condition + # one or multistep + if len(integration_selects) == 1: + select_partition_step = self.planner.get_integration_select_step(integration_selects[0]) + else: + select_partition_step = MultipleSteps( + steps=[self.planner.get_integration_select_step(s) for s in integration_selects], reduce='union') + + # get groping values + no_time_filter_query = copy.deepcopy(query) + no_time_filter_query.where = find_and_remove_time_filter(no_time_filter_query.where, time_filter) + select_partitions_step = self.plan_fetch_timeseries_partitions(no_time_filter_query, table, predictor_group_by_names) + + # sub-query by every grouping value + map_reduce_step = self.planner.plan.add_step(MapReduceStep(values=select_partitions_step.result, reduce='union', step=select_partition_step)) + data_step = map_reduce_step + + predictor_identifier = utils.get_predictor_name_identifier(predictor) + + params = None + if query.using is not None: + params = query.using + predictor_step = self.planner.plan.add_step( + ApplyTimeseriesPredictorStep( + output_time_filter=time_filter, + namespace=predictor_namespace, + dataframe=data_step.result, + predictor=predictor_identifier, + params=params, + ) + ) + + return { + 'predictor': predictor_step, + 'data': data_step, + 'saved_limit': saved_limit, + } + diff --git a/mindsdb_sql/planner/query_planner.py b/mindsdb_sql/planner/query_planner.py index df5201d7..cfa7a7e8 100644 --- a/mindsdb_sql/planner/query_planner.py +++ b/mindsdb_sql/planner/query_planner.py @@ -1,31 +1,25 @@ import copy -from collections import defaultdict + from mindsdb_sql.exceptions import PlanningException from mindsdb_sql.parser import ast -from mindsdb_sql.parser.ast import (Select, Identifier, Join, Star, BinaryOperation, Constant, OrderBy, - BetweenOperation, Union, NullConstant, CreateTable, Function, Insert, +from mindsdb_sql.parser.ast import (Select, Identifier, Join, Star, BinaryOperation, Constant, Union, CreateTable, + Function, Insert, Update, NativeQuery, Parameter, Delete) - -from mindsdb_sql.parser.dialects.mindsdb.latest import Latest -from mindsdb_sql.planner.steps import (FetchDataframeStep, ProjectStep, JoinStep, ApplyPredictorStep, - ApplyPredictorRowStep, FilterStep, GroupByStep, LimitOffsetStep, OrderByStep, - UnionStep, MapReduceStep, MultipleSteps, ApplyTimeseriesPredictorStep, - GetPredictorColumns, SaveToTable, InsertToTable, UpdateToTable, SubSelectStep, +from mindsdb_sql.planner import utils +from mindsdb_sql.planner.query_plan import QueryPlan +from mindsdb_sql.planner.steps import (FetchDataframeStep, ProjectStep, ApplyPredictorStep, + ApplyPredictorRowStep, UnionStep, GetPredictorColumns, SaveToTable, + InsertToTable, UpdateToTable, SubSelectStep, DeleteStep, DataStep) -from mindsdb_sql.planner.ts_utils import (validate_ts_where_condition, find_time_filter, replace_time_filter, - find_and_remove_time_filter) from mindsdb_sql.planner.utils import (disambiguate_predictor_column_identifier, get_deepest_select, recursively_extract_column_values, - recursively_check_join_identifiers_for_ambiguity, query_traversal, filters_to_bin_op) -from mindsdb_sql.planner.query_plan import QueryPlan -from mindsdb_sql.planner import utils -from .query_prepare import PreparedStatementPlanner - +from mindsdb_sql.planner.plan_join import PlanJoin +from mindsdb_sql.planner.query_prepare import PreparedStatementPlanner -class QueryPlanner(): +class QueryPlanner: def __init__(self, query=None, @@ -170,7 +164,10 @@ def _prepare_integration_select(node, is_table, is_target, parent_query, **kwarg query_traversal(query, _prepare_integration_select) def get_integration_select_step(self, select): - integration_name, table = self.resolve_database_table(select.from_table) + if isinstance(select.from_table, NativeQuery): + integration_name = select.from_table.integration.parts[-1] + else: + integration_name, table = self.resolve_database_table(select.from_table) fetch_df_select = copy.deepcopy(select) self.prepare_integration_select(integration_name, fetch_df_select) @@ -254,7 +251,6 @@ def find_selects(node, **kwargs): return find_selects - def plan_select_identifier(self, query): query_info = self.get_query_info(query) @@ -282,10 +278,29 @@ def plan_select_identifier(self, query): if len(query_info['predictors']) >= 1: # select from predictor return self.plan_select_from_predictor(query) + elif is_api_db: + return self.plan_api_db_select(query) else: # fallback to integration return self.plan_integration_select(query) + def plan_api_db_select(self, query): + # split to select from api database + # keep only limit and where + # the rest goes to outer select + query2 = Select( + targets=[Star()], + from_table=query.from_table, + where=query.where, + limit=query.limit, + ) + prev_step = self.plan_integration_select(query2) + + # clear limit and where + query.limit = None + query.where = None + return self.plan_sub_select(query, prev_step) + def plan_nested_select(self, select): query_info = self.get_query_info(select) @@ -329,12 +344,7 @@ def plan_mdb_nested_select(self, select): self.plan_select(select2) last_step = self.plan.steps[-1] - sup_select = self.sub_select_step(select, last_step) - if sup_select is not None: - self.plan.add_step(sup_select) - last_step = sup_select - - return last_step + return self.plan_sub_select(select, last_step) def get_predictor_namespace_and_name_from_identifier(self, identifier): new_identifier = copy.deepcopy(identifier) @@ -479,479 +489,28 @@ def split_filters(node, **kwargs): 'data': integration_select_step, } - def plan_fetch_timeseries_partitions(self, query, table, predictor_group_by_names): - targets = [ - Identifier(column) - for column in predictor_group_by_names - ] - - query = Select( - distinct=True, - targets=targets, - from_table=table, - where=query.where, - modifiers=query.modifiers, - ) - select_step = self.plan_integration_select(query) - return select_step - - def plan_timeseries_predictor(self, query, table, predictor_namespace, predictor): - - predictor_metadata = self.get_predictor(predictor) - - predictor_time_column_name = predictor_metadata['order_by_column'] - predictor_group_by_names = predictor_metadata['group_by_columns'] - if predictor_group_by_names is None: - predictor_group_by_names = [] - predictor_window = predictor_metadata['window'] - - if query.order_by: - raise PlanningException( - f'Can\'t provide ORDER BY to time series predictor, it will be taken from predictor settings. Found: {query.order_by}') - - saved_limit = None - if query.limit is not None: - saved_limit = query.limit.value - - if query.group_by or query.having or query.offset: - raise PlanningException(f'Unsupported query to timeseries predictor: {str(query)}') - - allowed_columns = [predictor_time_column_name.lower()] - if len(predictor_group_by_names) > 0: - allowed_columns += [i.lower() for i in predictor_group_by_names] - validate_ts_where_condition(query.where, allowed_columns=allowed_columns) - - time_filter = find_time_filter(query.where, time_column_name=predictor_time_column_name) - - order_by = [OrderBy(Identifier(parts=[predictor_time_column_name]), direction='DESC')] - - preparation_where = copy.deepcopy(query.where) - - query_modifiers = query.modifiers - - # add {order_by_field} is not null - def add_order_not_null(condition): - order_field_not_null = BinaryOperation(op='is not', args=[ - Identifier(parts=[predictor_time_column_name]), - NullConstant() - ]) - if condition is not None: - condition = BinaryOperation(op='and', args=[ - condition, - order_field_not_null - ]) - else: - condition = order_field_not_null - return condition - - preparation_where2 = copy.deepcopy(preparation_where) - preparation_where = add_order_not_null(preparation_where) - - # Obtain integration selects - if isinstance(time_filter, BetweenOperation): - between_from = time_filter.args[1] - preparation_time_filter = BinaryOperation('<', args=[Identifier(predictor_time_column_name), between_from]) - preparation_where2 = replace_time_filter(preparation_where2, time_filter, preparation_time_filter) - integration_select_1 = Select(targets=[Star()], - from_table=table, - where=add_order_not_null(preparation_where2), - modifiers=query_modifiers, - order_by=order_by, - limit=Constant(predictor_window)) - - integration_select_2 = Select(targets=[Star()], - from_table=table, - where=preparation_where, - modifiers=query_modifiers, - order_by=order_by) - - integration_selects = [integration_select_1, integration_select_2] - elif isinstance(time_filter, BinaryOperation) and time_filter.op == '>' and time_filter.args[1] == Latest(): - integration_select = Select(targets=[Star()], - from_table=table, - where=preparation_where, - modifiers=query_modifiers, - order_by=order_by, - limit=Constant(predictor_window), - ) - integration_select.where = find_and_remove_time_filter(integration_select.where, time_filter) - integration_selects = [integration_select] - elif isinstance(time_filter, BinaryOperation) and time_filter.op == '=': - integration_select = Select(targets=[Star()], - from_table=table, - where=preparation_where, - modifiers=query_modifiers, - order_by=order_by, - limit=Constant(predictor_window), - ) - - if type(time_filter.args[1]) is Latest: - integration_select.where = find_and_remove_time_filter(integration_select.where, time_filter) - else: - time_filter_date = time_filter.args[1] - preparation_time_filter = BinaryOperation( - '<=', - args=[ - Identifier(predictor_time_column_name), - time_filter_date - ] - ) - integration_select.where = add_order_not_null( - replace_time_filter( - preparation_where2, time_filter, preparation_time_filter - ) - ) - time_filter.op = '>' - - integration_selects = [integration_select] - elif isinstance(time_filter, BinaryOperation) and time_filter.op in ('>', '>='): - time_filter_date = time_filter.args[1] - preparation_time_filter_op = {'>': '<=', '>=': '<'}[time_filter.op] - - preparation_time_filter = BinaryOperation(preparation_time_filter_op, args=[Identifier(predictor_time_column_name), time_filter_date]) - preparation_where2 = replace_time_filter(preparation_where2, time_filter, preparation_time_filter) - integration_select_1 = Select(targets=[Star()], - from_table=table, - where=add_order_not_null(preparation_where2), - modifiers=query_modifiers, - order_by=order_by, - limit=Constant(predictor_window)) - - integration_select_2 = Select(targets=[Star()], - from_table=table, - where=preparation_where, - modifiers=query_modifiers, - order_by=order_by) - - integration_selects = [integration_select_1, integration_select_2] - else: - integration_select = Select(targets=[Star()], - from_table=table, - where=preparation_where, - modifiers=query_modifiers, - order_by=order_by, - ) - integration_selects = [integration_select] - - if len(predictor_group_by_names) == 0: - # ts query without grouping - # one or multistep - if len(integration_selects) == 1: - select_partition_step = self.get_integration_select_step(integration_selects[0]) - else: - select_partition_step = MultipleSteps( - steps=[self.get_integration_select_step(s) for s in integration_selects], reduce='union') - - # fetch data step - data_step = self.plan.add_step(select_partition_step) - else: - # inject $var to queries - for integration_select in integration_selects: - condition = integration_select.where - for num, column in enumerate(predictor_group_by_names): - cond = BinaryOperation('=', args=[Identifier(column), Constant(f'$var[{column}]')]) - - # join to main condition - if condition is None: - condition = cond - else: - condition = BinaryOperation('and', args=[condition, cond]) - - integration_select.where = condition - # one or multistep - if len(integration_selects) == 1: - select_partition_step = self.get_integration_select_step(integration_selects[0]) - else: - select_partition_step = MultipleSteps( - steps=[self.get_integration_select_step(s) for s in integration_selects], reduce='union') - - # get groping values - no_time_filter_query = copy.deepcopy(query) - no_time_filter_query.where = find_and_remove_time_filter(no_time_filter_query.where, time_filter) - select_partitions_step = self.plan_fetch_timeseries_partitions(no_time_filter_query, table, predictor_group_by_names) - - # sub-query by every grouping value - map_reduce_step = self.plan.add_step(MapReduceStep(values=select_partitions_step.result, reduce='union', step=select_partition_step)) - data_step = map_reduce_step - - predictor_identifier = utils.get_predictor_name_identifier(predictor) - - params = None - if query.using is not None: - params = query.using - predictor_step = self.plan.add_step( - ApplyTimeseriesPredictorStep( - output_time_filter=time_filter, - namespace=predictor_namespace, - dataframe=data_step.result, - predictor=predictor_identifier, - params=params, - ) - ) - - return { - 'predictor': predictor_step, - 'data': data_step, - 'saved_limit': saved_limit, - } - - def plan_join_tables(self, query_in): - query = copy.deepcopy(query_in) - - # replace sub selects, with identifiers with links to original selects - def replace_subselects(node, **args): - if isinstance(node, Select) or isinstance(node, NativeQuery) or isinstance(node, ast.Data): - name = f't_{id(node)}' - node2 = Identifier(name, alias=node.alias) - - # save in attribute - if isinstance(node, NativeQuery) or isinstance(node, ast.Data): - # wrap to select - node = Select(targets=[Star()], from_table=node) - node2.sub_select = node - return node2 - - query_traversal(query.from_table, replace_subselects) - - def resolve_table(table): - # gets integration for table and name to access to it - table = copy.deepcopy(table) - # get possible table aliases - aliases = [] - if table.alias is not None: - # to lowercase - parts = tuple(map(str.lower, table.alias.parts)) - aliases.append(parts) - else: - for i in range(0, len(table.parts)): - parts = table.parts[i:] - parts = tuple(map(str.lower, parts)) - aliases.append(parts) - - # try to use default namespace - integration = self.default_namespace - if len(table.parts) > 0: - if table.parts[0] in self.databases: - integration = table.parts.pop(0) - else: - integration = self.default_namespace - - if integration is None and not hasattr(table, 'sub_select'): - raise PlanningException(f'Integration not found for: {table}') - - sub_select = getattr(table, 'sub_select', None) - - return dict( - integration=integration, - table=table, - aliases=aliases, - conditions=[], - sub_select=sub_select, - ) - - # get all join tables, form join sequence - - tables_idx = {} - - def get_join_sequence(node): - sequence = [] - if isinstance(node, Identifier): - # resolve identifier - - table_info = resolve_table(node) - for alias in table_info['aliases']: - tables_idx[alias] = table_info - - table_info['predictor_info'] = self.get_predictor(node) - - sequence.append(table_info) - - elif isinstance(node, Join): - # create sequence: 1)table1, 2)table2, 3)join 1 2, 4)table 3, 5)join 3 4 - - # put all tables before - sequence2 = get_join_sequence(node.left) - for item in sequence2: - sequence.append(item) - - sequence2 = get_join_sequence(node.right) - if len(sequence2) != 1: - raise PlanningException('Unexpected join nesting behavior') - - # put next table - sequence.append(sequence2[0]) - - # put join - sequence.append(node) - - else: - raise NotImplementedError() - return sequence - - join_sequence = get_join_sequence(query.from_table) - - # get conditions for tables - binary_ops = [] - - def _check_identifiers(node, is_table, **kwargs): - if not is_table and isinstance(node, Identifier): - if len(node.parts) > 1: - parts = tuple(map(str.lower, node.parts[:-1])) - if parts not in tables_idx: - raise PlanningException(f'Table not found for identifier: {node.to_string()}') - - # # replace identifies name - col_parts = list(tables_idx[parts]['aliases'][-1]) - col_parts.append(node.parts[-1]) - node.parts = col_parts - - query_traversal(query, _check_identifiers) - - def _check_condition(node, **kwargs): - if isinstance(node, BinaryOperation): - binary_ops.append(node.op) - - node2 = copy.deepcopy(node) - arg1, arg2 = node2.args - if not isinstance(arg1, Identifier): - arg1, arg2 = arg2, arg1 - - if isinstance(arg1, Identifier) and isinstance(arg2, (Constant, Parameter)): - if len(arg1.parts) < 2: - return - - # to lowercase - parts = tuple(map(str.lower, arg1.parts[:-1])) - if parts not in tables_idx: - raise PlanningException(f'Table not found for identifier: {arg1.to_string()}') - - # keep only column name - arg1.parts = [arg1.parts[-1]] - - node2._orig_node = node - tables_idx[parts]['conditions'].append(node2) - - find_selects = self.get_nested_selects_plan_fnc(self.default_namespace, force=True) - query_traversal(query.where, find_selects) - - query_traversal(query.where, _check_condition) - - # create plan - # TODO add optimization: one integration without predictor - step_stack = [] - for item in join_sequence: - if isinstance(item, dict): - table_name = item['table'] - predictor_info = item['predictor_info'] - - if item['sub_select'] is not None: - # is sub select - item['sub_select'].alias = None - item['sub_select'].parentheses = False - step = self.plan_select(item['sub_select']) - - where = filters_to_bin_op(item['conditions']) - - # apply table alias - query2 = Select(targets=[Star()], where=where) - if item['table'].alias is None: - raise PlanningException(f'Subselect in join have to be aliased: {item["sub_select"].to_string()}') - table_name = item['table'].alias.parts[-1] - - add_absent_cols = False - if hasattr (item['sub_select'], 'from_table') and\ - isinstance(item['sub_select'].from_table, ast.Data): - add_absent_cols = True - - step2 = SubSelectStep(query2, step.result, table_name=table_name, add_absent_cols=add_absent_cols) - step2 = self.plan.add_step(step2) - step_stack.append(step2) - elif predictor_info is not None: - if len(step_stack) == 0: - raise NotImplementedError("Predictor can't be first element of join syntax") - if predictor_info.get('timeseries'): - raise NotImplementedError("TS predictor is not supported here yet") - data_step = step_stack[-1] - row_dict = None - if item['conditions']: - row_dict = {} - for el in item['conditions']: - if isinstance(el.args[0], Identifier) and el.op == '=': - - if isinstance(el.args[1], (Constant, Parameter)): - row_dict[el.args[0].parts[-1]] = el.args[1].value - - # exclude condition - item['conditions'][0]._orig_node.args = [Constant(0), Constant(0)] - - predictor_step = self.plan.add_step(ApplyPredictorStep( - namespace=item['integration'], - dataframe=data_step.result, - predictor=table_name, - params=query.using, - row_dict=row_dict, - )) - step_stack.append(predictor_step) - else: - # is table - query2 = Select(from_table=table_name, targets=[Star()]) - # parts = tuple(map(str.lower, table_name.parts)) - conditions = item['conditions'] - if 'or' in binary_ops: - # not use conditions - conditions = [] - - for cond in conditions: - if query2.where is not None: - query2.where = BinaryOperation('and', args=[query2.where, cond]) - else: - query2.where = cond - - # TODO use self.get_integration_select_step(query2) - step = FetchDataframeStep(integration=item['integration'], query=query2) - self.plan.add_step(step) - step_stack.append(step) - elif isinstance(item, Join): - step_right = step_stack.pop() - step_left = step_stack.pop() - - new_join = copy.deepcopy(item) - - # TODO - new_join.left = Identifier('tab1') - new_join.right = Identifier('tab2') - new_join.implicit = False - - step = self.plan.add_step(JoinStep(left=step_left.result, right=step_right.result, query=new_join)) - - step_stack.append(step) - - query_in.where = query.where - return step_stack.pop() - - def plan_group(self, query, last_step): - # ! is not using yet - - # check group - funcs = [] - for t in query.targets: - if isinstance(t, Function): - funcs.append(t.op.lower()) - agg_funcs = ['sum', 'min', 'max', 'avg', 'count', 'std'] - - if ( - query.having is not None - or query.group_by is not None - or set(agg_funcs) & set(funcs) - ): - # is aggregate - group_by_targets = [] - for t in query.targets: - target_copy = copy.deepcopy(t) - group_by_targets.append(target_copy) - # last_step = self.plan.steps[-1] - return GroupByStep(dataframe=last_step.result, columns=query.group_by, targets=group_by_targets) - + # def plan_group(self, query, last_step): + # # ! is not using yet + # + # # check group + # funcs = [] + # for t in query.targets: + # if isinstance(t, Function): + # funcs.append(t.op.lower()) + # agg_funcs = ['sum', 'min', 'max', 'avg', 'count', 'std'] + # + # if ( + # query.having is not None + # or query.group_by is not None + # or set(agg_funcs) & set(funcs) + # ): + # # is aggregate + # group_by_targets = [] + # for t in query.targets: + # target_copy = copy.deepcopy(t) + # group_by_targets.append(target_copy) + # # last_step = self.plan.steps[-1] + # return GroupByStep(dataframe=last_step.result, columns=query.group_by, targets=group_by_targets) def plan_project(self, query, dataframe, ignore_doubles=False): out_identifiers = [] @@ -972,215 +531,6 @@ def plan_project(self, query, dataframe, ignore_doubles=False): out_identifiers.append(new_identifier) return self.plan.add_step(ProjectStep(dataframe=dataframe, columns=out_identifiers, ignore_doubles=ignore_doubles)) - def get_aliased_fields(self, targets): - # get aliases from select target - aliased_fields = {} - for target in targets: - if target.alias is not None: - aliased_fields[target.alias.to_string()] = target - return aliased_fields - - def adapt_dbt_query(self, query, integration): - orig_query = query - - join = query.from_table - join_left = join.left - - # dbt query. - - # move latest into subquery - moved_conditions = [] - - def move_latest(node, **kwargs): - if isinstance(node, BinaryOperation): - if Latest() in node.args: - for arg in node.args: - if isinstance(arg, Identifier): - # remove table alias - arg.parts = [arg.parts[-1]] - moved_conditions.append(node) - - query_traversal(query.where, move_latest) - - # TODO make project step from query.target - - # TODO support complex query. Only one table is supported at the moment. - # if not isinstance(join_left.from_table, Identifier): - # raise PlanningException(f'Statement not supported: {query.to_string()}') - - # move properties to upper query - query = join_left - - if query.from_table.alias is not None: - table_alias = [query.from_table.alias.parts[0]] - else: - table_alias = query.from_table.parts - - # add latest to query.where - for cond in moved_conditions: - if query.where is not None: - query.where = BinaryOperation('and', args=[query.where, cond]) - else: - query.where = cond - - def add_aliases(node, is_table, **kwargs): - if not is_table and isinstance(node, Identifier): - if len(node.parts) == 1: - # add table alias to field - node.parts = table_alias + node.parts - - query_traversal(query.where, add_aliases) - - if isinstance(query.from_table, Identifier): - # DBT workaround: allow use tables without integration. - # if table.part[0] not in integration - take integration name from create table command - if ( - integration is not None - and query.from_table.parts[0] not in self.databases - ): - # add integration name to table - query.from_table.parts.insert(0, integration) - - join_left = join_left.from_table - - if orig_query.limit is not None: - if query.limit is None or query.limit.value > orig_query.limit.value: - query.limit = orig_query.limit - query.parentheses = False - query.alias = None - - return query, join_left - - def plan_join(self, query, integration=None): - orig_query = query - - join = query.from_table - join_left = join.left - join_right = join.right - - if isinstance(join_left, Select) and isinstance(join_left.from_table, Identifier): - if self.is_predictor(join_right) and self.get_predictor(join_right).get('timeseries'): - query, join_left = self.adapt_dbt_query(query, integration) - - aliased_fields = self.get_aliased_fields(query.targets) - - # check predictor - predictor = None - table = None - predictor_namespace = None - predictor_is_left = False - - if not self.is_predictor(join_right): - # predictor not in the right, swap - join_left, join_right = join_right, join_left - predictor_is_left = True - - if self.is_predictor(join_right): - # predictor is in the right now - - if self.is_predictor(join_left): - # left is predictor too - - raise PlanningException(f'Can\'t join two predictors {str(join_left.parts[0])} and {str(join_left.parts[1])}') - elif isinstance(join_left, Identifier): - # the left is table - predictor_namespace, predictor = self.get_predictor_namespace_and_name_from_identifier(join_right) - - table = join_left - - last_step = None - if predictor: - # One argument is a table, another is a predictor - # Apply mindsdb model to result of last dataframe fetch - # Then join results of applying mindsdb with table - - recursively_check_join_identifiers_for_ambiguity(query.where) - recursively_check_join_identifiers_for_ambiguity(query.group_by, aliased_fields=aliased_fields) - recursively_check_join_identifiers_for_ambiguity(query.having) - recursively_check_join_identifiers_for_ambiguity(query.order_by, aliased_fields=aliased_fields) - - if self.get_predictor(predictor).get('timeseries'): - predictor_steps = self.plan_timeseries_predictor(query, table, predictor_namespace, predictor) - else: - predictor_steps = self.plan_predictor(query, table, predictor_namespace, predictor) - - # add join - # Update reference - - left = Identifier(predictor_steps['predictor'].result.ref_name) - right = Identifier(predictor_steps['data'].result.ref_name) - - if not predictor_is_left: - # swap join - left, right = right, left - new_join = Join(left=left, right=right, join_type=join.join_type) - - left = predictor_steps['predictor'].result - right = predictor_steps['data'].result - if not predictor_is_left: - # swap join - left, right = right, left - - last_step = self.plan.add_step(JoinStep(left=left, right=right, query=new_join)) - - # limit from timeseries - if predictor_steps.get('saved_limit'): - last_step = self.plan.add_step(LimitOffsetStep(dataframe=last_step.result, - limit=predictor_steps['saved_limit'])) - - if predictor is None: - - query_info = self.get_query_info(query) - - # can we send all query to integration? - if ( - len(query_info['mdb_entities']) == 0 - and len(query_info['integrations']) == 1 - and 'files' not in query_info['integrations'] - and 'views' not in query_info['integrations'] - ): - int_name = list(query_info['integrations'])[0] - if self.integrations.get(int_name, {}).get('class_type') != 'api': - # if no predictor inside = run as is - self.prepare_integration_select(int_name, query) - - last_step = self.plan.add_step(FetchDataframeStep(integration=int_name, query=query)) - - return last_step - - # Both arguments are tables, join results of 2 dataframe fetches - - join_step = self.plan_join_tables(query) - last_step = join_step - if query.where: - # FIXME: Tableau workaround, INFORMATION_SCHEMA with Where - if isinstance(join.right, Identifier) \ - and self.resolve_database_table(join.right)[0] == 'INFORMATION_SCHEMA': - pass - else: - last_step = self.plan.add_step(FilterStep(dataframe=last_step.result, query=query.where)) - - if query.group_by: - group_by_targets = [] - for t in query.targets: - target_copy = copy.deepcopy(t) - target_copy.alias = None - group_by_targets.append(target_copy) - last_step = self.plan.add_step(GroupByStep(dataframe=last_step.result, columns=query.group_by, targets=group_by_targets)) - - if query.having: - last_step = self.plan.add_step(FilterStep(dataframe=last_step.result, query=query.having)) - - if query.order_by: - last_step = self.plan.add_step(OrderByStep(dataframe=last_step.result, order_by=query.order_by)) - - if query.limit is not None or query.offset is not None: - limit = query.limit.value if query.limit is not None else None - offset = query.offset.value if query.offset is not None else None - last_step = self.plan.add_step(LimitOffsetStep(dataframe=last_step.result, limit=limit, offset=offset)) - - return self.plan_project(orig_query, last_step.result) - def plan_create_table(self, query): if query.from_select is None: raise PlanningException(f'Not implemented "create table": {query.to_string()}') @@ -1257,26 +607,23 @@ def plan_select(self, query, integration=None): elif isinstance(from_table, Select): return self.plan_nested_select(query) elif isinstance(from_table, Join): - return self.plan_join(query, integration=integration) + plan_join = PlanJoin(self) + return plan_join.plan(query, integration) elif isinstance(from_table, NativeQuery): integration = from_table.integration.parts[0].lower() step = FetchDataframeStep(integration=integration, raw_query=from_table.query) last_step = self.plan.add_step(step) - sup_select = self.sub_select_step(query, step) - if sup_select is not None: - last_step = self.plan.add_step(sup_select) - return last_step + return self.plan_sub_select(query, last_step) + elif isinstance(from_table, ast.Data): step = DataStep(from_table.data) last_step = self.plan.add_step(step) - sup_select = self.sub_select_step(query, step, add_absent_cols=True) - if sup_select is not None: - last_step = self.plan.add_step(sup_select) - return last_step + return self.plan_sub_select(query, last_step, add_absent_cols=True) + else: raise PlanningException(f'Unsupported from_table {type(from_table)}') - def sub_select_step(self, query, prev_step, add_absent_cols=False): + def plan_sub_select(self, query, prev_step, add_absent_cols=False): if ( query.group_by is not None or query.order_by is not None @@ -1290,18 +637,27 @@ def sub_select_step(self, query, prev_step, add_absent_cols=False): ): if query.from_table.alias is not None: table_name = query.from_table.alias.parts[-1] + elif isinstance(query.from_table, Identifier): + table_name = query.from_table.parts[-1] else: table_name = None query2 = copy.deepcopy(query) query2.from_table = None - return SubSelectStep(query2, prev_step.result, table_name=table_name, add_absent_cols=add_absent_cols) + sup_select = SubSelectStep(query2, prev_step.result, table_name=table_name, add_absent_cols=add_absent_cols) + self.plan.add_step(sup_select) + return sup_select + return prev_step def plan_union(self, query): - query1 = self.plan_select(query.left) - query2 = self.plan_select(query.right) + if isinstance(query.left, Union): + step1 = self.plan_union(query.left) + else: + # it is select + step1 = self.plan_select(query.left) + step2 = self.plan_select(query.right) - return self.plan.add_step(UnionStep(left=query1.result, right=query2.result, unique=query.unique)) + return self.plan.add_step(UnionStep(left=step1.result, right=step2.result, unique=query.unique)) # method for compatibility def from_query(self, query=None): diff --git a/mindsdb_sql/planner/steps.py b/mindsdb_sql/planner/steps.py index e05c1a05..7c933f0a 100644 --- a/mindsdb_sql/planner/steps.py +++ b/mindsdb_sql/planner/steps.py @@ -48,7 +48,7 @@ def __init__(self, columns, dataframe, ignore_doubles=False, *args, **kwargs): if isinstance(dataframe, Result): self.references.append(dataframe) - +# TODO remove class FilterStep(PlanStep): """Filters some dataframe according to a query""" def __init__(self, dataframe, query, *args, **kwargs): @@ -59,7 +59,7 @@ def __init__(self, dataframe, query, *args, **kwargs): if isinstance(dataframe, Result): self.references.append(dataframe) - +# TODO remove class GroupByStep(PlanStep): """Groups output by columns and computes aggregation functions""" @@ -102,7 +102,7 @@ def __init__(self, left, right, unique, *args, **kwargs): if isinstance(right, Result): self.references.append(right) - +# TODO remove class OrderByStep(PlanStep): """Applies sorting to a dataframe""" @@ -254,6 +254,14 @@ def __init__(self, query, dataframe, table_name=None, add_absent_cols=False, *ar self.add_absent_cols = add_absent_cols +class QueryStep(PlanStep): + def __init__(self, query, from_table=None, *args, **kwargs): + """Performs query using injected dataframe""" + super().__init__(*args, **kwargs) + self.query = query + self.from_table = from_table + + class DataStep(PlanStep): def __init__(self, data, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/mindsdb_sql/planner/ts_utils.py b/mindsdb_sql/planner/ts_utils.py index 158bd1a5..f58338a0 100644 --- a/mindsdb_sql/planner/ts_utils.py +++ b/mindsdb_sql/planner/ts_utils.py @@ -1,3 +1,4 @@ +from mindsdb_sql import OrderBy from mindsdb_sql.exceptions import PlanningException from mindsdb_sql.parser.ast import Identifier, Operation, BinaryOperation, BetweenOperation @@ -69,4 +70,22 @@ def validate_ts_where_condition(op, allowed_columns, allow_and=True): if isinstance(op.args[0], Operation): validate_ts_where_condition(op.args[0], allowed_columns, allow_and=True) if isinstance(op.args[1], Operation): - validate_ts_where_condition(op.args[1], allowed_columns, allow_and=True) \ No newline at end of file + validate_ts_where_condition(op.args[1], allowed_columns, allow_and=True) + + +def recursively_check_join_identifiers_for_ambiguity(item, aliased_fields=None): + if item is None: + return + elif isinstance(item, Identifier): + if len(item.parts) == 1: + if aliased_fields is not None and item.parts[0] in aliased_fields: + # is alias + return + raise PlanningException(f'Ambigous identifier {str(item)}, provide table name for operations on a join.') + elif isinstance(item, Operation): + recursively_check_join_identifiers_for_ambiguity(item.args, aliased_fields=aliased_fields) + elif isinstance(item, OrderBy): + recursively_check_join_identifiers_for_ambiguity(item.field, aliased_fields=aliased_fields) + elif isinstance(item, list): + for arg in item: + recursively_check_join_identifiers_for_ambiguity(arg, aliased_fields=aliased_fields) diff --git a/mindsdb_sql/planner/utils.py b/mindsdb_sql/planner/utils.py index 2d4c32ca..64d7c88b 100644 --- a/mindsdb_sql/planner/utils.py +++ b/mindsdb_sql/planner/utils.py @@ -75,24 +75,6 @@ def recursively_extract_column_values(op, row_dict, predictor): raise PlanningException(f'Only \'and\' and \'=\' operations allowed in WHERE clause, found: {op.to_tree()}') -def recursively_check_join_identifiers_for_ambiguity(item, aliased_fields=None): - if item is None: - return - elif isinstance(item, Identifier): - if len(item.parts) == 1: - if aliased_fields is not None and item.parts[0] in aliased_fields: - # is alias - return - raise PlanningException(f'Ambigous identifier {str(item)}, provide table name for operations on a join.') - elif isinstance(item, Operation): - recursively_check_join_identifiers_for_ambiguity(item.args, aliased_fields=aliased_fields) - elif isinstance(item, OrderBy): - recursively_check_join_identifiers_for_ambiguity(item.field, aliased_fields=aliased_fields) - elif isinstance(item, list): - for arg in item: - recursively_check_join_identifiers_for_ambiguity(arg, aliased_fields=aliased_fields) - - def get_deepest_select(select): if not select.from_table or not isinstance(select.from_table, Select): return select @@ -126,7 +108,10 @@ def query_traversal(node, callback, is_table=False, is_target=False, parent_quer array = [] for node2 in node.targets: node_out = query_traversal(node2, callback, parent_query=node, is_target=True) or node2 - array.append(node_out) + if isinstance(node_out, list): + array.extend(node_out) + else: + array.append(node_out) node.targets = array if node.cte is not None: @@ -293,6 +278,20 @@ def query_traversal(node, callback, is_table=False, is_target=False, parent_quer if node_out is not None: node.field = node_out + elif isinstance(node, ast.Case): + rules = [] + for condition, result in node.rules: + condition2 = query_traversal(condition, callback, parent_query=parent_query) + result2 = query_traversal(result, callback, parent_query=parent_query) + + condition = condition if condition2 is None else condition2 + result = result if result2 is None else result2 + rules.append([condition, result]) + node.rules = rules + default = query_traversal(node.default, callback, parent_query=parent_query) + if default is not None: + node.default = default + elif isinstance(node, list): array = [] for node2 in node: diff --git a/mindsdb_sql/render/sqlalchemy_render.py b/mindsdb_sql/render/sqlalchemy_render.py index d4a8c657..23801cc9 100644 --- a/mindsdb_sql/render/sqlalchemy_render.py +++ b/mindsdb_sql/render/sqlalchemy_render.py @@ -392,17 +392,17 @@ def prepare_select(self, node): full=is_full ) elif isinstance(from_table, ast.Union): - if not(isinstance(from_table.left, ast.Select) and isinstance(from_table.right, ast.Select)): - raise NotImplementedError(f'Unknown UNION {from_table.left.__name__}, {from_table.right.__name__}') - - left = self.prepare_select(from_table.left) - right = self.prepare_select(from_table.right) + tables = self.extract_union_list(from_table) alias = None if from_table.alias: alias = self.get_alias(from_table.alias) - table = left.union(right).subquery(alias) + table1 = tables[1] + tables_x = tables[1:] + + table = table1.union(*tables_x).subquery(alias) + query = query.select_from(table) elif isinstance(from_table, ast.Select): @@ -412,6 +412,13 @@ def prepare_select(self, node): elif isinstance(from_table, ast.Identifier): table = self.to_table(from_table) query = query.select_from(table) + + elif isinstance(from_table, ast.NativeQuery): + alias = None + if from_table.alias: + alias = from_table.alias.parts[-1] + table = sa.text(from_table.query).columns().subquery(alias) + query = query.select_from(table) else: raise NotImplementedError(f'Select from {from_table}') @@ -460,6 +467,20 @@ def prepare_select(self, node): return query + def extract_union_list(self, node): + if not (isinstance(node.left, (ast.Select, ast.Union)) and isinstance(node.right, ast.Select)): + raise NotImplementedError( + f'Unknown UNION {node.left.__class__.__name__}, {node.right.__class__.__name__}') + + tables = [] + if isinstance(node.left, ast.Union): + tables.extend(self.extract_union_list(node.left)) + else: + tables.append(self.prepare_select(node.left)) + tables.append(self.prepare_select(node.right)) + return tables + + def prepare_create_table(self, ast_query): columns = [] diff --git a/tests/test_parser/test_base_sql/test_base_sql.py b/tests/test_parser/test_base_sql/test_base_sql.py index 53844d4e..3294b897 100644 --- a/tests/test_parser/test_base_sql/test_base_sql.py +++ b/tests/test_parser/test_base_sql/test_base_sql.py @@ -32,3 +32,29 @@ def test_not_equal(self): assert str(ast).lower() == str(expected_ast).lower() assert ast.to_tree() == expected_ast.to_tree() + def test_escaping(self): + expected_ast = Select( + targets=[Constant(value="a ' \" b")] + ) + + sql = """ + select 'a \\' \\" b' + """ + + ast = parse_sql(sql) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + # in double quotes + sql = """ + select "a \\' \\" b" + """ + + ast = parse_sql(sql) + + assert str(ast).lower() == str(expected_ast).lower() + assert ast.to_tree() == expected_ast.to_tree() + + + diff --git a/tests/test_parser/test_base_sql/test_describe.py b/tests/test_parser/test_base_sql/test_describe.py index 60c1d908..4032f754 100644 --- a/tests/test_parser/test_base_sql/test_describe.py +++ b/tests/test_parser/test_base_sql/test_describe.py @@ -31,3 +31,28 @@ def test_describe_predictor(self): assert str(ast) == str(expected_ast) assert ast.to_tree() == expected_ast.to_tree() + # describe attr + sql = "DESCRIBE MODEL pred.attr" + ast = parse_sql(sql, dialect='mindsdb') + + expected_ast = Describe(type='predictor', value=Identifier(parts=['pred', 'attr'])) + + assert str(ast) == str(expected_ast) + + # version + sql = "DESCRIBE MODEL pred.11" + ast = parse_sql(sql, dialect='mindsdb') + + expected_ast = Describe(type='predictor', value=Identifier(parts=['pred', '11'])) + + assert str(ast) == str(expected_ast) + + # version and attr + sql = "DESCRIBE MODEL pred.11.attr" + ast = parse_sql(sql, dialect='mindsdb') + + expected_ast = Describe(type='predictor', value=Identifier(parts=['pred', '11', 'attr'])) + + assert str(ast) == str(expected_ast) + + diff --git a/tests/test_parser/test_base_sql/test_misc_sql_queries.py b/tests/test_parser/test_base_sql/test_misc_sql_queries.py index bc8df714..68acb3f9 100644 --- a/tests/test_parser/test_base_sql/test_misc_sql_queries.py +++ b/tests/test_parser/test_base_sql/test_misc_sql_queries.py @@ -3,22 +3,22 @@ from mindsdb_sql.parser.ast import * -@pytest.mark.parametrize('dialect', ['sqlite', 'mysql', 'mindsdb']) +@pytest.mark.parametrize('dialect', ['mysql', 'mindsdb']) class TestMiscQueries: def test_set(self, dialect): lexer, parser = get_lexer_parser(dialect) - sql = "SET NAMES some_name" + sql = "SET names some_name" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(category="names", arg=Identifier('some_name')) + expected_ast = Set(category="names", value=Identifier('some_name')) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) sql = "set character_set_results = NULL" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(arg=BinaryOperation('=', args=[Identifier('character_set_results'), NullConstant()])) + expected_ast = Set(name=Identifier('character_set_results'), value=NullConstant()) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -74,14 +74,8 @@ def test_autocommit(self, dialect): ast = parse_sql(sql, dialect=dialect) expected_ast = Set( - category=None, - arg=BinaryOperation( - op='=', - args=( - Identifier('autocommit'), - Constant(1) - ) - ) + name=Identifier('autocommit'), + value=Constant(1) ) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -94,29 +88,30 @@ def test_set(self, dialect): sql = "set var1 = NULL, var2 = 10" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(arg=Tuple(items=[ - BinaryOperation('=', args=[Identifier('var1'), NullConstant()]), - BinaryOperation('=', args=[Identifier('var2'), Constant(10)]), - ]) - ) + expected_ast = Set( + set_list=[ + Set(name=Identifier('var1'), value=NullConstant()), + Set(name=Identifier('var2'), value=Constant(10)), + ] + ) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) - sql = "SET NAMES some_name collate default" + sql = "SET NAMES some_name collate DEFAULT" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(category="names", - arg=Identifier('some_name'), + expected_ast = Set(category="NAMES", + value=Constant('some_name', with_quotes=False), params={'COLLATE': 'DEFAULT'}) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) - sql = "SET NAMES some_name collate 'utf8mb4_general_ci'" + sql = "SET names some_name collate 'utf8mb4_general_ci'" ast = parse_sql(sql, dialect=dialect) expected_ast = Set(category="names", - arg=Identifier('some_name'), + value=Constant('some_name', with_quotes=False), params={'COLLATE': Constant('utf8mb4_general_ci')}) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -126,14 +121,14 @@ def test_set_charset(self, dialect): sql = "SET CHARACTER SET DEFAULT" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(category='CHARSET', arg=SpecialConstant('DEFAULT')) + expected_ast = Set(category='CHARSET', value=Constant('DEFAULT', with_quotes=False)) assert ast.to_tree() == expected_ast.to_tree() sql = "SET CHARSET DEFAULT" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(category='CHARSET', arg=SpecialConstant('DEFAULT')) + expected_ast = Set(category='CHARSET', value=Constant('DEFAULT', with_quotes=False)) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -141,7 +136,7 @@ def test_set_charset(self, dialect): sql = "SET CHARSET 'utf8'" ast = parse_sql(sql, dialect=dialect) - expected_ast = Set(category='CHARSET', arg=Constant('utf8')) + expected_ast = Set(category='CHARSET', value=Constant('utf8')) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -151,10 +146,14 @@ def test_set_transaction(self, dialect): sql = "SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE" ast = parse_sql(sql, dialect=dialect) - expected_ast = SetTransaction( - isolation_level='REPEATABLE READ', - access_mode='READ WRITE', - scope='GLOBAL') + expected_ast = Set( + category='TRANSACTION', + params=dict( + isolation_level='REPEATABLE READ', + access_mode='READ WRITE', + ), + scope='GLOBAL' + ) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -162,10 +161,15 @@ def test_set_transaction(self, dialect): sql = "SET SESSION TRANSACTION READ ONLY, ISOLATION LEVEL SERIALIZABLE" ast = parse_sql(sql, dialect=dialect) - expected_ast = SetTransaction( - isolation_level='SERIALIZABLE', - access_mode='READ ONLY', - scope='SESSION') + + expected_ast = Set( + category='TRANSACTION', + params=dict( + isolation_level='SERIALIZABLE', + access_mode='READ ONLY', + ), + scope='SESSION' + ) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) @@ -173,8 +177,12 @@ def test_set_transaction(self, dialect): sql = "SET TRANSACTION ISOLATION LEVEL READ UNCOMMITTED" ast = parse_sql(sql, dialect=dialect) - expected_ast = SetTransaction( - isolation_level='READ UNCOMMITTED' + + expected_ast = Set( + category='TRANSACTION', + params=dict( + isolation_level='READ UNCOMMITTED', + ) ) assert ast.to_tree() == expected_ast.to_tree() @@ -183,8 +191,12 @@ def test_set_transaction(self, dialect): sql = "SET TRANSACTION READ ONLY" ast = parse_sql(sql, dialect=dialect) - expected_ast = SetTransaction( - access_mode='READ ONLY' + + expected_ast = Set( + category='TRANSACTION', + params=dict( + access_mode='READ ONLY', + ) ) assert ast.to_tree() == expected_ast.to_tree() @@ -198,3 +210,14 @@ def test_begin(self, dialect): assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) +class TestMindsdb: + def test_charset(self): + sql = "SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci" + + ast = parse_sql(sql) + expected_ast = Set(category="NAMES", + value=Constant('utf8mb4', with_quotes=False), + params={'COLLATE': Constant('utf8mb4_unicode_ci', with_quotes=False)}) + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) + diff --git a/tests/test_parser/test_base_sql/test_select_structure.py b/tests/test_parser/test_base_sql/test_select_structure.py index 2294936a..b896bcab 100644 --- a/tests/test_parser/test_base_sql/test_select_structure.py +++ b/tests/test_parser/test_base_sql/test_select_structure.py @@ -907,45 +907,6 @@ def test_keywords(self, dialect): assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) - def test_case(self, dialect): - sql = f'''SELECT - CASE - WHEN R.DELETE_RULE = 'CASCADE' THEN 0 - WHEN R.DELETE_RULE = 'SET NULL' THEN 2 - ELSE 3 - END AS DELETE_RULE - FROM INFORMATION_SCHEMA.COLLATIONS''' - ast = parse_sql(sql, dialect=dialect) - - expected_ast = Select( - targets=[ - Case( - rules=[ - [ - BinaryOperation(op='=', args=[ - Identifier('R.DELETE_RULE'), - Constant('CASCADE') - ]), - Constant(0) - ], - [ - BinaryOperation(op='=', args=[ - Identifier('R.DELETE_RULE'), - Constant('SET NULL') - ]), - Constant(2) - ] - ], - default=Constant(3), - alias=Identifier('DELETE_RULE') - ) - ], - from_table=Identifier('INFORMATION_SCHEMA.COLLATIONS') - ) - - assert ast.to_tree() == expected_ast.to_tree() - assert str(ast) == str(expected_ast) - def test_table_star(self, dialect): sql = f'select *, t.* From table1 ' ast = parse_sql(sql, dialect=dialect) @@ -997,3 +958,64 @@ def test_select_function_star(self, dialect): assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) + +class TestMindsdb: + + def test_case(self): + sql = f'''SELECT + CASE + WHEN R.DELETE_RULE = 'CASCADE' THEN 0 + WHEN R.DELETE_RULE = 'SET NULL' THEN 2 + ELSE 3 + END AS DELETE_RULE, + sum( + CASE + WHEN 1 = 1 THEN 1 + ELSE 0 + END + ) + FROM INFORMATION_SCHEMA.COLLATIONS''' + ast = parse_sql(sql) + + expected_ast = Select( + targets=[ + Case( + rules=[ + [ + BinaryOperation(op='=', args=[ + Identifier('R.DELETE_RULE'), + Constant('CASCADE') + ]), + Constant(0) + ], + [ + BinaryOperation(op='=', args=[ + Identifier('R.DELETE_RULE'), + Constant('SET NULL') + ]), + Constant(2) + ] + ], + default=Constant(3), + alias=Identifier('DELETE_RULE') + ), + Function( + op='sum', + args=[ + Case( + rules=[ + [ + BinaryOperation(op='=', args=[Constant(1), Constant(1)]), + Constant(1) + ], + ], + default=Constant(0) + ) + ] + ) + ], + from_table=Identifier('INFORMATION_SCHEMA.COLLATIONS') + ) + + assert ast.to_tree() == expected_ast.to_tree() + assert str(ast) == str(expected_ast) diff --git a/tests/test_parser/test_base_sql/test_union.py b/tests/test_parser/test_base_sql/test_union.py index 8a92ea32..1545e4b0 100644 --- a/tests/test_parser/test_base_sql/test_union.py +++ b/tests/test_parser/test_base_sql/test_union.py @@ -4,19 +4,18 @@ from mindsdb_sql.exceptions import ParsingException -@pytest.mark.parametrize('dialect', ['sqlite', 'mysql', 'mindsdb']) class TestUnion: - def test_single_select_error(self, dialect): + def test_single_select_error(self): sql = "SELECT col FROM tab UNION" with pytest.raises(ParsingException): - parse_sql(sql, dialect=dialect) + parse_sql(sql) - def test_union_base(self, dialect): + def test_union_base(self): sql = """SELECT col1 FROM tab1 UNION SELECT col1 FROM tab2""" - ast = parse_sql(sql, dialect=dialect) + ast = parse_sql(sql) expected_ast = Union(unique=True, left=Select(targets=[Identifier('col1')], from_table=Identifier(parts=['tab1']), @@ -28,12 +27,12 @@ def test_union_base(self, dialect): assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) - def test_union_all(self, dialect): + def test_union_all(self): sql = """SELECT col1 FROM tab1 UNION ALL SELECT col1 FROM tab2""" - ast = parse_sql(sql, dialect=dialect) + ast = parse_sql(sql) expected_ast = Union(unique=False, left=Select(targets=[Identifier('col1')], from_table=Identifier(parts=['tab1']), @@ -45,25 +44,31 @@ def test_union_all(self, dialect): assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) - def test_union_alias(self, dialect): + def xtest_union_alias(self): sql = """SELECT * FROM ( SELECT col1 FROM tab1 UNION SELECT col1 FROM tab2 + UNION + SELECT col1 FROM tab3 ) AS alias""" - ast = parse_sql(sql, dialect=dialect) + ast = parse_sql(sql) expected_ast = Select(targets=[Star()], - from_table=Union(unique=True, - alias=Identifier('alias'), - left=Select( - targets=[Identifier('col1')], - from_table=Identifier(parts=['tab1']), - ), - right=Select(targets=[Identifier('col1')], - from_table=Identifier(parts=['tab2']), - ), - ) + from_table=Union( + unique=True, + alias=Identifier('alias'), + left=Union( + unique=True, + left=Select( + targets=[Identifier('col1')], + from_table=Identifier(parts=['tab1']),), + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab2']),), + ), + right=Select(targets=[Identifier('col1')], + from_table=Identifier(parts=['tab3']),), + ) ) assert ast.to_tree() == expected_ast.to_tree() assert str(ast) == str(expected_ast) diff --git a/tests/test_parser/test_mindsdb/test_finetune_predictor.py b/tests/test_parser/test_mindsdb/test_finetune_predictor.py index 2cd777f6..fffc7dd9 100644 --- a/tests/test_parser/test_mindsdb/test_finetune_predictor.py +++ b/tests/test_parser/test_mindsdb/test_finetune_predictor.py @@ -29,3 +29,10 @@ def test_finetune_predictor_full(self): assert ' '.join(str(ast).split()).lower() == sql.lower() assert str(ast) == str(expected_ast) assert ast.to_tree() == expected_ast.to_tree() + + # with MODEL + sql = "FINETUNE MODEL mindsdb.pred FROM integration_name (SELECT * FROM table_1) USING a=1, b=null" + ast = parse_sql(sql, dialect='mindsdb') + + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() diff --git a/tests/test_parser/test_mindsdb/test_retrain_predictor.py b/tests/test_parser/test_mindsdb/test_retrain_predictor.py index 2941694f..baa6b4e8 100644 --- a/tests/test_parser/test_mindsdb/test_retrain_predictor.py +++ b/tests/test_parser/test_mindsdb/test_retrain_predictor.py @@ -26,6 +26,12 @@ def test_retrain_predictor_ok(self): assert str(ast) == str(expected_ast) assert ast.to_tree() == expected_ast.to_tree() + # with model + sql = "RETRAIN MODEL mindsdb.pred" + ast = parse_sql(sql, dialect='mindsdb') + assert str(ast) == str(expected_ast) + assert ast.to_tree() == expected_ast.to_tree() + def test_retrain_predictor_full(self): sql = """Retrain pred FROM integration_name diff --git a/tests/test_parser/test_mindsdb/test_timeseries.py b/tests/test_parser/test_mindsdb/test_timeseries.py index aac1b5d7..c1161b80 100644 --- a/tests/test_parser/test_mindsdb/test_timeseries.py +++ b/tests/test_parser/test_mindsdb/test_timeseries.py @@ -2,7 +2,6 @@ from mindsdb_sql.parser.ast import * from mindsdb_sql.parser.dialects.mindsdb.latest import Latest from mindsdb_sql.parser.utils import JoinType -from mindsdb_sql.planner.ts_utils import validate_ts_where_condition class TestTimeSeries: diff --git a/tests/test_parser/test_mindsdb/test_variables.py b/tests/test_parser/test_mindsdb/test_variables.py new file mode 100644 index 00000000..27333033 --- /dev/null +++ b/tests/test_parser/test_mindsdb/test_variables.py @@ -0,0 +1,39 @@ +from mindsdb_sql import parse_sql +from mindsdb_sql.parser.ast import * +from mindsdb_sql.parser.ast import Variable + +class TestMDBParser: + def test_select_variable(self): + sql = 'SELECT @version' + ast = parse_sql(sql) + expected_ast = Select(targets=[Variable('version')]) + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + + sql = 'SELECT @@version' + ast = parse_sql(sql) + expected_ast = Select(targets=[Variable('version', is_system_var=True)]) + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + + sql = "set autocommit=1, global sql_mode=concat(@@sql_mode, ',STRICT_TRANS_TABLES'), NAMES utf8mb4 COLLATE utf8mb4_unicode_ci" + ast = parse_sql(sql) + expected_ast = Set( + set_list=[ + Set(name=Identifier('autocommit'), value=Constant(1)), + Set(name=Identifier('sql_mode'), + scope='global', + value=Function(op='concat', args=[ + Variable('sql_mode', is_system_var=True), + Constant(',STRICT_TRANS_TABLES') + ]) + ), + Set(category="NAMES", + value=Constant('utf8mb4', with_quotes=False), + params={'COLLATE': Constant('utf8mb4_unicode_ci', with_quotes=False)}) + ] + ) + + assert str(ast).lower() == sql.lower() + assert str(ast) == str(expected_ast) + diff --git a/tests/test_parser/test_mysql/test_mysql_parser.py b/tests/test_parser/test_mysql/test_mysql_parser.py index 9e28a899..30c91e18 100644 --- a/tests/test_parser/test_mysql/test_mysql_parser.py +++ b/tests/test_parser/test_mysql/test_mysql_parser.py @@ -1,6 +1,6 @@ from mindsdb_sql import parse_sql from mindsdb_sql.parser.ast import Select, Identifier, BinaryOperation, Star -from mindsdb_sql.parser.dialects.mysql import Variable +from mindsdb_sql.parser.ast import Variable from mindsdb_sql.parser.parser import Show class TestMySQLParser: diff --git a/tests/test_planner/test_injected_data.py b/tests/test_planner/test_injected_data.py index e0ff341b..abefa917 100644 --- a/tests/test_planner/test_injected_data.py +++ b/tests/test_planner/test_injected_data.py @@ -1,9 +1,11 @@ +import copy + from mindsdb_sql.parser.ast import * from mindsdb_sql.planner import plan_query from mindsdb_sql.planner.query_plan import QueryPlan from mindsdb_sql.planner.step_result import Result from mindsdb_sql.planner.steps import (FilterStep, DataStep, ProjectStep, JoinStep, ApplyPredictorStep, - SubSelectStep) + SubSelectStep, QueryStep) from mindsdb_sql.parser.utils import JoinType @@ -63,6 +65,9 @@ def test_join(self): where=BinaryOperation(op='=', args=[Identifier('t.a'), Constant(1)]) ) + subquery = copy.deepcopy(query) + subquery.from_table = None + plan = plan_query( query, integrations=['int1'], @@ -90,12 +95,9 @@ def test_join(self): query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), - FilterStep(dataframe=Result(3), query=BinaryOperation(op='=', args=[Identifier('t.a'), Constant(1)])), - ProjectStep(dataframe=Result(4), columns=[Identifier('t.x')]) + QueryStep(subquery, from_table=Result(3)), ], ) - for i in range(len(plan.steps)): - print(plan.steps[i], expected_plan.steps[i]) - assert plan.steps[i] == expected_plan.steps[i] + assert plan.steps == expected_plan.steps diff --git a/tests/test_planner/test_integration_select.py b/tests/test_planner/test_integration_select.py index 06df7edf..e4c3adf1 100644 --- a/tests/test_planner/test_integration_select.py +++ b/tests/test_planner/test_integration_select.py @@ -514,8 +514,9 @@ def test_select_from_table_subselect(self): def test_select_from_table_subselect_api_integration(self): query = parse_sql(''' - select * from int1.tab1 + select x from int1.tab2 where x1 in (select id from int1.tab1) + limit 1 ''', dialect='mindsdb') expected_plan = QueryPlan( @@ -523,22 +524,33 @@ def test_select_from_table_subselect_api_integration(self): steps=[ FetchDataframeStep( integration='int1', - query=parse_sql('select tab1.id as id from tab1'), + query=parse_sql('select * from tab1'), + ), + SubSelectStep( + dataframe=Result(0), + query=parse_sql("select id"), + table_name='tab1' ), FetchDataframeStep( integration='int1', query=Select( targets=[Star()], - from_table=Identifier('tab1'), + from_table=Identifier('tab2'), where=BinaryOperation( op='in', args=[ - Identifier(parts=['tab1', 'x1']), - Parameter(Result(0)) + Identifier(parts=['tab2', 'x1']), + Parameter(Result(1)) ] - ) + ), + limit=Constant(1) ), ), + SubSelectStep( + dataframe=Result(2), + query=parse_sql("select x"), + table_name='tab2' + ), ], ) @@ -585,7 +597,12 @@ def test_delete_from_table_subselect_api_integration(self): steps=[ FetchDataframeStep( integration='int1', - query=parse_sql('select tab1.id as id from tab1'), + query=parse_sql('select * from tab1'), + ), + SubSelectStep( + dataframe=Result(0), + query=parse_sql("select id"), + table_name='tab1' ), DeleteStep( table=Identifier('int1.tab1'), @@ -593,7 +610,7 @@ def test_delete_from_table_subselect_api_integration(self): op='in', args=[ Identifier(parts=['x1']), - Parameter(Result(0)) + Parameter(Result(1)) ] ) ), diff --git a/tests/test_planner/test_join_predictor.py b/tests/test_planner/test_join_predictor.py index 391d1267..8999b784 100644 --- a/tests/test_planner/test_join_predictor.py +++ b/tests/test_planner/test_join_predictor.py @@ -1,3 +1,5 @@ +import copy + import pytest from mindsdb_sql.exceptions import PlanningException @@ -6,19 +8,22 @@ from mindsdb_sql.planner.query_plan import QueryPlan from mindsdb_sql.planner.step_result import Result from mindsdb_sql.planner.steps import (FetchDataframeStep, ProjectStep, JoinStep, ApplyPredictorStep, FilterStep, - LimitOffsetStep, GroupByStep, SubSelectStep, ApplyPredictorRowStep) + LimitOffsetStep, QueryStep, SubSelectStep, ApplyPredictorRowStep) from mindsdb_sql.parser.utils import JoinType from mindsdb_sql import parse_sql class TestPlanJoinPredictor: def test_join_predictor_plan(self): - query = Select(targets=[Identifier('tab1.column1'), Identifier('pred.predicted')], - from_table=Join(left=Identifier('int.tab1'), - right=Identifier('mindsdb.pred'), - join_type=JoinType.INNER_JOIN, - implicit=True) - ) + + sql = """ + select tab1.column1, pred.predicted + from int.tab1, mindsdb.pred + """ + query = parse_sql(sql) + + query_step = parse_sql("select tab1.column1, pred.predicted") + query_step.from_table = Parameter(Result(2)) expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', @@ -27,51 +32,32 @@ def test_join_predictor_plan(self): ), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred')), JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.INNER_JOIN)), - ProjectStep(dataframe=Result(2), columns=[Identifier('tab1.column1'), Identifier('pred.predicted')]), + QueryStep(parse_sql("select tab1.column1, pred.predicted"), from_table=Result(2)), ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) for i in range(len(plan.steps)): assert plan.steps[i] == expected_plan.steps[i] - - def test_predictor_namespace_is_case_insensitive(self): - query = Select(targets=[Identifier('tab1.column1'), Identifier('pred.predicted')], - from_table=Join(left=Identifier('int.tab1'), - right=Identifier('mindsdb.pred'), - join_type=JoinType.INNER_JOIN, - implicit=True) - ) - expected_plan = QueryPlan( - steps=[ - FetchDataframeStep(integration='int', - query=Select(targets=[Star()], - from_table=Identifier('tab1')), - ), - ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred')), - JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), - join_type=JoinType.INNER_JOIN)), - ProjectStep(dataframe=Result(2), columns=[Identifier('tab1.column1'), Identifier('pred.predicted')]), - ], - ) + # test_predictor_namespace_is_case_insensitive plan = plan_query(query, integrations=['int'], predictor_namespace='MINDSDB', predictor_metadata={'pred': {}}) - assert plan.steps == expected_plan.steps - + for i in range(len(plan.steps)): + assert plan.steps[i] == expected_plan.steps[i] + def test_join_predictor_plan_aliases(self): - query = Select(targets=[Identifier('ta.column1'), Identifier('tb.predicted')], - from_table=Join(left=Identifier('int.tab1', alias=Identifier('ta')), - right=Identifier('mindsdb.pred', alias=Identifier('tb')), - join_type=JoinType.INNER_JOIN, - implicit=True) - ) + + sql = """ + select ta.column1, tb.predicted + from int.tab1 ta, mindsdb.pred tb + """ + query = parse_sql(sql) + expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', @@ -80,10 +66,10 @@ def test_join_predictor_plan_aliases(self): ), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred', alias=Identifier('tb'))), JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.INNER_JOIN)), - ProjectStep(dataframe=Result(2), columns=[Identifier('ta.column1'), Identifier('tb.predicted')]), + QueryStep(parse_sql("select ta.column1, tb.predicted"), from_table=Result(2)), ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) @@ -91,39 +77,40 @@ def test_join_predictor_plan_aliases(self): assert plan.steps == expected_plan.steps - def test_join_predictor_plan_where(self): - query = Select(targets=[Identifier('tab.column1'), Identifier('pred.predicted')], - from_table=Join(left=Identifier('int.tab'), - right=Identifier('mindsdb.pred'), - join_type=JoinType.INNER_JOIN, - implicit=True), - where=BinaryOperation('and', args=[ - BinaryOperation('=', args=[Identifier('tab.product_id'), Constant('x')]), - BetweenOperation(args=[Identifier('tab.time'), Constant('2021-01-01'), Constant('2021-01-31')]), - ]) - ) + def test_join_predictor_plan_limit(self): + + sql = """ + select tab.column1, pred.predicted + from int.tab, mindsdb.pred + where tab.product_id = 'x' and tab.time between '2021-01-01' and '2021-01-31' + order by tab.column2 + limit 10 + offset 1 + """ + query = parse_sql(sql) + + subquery = copy.deepcopy(query) + subquery.from_table = None + subquery.offset = None expected_plan = QueryPlan( steps=[ - FetchDataframeStep(integration='int', - query=Select(targets=[Star()], - from_table=Identifier('tab'), - where=BinaryOperation('and', args=[ - BinaryOperation('=', - args=[Identifier('tab.product_id'), Constant('x')]), - BetweenOperation( - args=[Identifier('tab.time'), - Constant('2021-01-01'), - Constant('2021-01-31')]), - ]) - ), - ), + FetchDataframeStep( + integration='int', + query=parse_sql(""" + select * from tab + where product_id = 'x' and time between '2021-01-01' and '2021-01-31' + order by column2 + limit 10 + offset 1 + """) + ), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred')), JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.INNER_JOIN)), - ProjectStep(dataframe=Result(2), columns=[Identifier('tab.column1'), Identifier('pred.predicted')]), + QueryStep(subquery, from_table=Result(2)), ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) @@ -156,7 +143,7 @@ def test_join_predictor_plan_where(self): # with pytest.raises(PlanningException): # plan_query(query, integrations=['postgres_90'], predictor_namespace='mindsdb', predictor_metadata={'hrp3': {}}) - def test_join_predictor_plan_group_by(self): + def test_join_predictor_plan_complex_query(self): query = Select(targets=[Identifier('tab.asset'), Identifier('tab.time'), Identifier('pred.predicted')], from_table=Join(left=Identifier('int.tab'), right=Identifier('mindsdb.pred'), @@ -166,124 +153,38 @@ def test_join_predictor_plan_group_by(self): having=BinaryOperation('=', args=[Identifier('tab.asset'), Constant('bitcoin')]) ) - expected_plan = QueryPlan( - steps=[ - FetchDataframeStep(integration='int', - query=Select(targets=[Star()], - from_table=Identifier('tab'), - group_by=[Identifier('tab.asset')], - having=BinaryOperation('=', args=[Identifier('tab.asset'), - Constant('bitcoin')]) - ), - ), - ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred')), - JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), - join_type=JoinType.INNER_JOIN)), - ProjectStep(dataframe=Result(2), columns=[Identifier('tab.asset'), Identifier('tab.time'), Identifier('pred.predicted')]), - ], - ) - plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) - - assert plan.steps == expected_plan.steps - - - def test_join_predictor_plan_limit_offset(self): - query = Select(targets=[Identifier('tab.column1'), Identifier('pred.predicted')], - from_table=Join(left=Identifier('int.tab'), - right=Identifier('mindsdb.pred'), - join_type=JoinType.INNER_JOIN, - implicit=True), - where=BinaryOperation('=', args=[Identifier('tab.product_id'), Constant('x')]), - limit=Constant(10), - offset=Constant(15), - ) - - expected_plan = QueryPlan( - steps=[ - FetchDataframeStep(integration='int', - query=Select(targets=[Star()], - from_table=Identifier('tab'), - where=BinaryOperation('=', args=[Identifier('tab.product_id'), Constant('x')]), - limit=Constant(10), - offset=Constant(15), - ), - ), - ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred')), - JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), - join_type=JoinType.INNER_JOIN)), - ProjectStep(dataframe=Result(2), columns=[Identifier('tab.column1'), Identifier('pred.predicted')]), - ], - ) - plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) - - assert plan.steps == expected_plan.steps - - - def test_join_predictor_plan_order_by(self): - query = Select(targets=[Identifier('tab.column1'), Identifier('pred.predicted')], - from_table=Join(left=Identifier('int.tab'), - right=Identifier('mindsdb.pred'), - join_type=JoinType.INNER_JOIN, - implicit=True), - where=BinaryOperation('=', args=[Identifier('tab.product_id'), Constant('x')]), - limit=Constant(10), - offset=Constant(15), - order_by=[OrderBy(field=Identifier('tab.column1'))] - ) - - expected_plan = QueryPlan( - steps=[ - FetchDataframeStep(integration='int', - query=Select(targets=[Star()], - from_table=Identifier('tab'), - where=BinaryOperation('=', args=[Identifier('tab.product_id'), Constant('x')]), - limit=Constant(10), - offset=Constant(15), - order_by=[OrderBy(field=Identifier('tab.column1'))], - ), - ), - ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred')), - JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), - join_type=JoinType.INNER_JOIN)), - ProjectStep(dataframe=Result(2), columns=[Identifier('tab.column1'), Identifier('pred.predicted')]), - ], - ) - plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) + sql = """ + select t.asset, t.time, m.predicted + from int.tab t, mindsdb.pred m + where t.col1 = 'x' + group by t.asset + having t.asset = 'bitcoin' + order by t.asset + limit 1 + offset 2 + """ + query = parse_sql(sql) - assert plan.steps == expected_plan.steps - + subquery = copy.deepcopy(query) + subquery.from_table = None - def test_join_predictor_plan_predictor_alias(self): - query = Select(targets=[Identifier('tab1.column1'), Identifier('pred_alias.predicted')], - from_table=Join(left=Identifier('int.tab1'), - right=Identifier('mindsdb.pred', alias=Identifier('pred_alias')), - join_type=JoinType.INNER_JOIN, - implicit=True) - ) expected_plan = QueryPlan( steps=[ - FetchDataframeStep(integration='int', - query=Select(targets=[Star()], - from_table=Identifier('tab1')), - ), - ApplyPredictorStep(namespace='mindsdb', predictor=Identifier('pred', alias=Identifier('pred_alias')), dataframe=Result(0)), + FetchDataframeStep( + integration='int', + query=parse_sql("select * from tab as t where col1 = 'x'") + ), + ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred', alias=Identifier('m'))), JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.INNER_JOIN)), - ProjectStep(dataframe=Result(2), columns=[Identifier('tab1.column1'), Identifier('pred_alias.predicted')]), + QueryStep(subquery, from_table=Result(2)), ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) assert plan.steps == expected_plan.steps - def test_no_predictor_error(self): query = Select(targets=[Identifier('tab1.column1'), Identifier('pred.predicted')], @@ -297,12 +198,12 @@ def test_no_predictor_error(self): plan = plan_query(query, integrations=['int'], predictor_metadata={'pred': {}}) def test_join_predictor_plan_default_namespace_integration(self): - query = Select(targets=[Identifier('tab1.column1'), Identifier('pred.predicted')], - from_table=Join(left=Identifier('tab1'), - right=Identifier('mindsdb.pred'), - join_type=JoinType.INNER_JOIN, - implicit=True) - ) + + sql = """ + select tab1.column1, pred.predicted + from tab1, mindsdb.pred + """ + query = parse_sql(sql) expected_plan = QueryPlan( default_namespace='int', steps=[ @@ -312,24 +213,24 @@ def test_join_predictor_plan_default_namespace_integration(self): ), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred')), JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.INNER_JOIN)), - ProjectStep(dataframe=Result(2), columns=[Identifier('tab1.column1'), Identifier('pred.predicted')]), + QueryStep(parse_sql("select tab1.column1, pred.predicted"), from_table=Result(2)), ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', default_namespace='int', predictor_metadata={'pred': {}}) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] + assert plan.steps == expected_plan.steps def test_join_predictor_plan_default_namespace_predictor(self): - query = Select(targets=[Identifier('tab1.column1'), Identifier('pred.predicted')], - from_table=Join(left=Identifier('int.tab1'), - right=Identifier('pred'), - join_type=JoinType.INNER_JOIN, - implicit=True) - ) + + sql = """ + select tab1.column1, pred.predicted + from int.tab1, pred + """ + query = parse_sql(sql) + expected_plan = QueryPlan( default_namespace='mindsdb', steps=[ @@ -339,10 +240,10 @@ def test_join_predictor_plan_default_namespace_predictor(self): ), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred')), JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.INNER_JOIN)), - ProjectStep(dataframe=Result(2), columns=[Identifier('tab1.column1'), Identifier('pred.predicted')]), + QueryStep(parse_sql("select tab1.column1, pred.predicted"), from_table=Result(2)), ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', default_namespace='mindsdb', predictor_metadata={'pred': {}}) @@ -363,7 +264,7 @@ def test_nested_select(self): limit 1 ''' - query = parse_sql(sql, dialect='mindsdb') + query = parse_sql(sql) expected_plan = QueryPlan( default_namespace='mindsdb', @@ -372,11 +273,12 @@ def test_nested_select(self): query=parse_sql('select * from covid limit 10')), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred')), JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.JOIN)), - SubSelectStep(dataframe=Result(2), query=parse_sql('SELECT time limit 1'), table_name='Custom SQL Query'), - LimitOffsetStep(dataframe=Result(3), limit=1) + QueryStep(Select(targets=[Star()], limit=Constant(10)), from_table=Result(2)), + SubSelectStep(dataframe=Result(3), query=parse_sql('SELECT time limit 1'), table_name='Custom SQL Query'), + ], ) @@ -387,9 +289,8 @@ def test_nested_select(self): default_namespace='mindsdb', predictor_metadata={'pred': {}} ) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] + assert plan.steps == expected_plan.steps sql = f''' SELECT `time` @@ -400,7 +301,7 @@ def test_nested_select(self): GROUP BY 1 ''' - query = parse_sql(sql, dialect='mindsdb') + query = parse_sql(sql) expected_plan = QueryPlan( default_namespace='mindsdb', @@ -409,8 +310,8 @@ def test_nested_select(self): query=parse_sql('select * from covid')), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred')), JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.JOIN)), SubSelectStep(dataframe=Result(2), query=Select(targets=[Identifier('time')], group_by=[Constant(1)]), @@ -425,8 +326,8 @@ def test_nested_select(self): default_namespace='mindsdb', predictor_metadata={'pred': {}} ) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] + + assert plan.steps == expected_plan.steps def test_subselect(self): @@ -441,7 +342,7 @@ def test_subselect(self): limit 5 ''' - query = parse_sql(sql, dialect='mindsdb') + query = parse_sql(sql) expected_plan = QueryPlan( default_namespace='mindsdb', @@ -454,8 +355,7 @@ def test_subselect(self): query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), - LimitOffsetStep(dataframe=Result(3), limit=5), - ProjectStep(dataframe=Result(4), columns=[Star()]) + QueryStep(Select(targets=[Star()], limit=Constant(5)), from_table=Result(3)) ], ) @@ -466,9 +366,7 @@ def test_subselect(self): default_namespace='mindsdb', predictor_metadata={'pred': {}} ) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] - + assert plan.steps == expected_plan.steps # only nested select with limit sql = f''' @@ -481,7 +379,7 @@ def test_subselect(self): join mindsdb.pred ''' - query = parse_sql(sql, dialect='mindsdb') + query = parse_sql(sql) expected_plan = QueryPlan( default_namespace='mindsdb', @@ -494,7 +392,6 @@ def test_subselect(self): query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), - ProjectStep(dataframe=Result(3), columns=[Star()]) ], ) @@ -505,9 +402,7 @@ def test_subselect(self): default_namespace='mindsdb', predictor_metadata={'pred': {}} ) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] - + assert plan.steps == expected_plan.steps class TestPredictorWithUsing: @@ -519,16 +414,16 @@ def test_using_join(self): using a=1 ''' - query = parse_sql(sql, dialect='mindsdb') + query = parse_sql(sql) expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1', dialect='mindsdb')), + query=parse_sql('select * from tab1')), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred'), params={'a': 1}), JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.JOIN)), ProjectStep(dataframe=Result(2), columns=[Star()]), ], @@ -546,7 +441,7 @@ def test_using_join(self): using a=1 ''' - query = parse_sql(sql, dialect='mindsdb') + query = parse_sql(sql) expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', raw_query='select * from tab1'), @@ -573,7 +468,7 @@ def test_using_one_line(self): select * from mindsdb.pred where x=2 using a=1 ''' - query = parse_sql(sql, dialect='mindsdb') + query = parse_sql(sql) expected_plan = QueryPlan( steps=[ ApplyPredictorRowStep(namespace='mindsdb', predictor=Identifier('pred'), @@ -596,16 +491,16 @@ def test_using_join(self): using a=1 ''' - query = parse_sql(sql, dialect='mindsdb') + query = parse_sql(sql) expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1', dialect='mindsdb')), + query=parse_sql('select * from tab1')), ApplyPredictorStep(namespace='proj', dataframe=Result(0), predictor=Identifier('pred.1'), params={'a': 1}), JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.JOIN)), ProjectStep(dataframe=Result(2), columns=[Star()]), ], @@ -624,7 +519,7 @@ def test_using_join(self): join pred.1 using a=1 ''' - query = parse_sql(sql, dialect='mindsdb') + query = parse_sql(sql) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', default_namespace='proj', predictor_metadata=[{'name': 'pred', 'integration_name': 'proj'}]) @@ -640,29 +535,34 @@ def test_where_using(self): where a.x=1 and p.x=1 and a.y=3 and p.y='' ''' - query = parse_sql(sql, dialect='mindsdb') + subquery = parse_sql(""" + select * from x + where a.x=1 and 0=0 and a.y=3 and p.y='' + """) + subquery.from_table = None + + query = parse_sql(sql) expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1 as a where a.x=1 and a.y=3', dialect='mindsdb')), + query=parse_sql('select * from tab1 as a where x=1 and y=3')), ApplyPredictorStep( namespace='proj', dataframe=Result(0), predictor=Identifier('pred.1', alias=Identifier('p')), row_dict={'x': 1, 'y': ''} ), JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.JOIN)), - ProjectStep(dataframe=Result(2), columns=[Star()]), + QueryStep(subquery, from_table=Result(2)) ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata=[{'name': 'pred', 'integration_name': 'proj'}]) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] + assert plan.steps == expected_plan.steps def test_using_one_line(self): @@ -670,7 +570,7 @@ def test_using_one_line(self): select * from proj.pred.1 where x=2 using a=1 ''' - query = parse_sql(sql, dialect='mindsdb') + query = parse_sql(sql) expected_plan = QueryPlan( steps=[ ApplyPredictorRowStep(namespace='proj', predictor=Identifier('pred.1'), @@ -689,7 +589,7 @@ def test_using_one_line(self): sql = ''' select * from pred.1 where x=2 using a=1 ''' - query = parse_sql(sql, dialect='mindsdb') + query = parse_sql(sql) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', default_namespace='proj', predictor_metadata=[{'name': 'pred', 'integration_name': 'proj'}]) @@ -705,24 +605,30 @@ def test_model_param(self): where m.a=1 and t.b=2 ''' - query = parse_sql(sql, dialect='mindsdb') + query = parse_sql(sql) + + subquery = parse_sql(""" + select * from x + where 0=0 and t.b=2 + """) + subquery.from_table = None + expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1 as t where t.b=2', dialect='mindsdb')), + query=parse_sql('select * from tab1 as t where b=2')), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': 1}), JoinStep(left=Result(0), right=Result(1), - query=Join(left=Identifier('result_0'), - right=Identifier('result_1'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.JOIN)), - ProjectStep(dataframe=Result(2), columns=[Star()]), + QueryStep(subquery, from_table=Result(2)), ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] + assert plan.steps == expected_plan.steps # 3 table sql = ''' @@ -732,13 +638,20 @@ def test_model_param(self): where m.a=1 ''' - query = parse_sql(sql, dialect='mindsdb') + + subquery = parse_sql(""" + select * from x + where 0=0 + """) + subquery.from_table = None + + query = parse_sql(sql) expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1 as t', dialect='mindsdb')), + query=parse_sql('select * from tab1 as t')), FetchDataframeStep(integration='int', - query=parse_sql('select * from tab2 as t2', dialect='mindsdb')), + query=parse_sql('select * from tab2 as t2')), JoinStep(left=Result(0), right=Result(1), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), @@ -749,48 +662,73 @@ def test_model_param(self): query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), - FilterStep(dataframe=Result(4), query=BinaryOperation(op='=', args=[Constant(0), Constant(0)])), + QueryStep(subquery, from_table=Result(4)), ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] + assert plan.steps == expected_plan.steps + + def test_complex_subselect(self): - def test_model_param_subselect(self): sql = ''' - select * from int.tab1 t - join int.tab2 t2 - join mindsdb.pred m - where m.a = (select a from int.tab3 where x=1) - ''' + select t2.x, m.id, (select a from int.tab0 where x=0) from int.tab1 t1 + join int.tab2 t2 on t1.x = t2.x + join mindsdb.pred m + where m.a=(select a from int.tab3 where x=3) + and t2.x=(select a from int.tab4 where x=4) + and t1.b=1 and t2.b=2 and t1.a = t2.a + ''' + + q_table2 = parse_sql('select * from tab2 as t2 where x=0 and b=2 ') + q_table2.where.args[0].args[1] = Parameter(Result(2)) + + subquery = parse_sql(""" + select t2.x, m.id, x + from x + where 0=0 + and t2.x=x + and t1.b=1 and t2.b=2 and t1.a = t2.a + """) + subquery.from_table = None + subquery.targets[2] = Parameter(Result(0)) + subquery.where.args[0].args[0].args[0].args[1].args[1] = Parameter(Result(2)) - query = parse_sql(sql, dialect='mindsdb') + + query = parse_sql(sql) expected_plan = QueryPlan( steps=[ + # nested queries FetchDataframeStep(integration='int', - query=parse_sql('select tab3.a as a from tab3 where tab3.x=1', dialect='mindsdb')), + query=parse_sql('select tab0.a as a from tab0 where tab0.x=0')), FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1 as t', dialect='mindsdb')), + query=parse_sql('select tab3.a as a from tab3 where tab3.x=3')), FetchDataframeStep(integration='int', - query=parse_sql('select * from tab2 as t2', dialect='mindsdb')), - JoinStep(left=Result(1), right=Result(2), + query=parse_sql('select tab4.a as a from tab4 where tab4.x=4')), + # tables + FetchDataframeStep(integration='int', + query=parse_sql('select * from tab1 as t1 where b=1')), + FetchDataframeStep(integration='int', query=q_table2), + JoinStep(left=Result(3), right=Result(4), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), - join_type=JoinType.JOIN)), - ApplyPredictorStep(namespace='mindsdb', dataframe=Result(3), - predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': Result(step_num=0)}), - JoinStep(left=Result(3), right=Result(4), + join_type=JoinType.JOIN, + condition=BinaryOperation(op='=', args=[Identifier('t1.x'), Identifier('t2.x')]) + ) + ), + # model + ApplyPredictorStep(namespace='mindsdb', dataframe=Result(5), + predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': Result(1)}), + JoinStep(left=Result(5), right=Result(6), query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), - FilterStep(dataframe=Result(5), query=BinaryOperation(op='=', args=[Constant(0), Constant(0)])), + QueryStep(subquery, from_table=Result(7)), ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] + assert plan.steps == expected_plan.steps def test_model_join_model(self): sql = ''' @@ -800,11 +738,17 @@ def test_model_join_model(self): where m.a = 2 ''' - query = parse_sql(sql, dialect='mindsdb') + subquery = parse_sql(""" + select * from x + where 0=0 + """) + subquery.from_table = None + + query = parse_sql(sql) expected_plan = QueryPlan( steps=[ FetchDataframeStep(integration='int', - query=parse_sql('select * from tab1 as t', dialect='mindsdb')), + query=parse_sql('select * from tab1 as t')), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(0), predictor=Identifier('pred', alias=Identifier('m')), row_dict={'a': 2}), JoinStep(left=Result(0), right=Result(1), @@ -817,10 +761,9 @@ def test_model_join_model(self): query=Join(left=Identifier('tab1'), right=Identifier('tab2'), join_type=JoinType.JOIN)), - FilterStep(dataframe=Result(4), query=BinaryOperation(op='=', args=[Constant(0), Constant(0)])), + QueryStep(subquery, from_table=Result(4)), ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] \ No newline at end of file + assert plan.steps == expected_plan.steps \ No newline at end of file diff --git a/tests/test_planner/test_join_tables.py b/tests/test_planner/test_join_tables.py index b1e0994e..df15bf42 100644 --- a/tests/test_planner/test_join_tables.py +++ b/tests/test_planner/test_join_tables.py @@ -1,3 +1,5 @@ +import copy + import pytest from mindsdb_sql.exceptions import PlanningException @@ -6,7 +8,7 @@ from mindsdb_sql.planner.query_plan import QueryPlan from mindsdb_sql.planner.step_result import Result from mindsdb_sql.planner.steps import (FetchDataframeStep, ProjectStep, FilterStep, JoinStep, GroupByStep, - LimitOffsetStep, OrderByStep, ApplyPredictorStep, SubSelectStep) + LimitOffsetStep, OrderByStep, ApplyPredictorStep, SubSelectStep, QueryStep) from mindsdb_sql.parser.utils import JoinType from mindsdb_sql import parse_sql @@ -39,8 +41,7 @@ def test_join_tables_plan(self): Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN )), - ProjectStep(dataframe=Result(2), - columns=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')]), + QueryStep(parse_sql("select tab1.column1, tab2.column1, tab2.column2"), from_table=Result(2)), ], ) @@ -57,6 +58,10 @@ def test_join_tables_where_plan(self): AND (tab1.column3 = tab2.column3) ''') + subquery = copy.deepcopy(query) + subquery.from_table = None + subquery.offset = None + plan = plan_query(query, integrations=['int', 'int2']) expected_plan = QueryPlan(integrations=['int'], steps=[ @@ -72,40 +77,11 @@ def test_join_tables_where_plan(self): Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN )), - FilterStep(dataframe=Result(2), - query=BinaryOperation('and', - args=[ - BinaryOperation('and', parentheses=True, - args=[ - BinaryOperation('=', parentheses=True, - args=[ - Identifier( - 'tab1.column1'), - Constant( - 1)]), - BinaryOperation('=', parentheses=True, - args=[ - Identifier( - 'tab2.column1'), - Constant( - 0)]), - - ] - ), - BinaryOperation('=', parentheses=True, - args=[Identifier( - 'tab1.column3'), - Identifier( - 'tab2.column3')]), - ] - )), - ProjectStep(dataframe=Result(3), - columns=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')]), + QueryStep(subquery, from_table=Result(2)), ], ) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] + assert plan.steps == expected_plan.steps def test_join_tables_plan_groupby(self): @@ -122,6 +98,11 @@ def test_join_tables_plan_groupby(self): group_by=[Identifier('tab1.column1'), Identifier('tab2.column1')], having=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Constant(0)]) ) + + subquery = copy.deepcopy(query) + subquery.from_table = None + subquery.offset = None + plan = plan_query(query, integrations=['int', 'int2']) expected_plan = QueryPlan(integrations=['int'], steps = [ @@ -142,15 +123,7 @@ def test_join_tables_plan_groupby(self): Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN )), - GroupByStep(dataframe=Result(2), - targets=[Identifier('tab1.column1'), - Identifier('tab2.column1'), - Function('sum', args=[Identifier('tab2.column2')])], - columns=[Identifier('tab1.column1'), Identifier('tab2.column1')]), - FilterStep(dataframe=Result(3), query=BinaryOperation(op='=', args=[Identifier('tab1.column1'), Constant(0)])), - ProjectStep(dataframe=Result(4), - columns=[Identifier('tab1.column1'), Identifier('tab2.column1'), - Function(op='sum', args=[Identifier('tab2.column2')], alias=Identifier('total'))]), + QueryStep(subquery, from_table=Result(2)), ], ) assert plan.steps == expected_plan.steps @@ -166,13 +139,21 @@ def test_join_tables_plan_limit_offset(self): limit=Constant(10), offset=Constant(15), ) + + subquery = copy.deepcopy(query) + subquery.from_table = None + subquery.offset = None + plan = plan_query(query, integrations=['int', 'int2']) expected_plan = QueryPlan(integrations=['int'], steps = [ FetchDataframeStep(integration='int', query=Select( targets=[Star()], - from_table=Identifier('tab1')), + from_table=Identifier('tab1'), + limit=Constant(10), + offset=Constant(15), + ), ), FetchDataframeStep(integration='int2', query=Select(targets=[Star()], @@ -186,9 +167,7 @@ def test_join_tables_plan_limit_offset(self): Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN )), - LimitOffsetStep(dataframe=Result(2), limit=10, offset=15), - ProjectStep(dataframe=Result(3), - columns=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')]), + QueryStep(subquery, from_table=Result(2)), ], ) @@ -206,14 +185,18 @@ def test_join_tables_plan_order_by(self): offset=Constant(15), order_by=[OrderBy(field=Identifier('tab1.column1'))], ) + + subquery = copy.deepcopy(query) + subquery.from_table = None + subquery.offset = None + plan = plan_query(query, integrations=['int', 'int2']) expected_plan = QueryPlan(integrations=['int'], steps = [ - FetchDataframeStep(integration='int', - query=Select( - targets=[Star()], - from_table=Identifier('tab1')), - ), + FetchDataframeStep( + integration='int', + query=parse_sql("select * from tab1 order by column1 limit 10 offset 15") + ), FetchDataframeStep(integration='int2', query=Select(targets=[Star()], from_table=Identifier('tab2')), @@ -226,10 +209,7 @@ def test_join_tables_plan_order_by(self): Identifier('tab2.column1')]), join_type=JoinType.INNER_JOIN )), - OrderByStep(dataframe=Result(2), order_by=[OrderBy(field=Identifier('tab1.column1'))]), - LimitOffsetStep(dataframe=Result(3), limit=10, offset=15), - ProjectStep(dataframe=Result(4), - columns=[Identifier('tab1.column1'), Identifier('tab2.column1'), Identifier('tab2.column2')]), + QueryStep(subquery, from_table=Result(2)), ], ) @@ -355,6 +335,9 @@ def test_complex_join_tables(self): where t1.a=1 and t2.b=2 and 1=1 ''', dialect='mindsdb') + subquery = copy.deepcopy(query) + subquery.from_table = None + plan = plan_query(query, integrations=['int1', 'int2', 'proj'], default_namespace='proj', predictor_metadata=[{'name': 'pred', 'integration_name': 'proj'}]) @@ -387,40 +370,11 @@ def test_complex_join_tables(self): args=[Identifier('tbl3.id'), Identifier('t1.id')]), join_type=JoinType.LEFT_JOIN)), - FilterStep(dataframe=Result(6), - query=BinaryOperation(op='and', - args=( - BinaryOperation(op='and', - args=( - BinaryOperation(op='=', - args=( - Identifier(parts=['t1', 'a']), - Constant(value=1) - ) - ), - BinaryOperation(op='=', - args=( - Identifier(parts=['t2', 'b']), - Constant(value=2) - ) - ) - ) - ), - BinaryOperation(op='=', - args=( - Constant(value=1), - Constant(value=1) - ) - ) - ) - ) - ), - ProjectStep(dataframe=Result(7), columns=[Star()]) + QueryStep(subquery, from_table=Result(6)), ] ) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] + assert plan.steps == expected_plan.steps def test_complex_join_tables_subselect(self): query = parse_sql(''' @@ -442,8 +396,8 @@ def test_complex_join_tables_subselect(self): predictor=Identifier('pred', alias=Identifier('m'))), JoinStep(left=Result(1), right=Result(2), - query=Join(left=Identifier('result_1'), - right=Identifier('result_2'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.JOIN)), SubSelectStep(dataframe=Result(3), query=Select(targets=[Star()]), table_name='t2'), JoinStep( @@ -462,8 +416,7 @@ def test_complex_join_tables_subselect(self): ] ) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] + assert plan.steps == expected_plan.steps def test_join_with_select_from_native_query(self): query = parse_sql(''' diff --git a/tests/test_planner/test_plan_union.py b/tests/test_planner/test_plan_union.py index e3b4995b..fcdf1ec8 100644 --- a/tests/test_planner/test_plan_union.py +++ b/tests/test_planner/test_plan_union.py @@ -1,9 +1,10 @@ +from mindsdb_sql import parse_sql from mindsdb_sql.parser.ast import * from mindsdb_sql.planner import plan_query from mindsdb_sql.planner.query_plan import QueryPlan from mindsdb_sql.planner.step_result import Result from mindsdb_sql.planner.steps import (FetchDataframeStep, ProjectStep, JoinStep, ApplyPredictorStep, - UnionStep) + UnionStep, QueryStep) from mindsdb_sql.parser.utils import JoinType @@ -49,19 +50,16 @@ def test_plan_union_queries(self): ), ApplyPredictorStep(namespace='mindsdb', dataframe=Result(1), predictor=Identifier('pred')), JoinStep(left=Result(1), right=Result(2), - query=Join(left=Identifier('result_1'), - right=Identifier('result_2'), + query=Join(left=Identifier('tab1'), + right=Identifier('tab2'), join_type=JoinType.INNER_JOIN)), - ProjectStep(dataframe=Result(3), columns=[Identifier('tab1.column1'), Identifier('pred.predicted', alias=Identifier('predicted'))]), - + QueryStep(parse_sql("select tab1.column1, pred.predicted as predicted"), from_table=Result(3)), # Union UnionStep(left=Result(0), right=Result(4), unique=False), - ], ) plan = plan_query(query, integrations=['int'], predictor_namespace='mindsdb', predictor_metadata={'pred': {}}) - for i in range(len(plan.steps)): - assert plan.steps[i] == expected_plan.steps[i] + assert plan.steps == expected_plan.steps diff --git a/tests/test_planner/test_prepared_statement.py b/tests/test_planner/test_prepared_statement.py index e576c683..e0368805 100644 --- a/tests/test_planner/test_prepared_statement.py +++ b/tests/test_planner/test_prepared_statement.py @@ -48,6 +48,7 @@ def execute(self, step): {'name': 'predicted', 'type': 'float'}, {'name': 'target', 'type': 'float'}, {'name': 'sqft', 'type': 'float'}, + {'name': 'x', 'type': 'int'}, ] return self.list_cols_return(step.table, cols) return None diff --git a/tests/test_planner/test_ts_predictor.py b/tests/test_planner/test_ts_predictor.py index 1f602d89..ef7cce5a 100644 --- a/tests/test_planner/test_ts_predictor.py +++ b/tests/test_planner/test_ts_predictor.py @@ -2,7 +2,7 @@ import pytest -from mindsdb_sql import parse_sql +from mindsdb_sql import parse_sql, NativeQuery, OrderBy, NullConstant from mindsdb_sql.exceptions import PlanningException from mindsdb_sql.parser.ast import Select, Star, Identifier, Join, Constant, BinaryOperation, Update, BetweenOperation from mindsdb_sql.parser.dialects.mindsdb.latest import Latest @@ -1462,4 +1462,74 @@ def test_dbt_latest(self): for i in range(len(plan.steps)): # print(plan.steps[i]) # print(expected_plan.steps[i]) + assert plan.steps[i] == expected_plan.steps[i] + + + def test_join_native_query(self): + query = parse_sql(''' + SELECT * + FROM int1 (select * from tab) as t + JOIN pred as m + WHERE t.date > LATEST + ''') + + group_by_column = 'type' + + plan = plan_query( + query, + integrations=['int1'], + default_namespace='proj', + predictor_metadata=[{ + 'name': 'pred', + 'integration_name': 'proj', + 'timeseries': True, + 'window': 10, 'horizon': 10, 'order_by_column': 'date', 'group_by_columns': [group_by_column] + }] + ) + + expected_plan = QueryPlan(steps=[ + FetchDataframeStep( + integration='int1', + query=Select( + targets=[Identifier('t.type', alias=Identifier('type'))], + from_table=NativeQuery(query='select * from tab', integration=Identifier('int1'), alias=Identifier('t')), + distinct=True + ) + ), + MapReduceStep( + values=Result(0), + reduce='union', + step=FetchDataframeStep(integration='int1', + query=Select( + targets=[Star()], + from_table=NativeQuery(query='select * from tab', integration=Identifier('int1'), alias=Identifier('t')), + distinct=False, + limit=Constant(10), + order_by=[OrderBy(field=Identifier('t.date'), direction='DESC')], + where=BinaryOperation('and', args=[ + BinaryOperation('is not', args=[Identifier('t.date'), NullConstant()]), + BinaryOperation('=', args=[Identifier('t.type'), Constant('$var[type]')]), + ]) + ) + ), + ), + ApplyTimeseriesPredictorStep( + namespace='proj', + predictor=Identifier('pred', alias=Identifier('m')), + dataframe=Result(1), + output_time_filter=BinaryOperation('>', args=[Identifier('t.date'), Latest()]), + ), + JoinStep( + left=Result(1), + right=Result(2), + query=Join( + left=Identifier('result_1'), + right=Identifier('result_2'), + join_type=JoinType.JOIN + ) + ) + ]) + + assert len(plan.steps) == len(expected_plan.steps) + for i in range(len(plan.steps)): assert plan.steps[i] == expected_plan.steps[i] \ No newline at end of file diff --git a/tests/test_render/test_sqlalchemyrender.py b/tests/test_render/test_sqlalchemyrender.py index 6f49e7cb..77ff7ba4 100644 --- a/tests/test_render/test_sqlalchemyrender.py +++ b/tests/test_render/test_sqlalchemyrender.py @@ -26,7 +26,7 @@ ) -def parse_sql2(sql, dialect='sqlite'): +def parse_sql2(sql, dialect='mindsdb'): # convert to ast query = parse_sql(sql, dialect)