From e12c1b6d5db6c6cef887ec9389979a9e3233eec0 Mon Sep 17 00:00:00 2001 From: SaOligg88 Date: Tue, 23 Feb 2016 10:52:58 +0100 Subject: [PATCH] added create figure --- plotting.py | 143 +++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 130 insertions(+), 13 deletions(-) diff --git a/plotting.py b/plotting.py index 0dd4a4b..a072884 100644 --- a/plotting.py +++ b/plotting.py @@ -12,9 +12,12 @@ def plot_surf_stat_map(coords, faces, stat_map=None, threshold=None, bg_map=None, bg_on_stat=False, alpha='auto', - vmax=None, symmetric_cbar="auto", + vmin=None, vmax=None, + cbar='sequential', # or'diverging' + symmetric_cbar="auto", figsize=None, - labels=None, label_cpal=None, + labels=None, label_col=None, label_cpal=None, + mask=None, mask_lenient=None, **kwargs): import numpy as np @@ -53,6 +56,16 @@ def plot_surf_stat_map(coords, faces, stat_map=None, antialiased=False, color='white') + # where mask is indices of nodes to include: + if mask is not None: + cmask = np.zeros(len(coords)) + cmask[mask] = 1 + cutoff = 2 # include triangles in cortex only if ALL nodes in mask + if mask_lenient: # include triangles in cortex if ANY are in mask + cutoff = 0 + fmask = np.where(cmask[faces].sum(axis=1) > cutoff)[0] + + # If depth_map and/or stat_map are provided, map these onto the surface # set_facecolors function of Poly3DCollection is used as passing the # facecolors argument to plot_trisurf does not seem to work @@ -78,11 +91,18 @@ def plot_surf_stat_map(coords, faces, stat_map=None, stat_map_data = stat_map stat_map_faces = np.mean(stat_map_data[faces], axis=1) - # Call _get_plot_stat_map_params to derive symmetric vmin and vmax - # And colorbar limits depending on symmetric_cbar settings - cbar_vmin, cbar_vmax, vmin, vmax = \ - _get_plot_stat_map_params(stat_map_faces, vmax, - symmetric_cbar, kwargs) + if cbar is 'diverging': + print cbar + # Call _get_plot_stat_map_params to derive symmetric vmin and vmax + # And colorbar limits depending on symmetric_cbar settings + cbar_vmin, cbar_vmax, vmin, vmax = \ + _get_plot_stat_map_params(stat_map_faces, vmax, + symmetric_cbar, kwargs) + if cbar is 'sequential': + if vmin is None: + vmin = stat_map_data.min() + if vmax is None: + vmax = stat_map_data.max() if threshold is not None: kept_indices = np.where(abs(stat_map_faces) >= threshold)[0] @@ -96,9 +116,16 @@ def plot_surf_stat_map(coords, faces, stat_map=None, stat_map_faces = stat_map_faces - vmin stat_map_faces = stat_map_faces / (vmax-vmin) if bg_on_stat: - face_colors = cmap(stat_map_faces) * face_colors + if mask is not None: + face_colors[fmask] = cmap(stat_map_faces)[fmask] * face_colors[fmask] + else: + face_colors = cmap(stat_map_faces) * face_colors else: - face_colors = cmap(stat_map_faces) + if mask is not None: + face_colors[fmask] = cmap(stat_map_faces)[fmask] + + else: + face_colors = cmap(stat_map_faces) if labels is not None: ''' @@ -111,6 +138,8 @@ def plot_surf_stat_map(coords, faces, stat_map=None, valid color names from http://xkcd.com/color/rgb/ ''' if label_cpal is not None: + if label_col is not None: + raise ValueError("Don't use label_cpal and label_col together.") if type(label_cpal) == str: cpal = sns.color_palette(label_cpal, len(labels)) if type(label_cpal) == list: @@ -121,12 +150,18 @@ def plot_surf_stat_map(coords, faces, stat_map=None, except: cpal = sns.xkcd_palette(label_cpal) + + + for n_label, label in enumerate(labels): for n_face, face in enumerate(faces): count = len(set(face).intersection(set(label))) if (count > 0) & (count < 3): if label_cpal is None: - face_colors[n_face,0:3] = sns.xkcd_palette(["black"])[0] + if label_col is not None: + face_colors[n_face,0:3] = sns.xkcd_palette([label_col])[0] + else: + face_colors[n_face,0:3] = sns.xkcd_palette(["black"])[0] else: face_colors[n_face,0:3] = cpal[n_label] @@ -290,14 +325,96 @@ def crop_img(fig, margin=10): kept = {'rows':[], 'cols':[]} for row in range(img.shape[0]): - if len(set(np.ndarray.flatten(img[row,:,:]))) > 3: + if len(set(np.ndarray.flatten(img[row,:,:]))) > 1: kept['rows'].append(row) for col in range(img.shape[1]): - if len(set(np.ndarray.flatten(img[:,col,:]))) > 3: + if len(set(np.ndarray.flatten(img[:,col,:]))) > 1: kept['cols'].append(col) if margin: return img[min(kept['rows'])-margin:max(kept['rows'])+margin, min(kept['cols'])-margin:max(kept['cols'])+margin] else: - return img[kept['rows']][:,kept['cols']] \ No newline at end of file + return img[kept['rows']][:,kept['cols']] + + + + +def create_fig(data=None, labels=None, label_col=None, + hemi=None, surf='pial', + sulc=True, alpha='auto', + cmap='cubehelix', cpal='bright', cbar=False, + dmin=None, dmax=None, + mask=None): + + import nibabel as nib, numpy as np + import matplotlib.pyplot as plt, matplotlib as mpl + from IPython.core.display import Image, display + import os + + fsDir = '/afs/cbs.mpg.de/projects/mar004_lsd-lemon-preproc/freesurfer' + surf_f = '%s/fsaverage5/surf/%s.%s' % (fsDir, hemi, surf) + coords = nib.freesurfer.io.read_geometry(surf_f)[0] + faces = nib.freesurfer.io.read_geometry(surf_f)[1] + if sulc: + sulc_f = '%s/fsaverage5/surf/%s.sulc' % (fsDir, hemi) + sulc = nib.freesurfer.io.read_morph_data(sulc_f) + sulc_bool = True + else: + sulc = None + sulc_bool = False + + # create images + imgs = [] + for azim in [0, 180]: + + if data is not None: + if dmin is None: + dmin = data[np.nonzero(data)].min() + if dmax is None: + dmax = data.max() + fig = plot_surf_stat_map(coords, faces, stat_map=data, + elev=0, azim=azim, + cmap=cmap, + bg_map=sulc,bg_on_stat=sulc_bool, + vmin=dmin, vmax=dmax, + labels=labels, label_col=label_col, + alpha=alpha, + mask=mask, mask_lenient=False) + #label_cpal=cpal) + else: + fig = plot_surf_label(coords, faces, + labels=labels, + elev=0, azim=azim, + bg_map=sulc, + cpal=cpal, + bg_on_labels=sulc_bool, + alpha=alpha) + + # crop image + imgs.append((crop_img(fig, margin=15)),) + plt.close(fig) + + # create figure with color bar + fig = plt.figure() + fig.set_size_inches(8, 4) + + ax1 = plt.subplot2grid((4,60), (0,0), colspan = 26, rowspan =4) + plt.imshow(imgs[0]) + ax1.set_axis_off() + + ax2 = plt.subplot2grid((4,60), (0,28), colspan = 26, rowspan =4) + plt.imshow(imgs[1]) + ax2.set_axis_off() + + if cbar==True and data is not None: + cax = plt.subplot2grid((4,60), (1,59), colspan = 1, rowspan =2) + cmap = plt.cm.get_cmap(cmap) + norm = mpl.colors.Normalize(vmin=dmin, vmax=dmax) + cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm) + cb.set_ticks([dmin, dmax]) + + fig.savefig('./tempimage') + plt.close(fig) + display(Image(filename='./tempimage.png', width=800)) + os.remove('./tempimage.png')