diff --git a/cmapPy/visualization/cohort_view.py b/cmapPy/visualization/cohort_view.py
index 3d3e685..0b28545 100644
--- a/cmapPy/visualization/cohort_view.py
+++ b/cmapPy/visualization/cohort_view.py
@@ -10,22 +10,6 @@
logger = logging.getLogger()
-def _add_row_percentages(s):
- '''Convert all columns except for "Total" to a string
- that shows the integer count as well as the percentage
- of Total within the row.'''
- s = s + 0
- index = s.index
- assert "Total" in index
- total = s['Total']
- for label, x in s.iteritems():
- if label == "Total":
- continue
- s[label] = r'{:,d} \
- ({:.0%})'.format(int(x), float(x) / total)
- return s
-
-
def cohort_view_table(df,
category_label="category_label",
category_order="category_order",
@@ -78,7 +62,8 @@ def cohort_view_table(df,
# Test comopound fields
cpd_fields = [c for c in df.columns if 'Test subset' in c]
- df['Test Compounds Total'] = df[cpd_fields].sum(1)
+ if len(cpd_fields) != 0:
+ df['Test Compounds Total'] = df[cpd_fields].sum(1)
df['Grand Total'] = df.iloc[:, :num_categories].sum(1)
df = df.T
df.index.name = None
@@ -97,8 +82,6 @@ def _fmt_total_percentages(x, total):
({:.0%})'''.format(int(x), float(x) / total)
return s
-
- from IPython.display import display
def _add_row_percentages(s):
@@ -117,45 +100,6 @@ def _add_row_percentages(s):
({:.0%})'''.format(int(x), float(x) / total)
return s
-#
-# def cohort_view_table(df,
-# category_label="category_label",
-# category_order="category_order",
-# flags=[],
-# flag_display_labels=[],
-# add_percentages=True):
-#
-#
-# df['Total'] = 1
-# columns = ['Total'] + flags
-# df = (
-# df
-# .groupby([category_order, category_label])[columns]
-# .sum()
-# .sort_index(axis=0, level=category_order)
-# .reset_index(level=[category_order])
-# .drop(columns=category_order)
-# )
-#
-# column_names = ["Total"] + flag_display_labels
-# df.columns = column_names
-# df.index.names=['Category']
-#
-# df = df.T
-# num_categories = len(df.columns)
-# print "num_categories: {}".format(num_categories)
-#
-# # Test comopound fields
-# cpd_columns = [c for c in df.columns if 'Test subset' in c]
-# df['Test Compounds Total'] = df[cpd_columns].sum(1)
-# df['Grand Total'] = df.iloc[:,:num_categories].sum(1)
-# df = df.T
-# df.index.name=None
-#
-# if add_percentages:
-# df = df.transform(_add_row_percentages, axis=1)
-# return df
-
def display_cohort_stats_table(table, barplot_column):
font_family = "Roboto"
@@ -164,9 +108,11 @@ def display_cohort_stats_table(table, barplot_column):
# the last "total" sums
group_ids = [x for x in table.index if 'Total' not in x]
+ barplot_max = table.loc[group_ids, barplot_column].sum()
+
# Sum of numbers in Total column (excluding Grand Total, obviously)
total = table.loc['Grand Total', 'Total']
- return (
+ table_stylized = (
table
.style
.format(
@@ -175,10 +121,9 @@ def display_cohort_stats_table(table, barplot_column):
)
.applymap(lambda x : 'text-align:center;')
.applymap(lambda x: "border-left:solid thin #d65f5f", subset=idx[:, barplot_column])
- .bar(subset=idx[group_ids, barplot_column], color='#FFDACF')
+ .bar(subset=idx[group_ids, barplot_column], color='#FFDACF', vmin=0, vmax=barplot_max)
.applymap(lambda x: "padding:0.5em 1em 0.5em 1em")
.applymap(lambda x: "background:#444;color:white;border:solid thin #000;font-weight:bold", subset=idx['Grand Total', :])
- .applymap(lambda x: "border-top:solid thin #aaa", subset=idx['Test Compounds Total', :])
.applymap(lambda x: "border-left:solid thin #ddd", subset=idx[:, 'Total'])
.set_table_styles(
[
@@ -223,4 +168,7 @@ def display_cohort_stats_table(table, barplot_column):
]
)
)
+ if 'Test Compounds Total' in table.index:
+ table_stylized = table_stylized.applymap(lambda x: "border-top:solid thin #aaa", subset=idx['Test Compounds Total', :])
+ return table_stylized
diff --git a/cmapPy/visualization/test_cohort_view.py b/cmapPy/visualization/test_cohort_view.py
index 6fd747f..1d80cc8 100644
--- a/cmapPy/visualization/test_cohort_view.py
+++ b/cmapPy/visualization/test_cohort_view.py
@@ -29,7 +29,7 @@ def testCohortView(self):
flag_display_labels=column_names
)
- print table
+ print(table)
# plt.savefig("./test_files/cohort_view_test.html", dpi=150)