Skip to content

Commit

Permalink
improve plots
Browse files Browse the repository at this point in the history
  • Loading branch information
smorbieu committed Apr 11, 2016
1 parent 05b8844 commit 59fe980
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 41 deletions.
2 changes: 1 addition & 1 deletion coclust/coclust.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def process_output_labels(args, model):
print(model.column_labels_)

if args.get('output_fuzzy_row_labels', None):
print("Save first cols of ordered BW")
print("Save first cols of ordered BW")

if args.get('output_fuzzy_column_labels', None):
print("Save first cols of ordered BtZ")
Expand Down
59 changes: 50 additions & 9 deletions coclust/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,48 @@
import logging
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as PathEffects

from sklearn.metrics.cluster import normalized_mutual_info_score as nmi
from sklearn.metrics.cluster import adjusted_rand_score
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import normalize

plt.style.use('ggplot')


def _remove_ticks():
plt.tick_params(axis='both', which='both', bottom='off', top='off',
right='off', left='off')


def plot_criterion(values, ylabel):
plt.plot(values, marker='o')
plt.ylabel(ylabel)
plt.xlabel('Iterations')
_remove_ticks()
plt.show()


def plot_reorganized_matrix(X, model, precision=0.8, markersize=0.9):
row_indices = np.argsort(model.row_labels_)
col_indices = np.argsort(model.column_labels_)
X_reorg = X[row_indices, :]
X_reorg = X_reorg[:, col_indices]
plt.spy(X_reorg, precision=precision, markersize=markersize)
_remove_ticks()
plt.show()


def plot_convergence(criteria, criterion_name, marker='o'):
plt.plot(criteria, marker=marker)
plt.ylabel(criterion_name)
plt.xlabel('Iterations')
plt.show()
_remove_ticks()
plt.show()


def plot_confusion_matrix(cm, colormap=plt.cm.jet, labels='012'):
def plot_confusion_matrix(cm, colormap=plt.get_cmap("viridis"), labels='012'):
conf_arr = np.array(cm)

norm_conf_arr = []
Expand All @@ -52,23 +70,30 @@ def plot_confusion_matrix(cm, colormap=plt.cm.jet, labels='012'):
plt.clf()
ax = fig.add_subplot(111)
ax.set_aspect(1)
res = ax.imshow(np.array(norm_conf_arr), cmap=plt.cm.jet,
res = ax.imshow(np.array(norm_conf_arr), cmap=colormap,
interpolation='nearest')

width, height = conf_arr.shape

for x in np.arange(width):
for y in np.arange(height):
ax.annotate(str(conf_arr[x][y]), xy=(y, x),
ax.annotate(str(conf_arr[x][y]),
xy=(y, x),
horizontalalignment='center',
verticalalignment='center')
verticalalignment='center',
path_effects=[PathEffects.withStroke(linewidth=3,
foreground="w",
alpha=0.7)])

fig.colorbar(res)
plt.xticks(range(width), labels[:width])
plt.yticks(range(height), labels[:height])
_remove_ticks()
plt.show()


def plot_delta_kl(delta, model, colormap=plt.cm.jet, labels='012'):
def plot_delta_kl(delta, model, colormap=plt.get_cmap("viridis"),
labels='012'):

delta_arr = np.round(np.array(delta), decimals=3)

Expand All @@ -89,12 +114,21 @@ def plot_delta_kl(delta, model, colormap=plt.cm.jet, labels='012'):
(nb_docs, nb_terms),
xy=(y, x),
horizontalalignment='center',
verticalalignment='center')
verticalalignment='center',
path_effects=[PathEffects.withStroke(linewidth=3,
foreground="w",
alpha=0.7)])

fig.colorbar(res)
plt.xticks(range(width), labels[:width])
plt.yticks(range(height), labels[:height])

ax = plt.gca()
ax.grid(False)

_remove_ticks()
plt.show()


def plot_top_terms(model, X, terms, n_cluster, n_terms=10,
x_label="number of occurences"):
Expand All @@ -117,13 +151,16 @@ def plot_top_terms(model, X, terms, n_cluster, n_terms=10,
plt.yticks(.4 + pos, terms[max_indices][::-1])

plt.xlabel(x_label)
plt.margins(y=0.05)
_remove_ticks()
plt.show()


def plot_cluster_sizes(model):
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111)
colors = ['r', 'g', 'b']
prop_list = list(plt.rcParams['axes.prop_cycle'])
colors = [prop_list[0]['color'], prop_list[1]['color']]
x = []
y = []
for i in range(model.n_clusters):
Expand All @@ -140,14 +177,18 @@ def plot_cluster_sizes(model):
legend_rects.append(cols[0])
for c in cols:
h = c.get_height()
ax.text(c.get_x() + c.get_width() / 2., 0.98 * h, '%d' % int(h),
ax.text(c.get_x() + c.get_width() / 2., h + 5, '%d' % int(h),
ha='center', va='bottom')
ax.set_xticks(location + (shift / 2.))
ax.set_xticklabels(['coclust-' + str(i) for i in range(model.n_clusters)])
plt.xlabel('Co-clusters')
plt.ylabel('Sizes')
plt.tight_layout()
ax.legend(legend_rects, ('Rows', 'Columns'))

_remove_ticks()
plt.show()


def print_NMI_and_ARI(true_labels, predicted_labels):
print("NMI:", nmi(true_labels, predicted_labels))
Expand Down
15 changes: 7 additions & 8 deletions demo/demo-coclustinfo.ipynb

Large diffs are not rendered by default.

22 changes: 8 additions & 14 deletions demo/demo-coclustmod.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 59fe980

Please sign in to comment.