Skip to content

Commit

Permalink
Merge pull request #2714 from martinholmer/tc-sqldb
Browse files Browse the repository at this point in the history
Add baseline table to output file generated by the tc --sqldb command
  • Loading branch information
jdebacker authored Feb 6, 2024
2 parents b847e4f + d806ac3 commit 51b0941
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
3 changes: 2 additions & 1 deletion taxcalc/cli/tc.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def cli_tc_main():
default=None)
parser.add_argument('--sqldb',
help=('optional flag that writes SQLite database '
'with dump table containing same output as '
'with two tables (baseline and reform) each '
'containing same output variables as '
'produced by --dump option.'),
default=False,
action="store_true")
Expand Down
43 changes: 33 additions & 10 deletions taxcalc/taxcalcio.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,9 @@ def analyze(self, writing_output_file=False,
calculated variables using their Tax-Calculator names
output_sqldb: boolean
whether or not to write SQLite3 database with dump table
containing same output as written by output_dump to a csv file
whether or not to write SQLite3 database with two tables
(baseline and reform) each containing same output as written
by output_dump to a csv file
Returns
-------
Expand All @@ -449,18 +450,28 @@ def analyze(self, writing_output_file=False,
(mtr_paytax, mtr_inctax,
_) = self.calc.mtr(wrt_full_compensation=False,
calc_all_already_called=True)
self.calc_base.calc_all()
calc_base_calculated = True
(mtr_paytax_base, mtr_inctax_base,
_) = self.calc_base.mtr(wrt_full_compensation=False,
calc_all_already_called=True)
else:
# definitely do not need marginal tax rates
mtr_paytax = None
mtr_inctax = None
mtr_paytax_base = None
mtr_inctax_base = None
# extract output if writing_output_file
if writing_output_file:
self.write_output_file(output_dump, dump_varset,
mtr_paytax, mtr_inctax)
self.write_doc_file()
# optionally write --sqldb output to SQLite3 database
if output_sqldb:
self.write_sqldb_file(dump_varset, mtr_paytax, mtr_inctax)
self.write_sqldb_file(
dump_varset, mtr_paytax, mtr_inctax,
mtr_paytax_base, mtr_inctax_base
)
# optionally write --tables output to text file
if output_tables:
if not calc_base_calculated:
Expand All @@ -480,7 +491,9 @@ def write_output_file(self, output_dump, dump_varset,
Write output to CSV-formatted file.
"""
if output_dump:
outdf = self.dump_output(dump_varset, mtr_inctax, mtr_paytax)
outdf = self.dump_output(
self.calc, dump_varset, mtr_inctax, mtr_paytax
)
column_order = sorted(outdf.columns)
else:
outdf = self.minimal_output()
Expand All @@ -504,15 +517,25 @@ def write_doc_file(self):
with open(doc_fname, 'w') as dfile:
dfile.write(doc)

def write_sqldb_file(self, dump_varset, mtr_paytax, mtr_inctax):
def write_sqldb_file(self, dump_varset, mtr_paytax, mtr_inctax,
mtr_paytax_base, mtr_inctax_base):
"""
Write dump output to SQLite3 database table dump.
"""
outdf = self.dump_output(dump_varset, mtr_inctax, mtr_paytax)
assert len(outdf.index) == self.calc.array_len
db_fname = self._output_filename.replace('.csv', '.db')
dbcon = sqlite3.connect(db_fname)
outdf.to_sql('dump', dbcon, if_exists='replace', index=False)
# write baseline table
outdf = self.dump_output(
self.calc_base, dump_varset, mtr_inctax_base, mtr_paytax_base
)
assert len(outdf.index) == self.calc.array_len
outdf.to_sql('baseline', dbcon, if_exists='replace', index=False)
# write reform table
outdf = self.dump_output(
self.calc, dump_varset, mtr_inctax, mtr_paytax
)
assert len(outdf.index) == self.calc.array_len
outdf.to_sql('reform', dbcon, if_exists='replace', index=False)
dbcon.close()
del outdf
gc.collect()
Expand Down Expand Up @@ -687,7 +710,7 @@ def minimal_output(self):
odf = pd.DataFrame(data=odict, columns=varlist)
return odf

def dump_output(self, dump_varset, mtr_inctax, mtr_paytax):
def dump_output(self, calcx, dump_varset, mtr_inctax, mtr_paytax):
"""
Extract dump output and return it as Pandas DataFrame.
"""
Expand All @@ -699,7 +722,7 @@ def dump_output(self, dump_varset, mtr_inctax, mtr_paytax):
# create and return dump output DataFrame
odf = pd.DataFrame()
for varname in varset:
vardata = self.calc.array(varname)
vardata = calcx.array(varname)
if varname in recs_vinfo.INTEGER_VARS:
odf[varname] = vardata
else:
Expand Down

0 comments on commit 51b0941

Please sign in to comment.