diff --git a/cmapPy/visualization/cohort_view.py b/cmapPy/visualization/cohort_view.py index 3d3e685..0b28545 100644 --- a/cmapPy/visualization/cohort_view.py +++ b/cmapPy/visualization/cohort_view.py @@ -10,22 +10,6 @@ logger = logging.getLogger() -def _add_row_percentages(s): - '''Convert all columns except for "Total" to a string - that shows the integer count as well as the percentage - of Total within the row.''' - s = s + 0 - index = s.index - assert "Total" in index - total = s['Total'] - for label, x in s.iteritems(): - if label == "Total": - continue - s[label] = r'{:,d} \ - ({:.0%})'.format(int(x), float(x) / total) - return s - - def cohort_view_table(df, category_label="category_label", category_order="category_order", @@ -78,7 +62,8 @@ def cohort_view_table(df, # Test comopound fields cpd_fields = [c for c in df.columns if 'Test subset' in c] - df['Test Compounds Total'] = df[cpd_fields].sum(1) + if len(cpd_fields) != 0: + df['Test Compounds Total'] = df[cpd_fields].sum(1) df['Grand Total'] = df.iloc[:, :num_categories].sum(1) df = df.T df.index.name = None @@ -97,8 +82,6 @@ def _fmt_total_percentages(x, total): ({:.0%})'''.format(int(x), float(x) / total) return s - - from IPython.display import display def _add_row_percentages(s): @@ -117,45 +100,6 @@ def _add_row_percentages(s): ({:.0%})'''.format(int(x), float(x) / total) return s -# -# def cohort_view_table(df, -# category_label="category_label", -# category_order="category_order", -# flags=[], -# flag_display_labels=[], -# add_percentages=True): -# -# -# df['Total'] = 1 -# columns = ['Total'] + flags -# df = ( -# df -# .groupby([category_order, category_label])[columns] -# .sum() -# .sort_index(axis=0, level=category_order) -# .reset_index(level=[category_order]) -# .drop(columns=category_order) -# ) -# -# column_names = ["Total"] + flag_display_labels -# df.columns = column_names -# df.index.names=['Category'] -# -# df = df.T -# num_categories = len(df.columns) -# print "num_categories: {}".format(num_categories) -# -# # Test comopound fields -# cpd_columns = [c for c in df.columns if 'Test subset' in c] -# df['Test Compounds Total'] = df[cpd_columns].sum(1) -# df['Grand Total'] = df.iloc[:,:num_categories].sum(1) -# df = df.T -# df.index.name=None -# -# if add_percentages: -# df = df.transform(_add_row_percentages, axis=1) -# return df - def display_cohort_stats_table(table, barplot_column): font_family = "Roboto" @@ -164,9 +108,11 @@ def display_cohort_stats_table(table, barplot_column): # the last "total" sums group_ids = [x for x in table.index if 'Total' not in x] + barplot_max = table.loc[group_ids, barplot_column].sum() + # Sum of numbers in Total column (excluding Grand Total, obviously) total = table.loc['Grand Total', 'Total'] - return ( + table_stylized = ( table .style .format( @@ -175,10 +121,9 @@ def display_cohort_stats_table(table, barplot_column): ) .applymap(lambda x : 'text-align:center;') .applymap(lambda x: "border-left:solid thin #d65f5f", subset=idx[:, barplot_column]) - .bar(subset=idx[group_ids, barplot_column], color='#FFDACF') + .bar(subset=idx[group_ids, barplot_column], color='#FFDACF', vmin=0, vmax=barplot_max) .applymap(lambda x: "padding:0.5em 1em 0.5em 1em") .applymap(lambda x: "background:#444;color:white;border:solid thin #000;font-weight:bold", subset=idx['Grand Total', :]) - .applymap(lambda x: "border-top:solid thin #aaa", subset=idx['Test Compounds Total', :]) .applymap(lambda x: "border-left:solid thin #ddd", subset=idx[:, 'Total']) .set_table_styles( [ @@ -223,4 +168,7 @@ def display_cohort_stats_table(table, barplot_column): ] ) ) + if 'Test Compounds Total' in table.index: + table_stylized = table_stylized.applymap(lambda x: "border-top:solid thin #aaa", subset=idx['Test Compounds Total', :]) + return table_stylized diff --git a/cmapPy/visualization/test_cohort_view.py b/cmapPy/visualization/test_cohort_view.py index 6fd747f..1d80cc8 100644 --- a/cmapPy/visualization/test_cohort_view.py +++ b/cmapPy/visualization/test_cohort_view.py @@ -29,7 +29,7 @@ def testCohortView(self): flag_display_labels=column_names ) - print table + print(table) # plt.savefig("./test_files/cohort_view_test.html", dpi=150)