forked from DebadityaPal/RoBERTa-NL2SQL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdbengine_sqlnet.py
109 lines (96 loc) · 4.98 KB
/
dbengine_sqlnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import records
import re
from babel.numbers import parse_decimal, NumberFormatError
# From original SQLNet code.
# Wonseok modified. 20180607
schema_re = re.compile(r'\((.+)\)') # group (.......) dfdf (.... )group
num_re = re.compile(r'[-+]?\d*\.\d+|\d+') # ? zero or one time appear of preceding character, * zero or several time appear of preceding character.
# Catch something like -34.34, .4543,
# | is 'or'
agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
cond_ops = ['=', '>', '<', 'OP']
class DBEngine:
def __init__(self, fdb):
#fdb = 'data/test.db'
self.db = records.Database('sqlite:///{}'.format(fdb)).get_connection()
def execute_query(self, table_id, query, *args, **kwargs):
return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs)
def execute(self, table_id, select_index, aggregation_index, conditions, lower=True):
if not table_id.startswith('table'):
table_id = 'table_{}'.format(table_id.replace('-', '_'))
table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','')
schema_str = schema_re.findall(table_info)[0]
schema = {}
for tup in schema_str.split(', '):
c, t = tup.split()
schema[c] = t
select = 'col{}'.format(select_index)
agg = agg_ops[aggregation_index]
if agg:
select = '{}({})'.format(agg, select)
where_clause = []
where_map = {}
for col_index, op, val in conditions:
if lower and (isinstance(val, str) or isinstance(val, str)):
val = val.lower()
if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
try:
# print('!!!!!!value of val is: ', val, 'type is: ', type(val))
# val = float(parse_decimal(val)) # somehow it generates error.
val = float(parse_decimal(val, locale='en_US'))
# print('!!!!!!After: val', val)
except NumberFormatError as e:
try:
val = float(num_re.findall(val)[0]) # need to understand and debug this part.
except:
# Although column is of number, selected one is not number. Do nothing in this case.
pass
where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index))
where_map['col{}'.format(col_index)] = val
where_str = ''
if where_clause:
where_str = 'WHERE ' + ' AND '.join(where_clause)
query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str)
#print query
out = self.db.query(query, **where_map)
return [o.result for o in out]
def execute_return_query(self, table_id, select_index, aggregation_index, conditions, lower=True):
if not table_id.startswith('table'):
table_id = 'table_{}'.format(table_id.replace('-', '_'))
table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','')
schema_str = schema_re.findall(table_info)[0]
schema = {}
for tup in schema_str.split(', '):
c, t = tup.split()
schema[c] = t
select = 'col{}'.format(select_index)
agg = agg_ops[aggregation_index]
if agg:
select = '{}({})'.format(agg, select)
where_clause = []
where_map = {}
for col_index, op, val in conditions:
if lower and (isinstance(val, str) or isinstance(val, str)):
val = val.lower()
if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
try:
# print('!!!!!!value of val is: ', val, 'type is: ', type(val))
# val = float(parse_decimal(val)) # somehow it generates error.
val = float(parse_decimal(val, locale='en_US'))
# print('!!!!!!After: val', val)
except NumberFormatError as e:
val = float(num_re.findall(val)[0])
where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index))
where_map['col{}'.format(col_index)] = val
where_str = ''
if where_clause:
where_str = 'WHERE ' + ' AND '.join(where_clause)
query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str)
#print query
out = self.db.query(query, **where_map)
return [o.result for o in out], query
def show_table(self, table_id):
if not table_id.startswith('table'):
table_id = 'table_{}'.format(table_id.replace('-', '_'))
rows = self.db.query('select * from ' +table_id)
print(rows.dataset)