Source code for cellCnn.plotting


""" Copyright 2016-2017 ETH Zurich, Eirini Arvaniti and Manfred Claassen.

This module contains functions for plotting the results of a CellCnn analysis.

"""

import os
import sys
from collections import Counter
import numpy as np
import pandas as pd
from scipy import stats
from sklearn.cluster import DBSCAN
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import shuffle
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import ImageGrid
import seaborn as sns
from cellCnn.utils import mkdir_p
import statsmodels.api as sm
try:
    from cellCnn.utils import create_graph
except ImportError:
    pass


[docs]def plot_results(results, samples, phenotypes, labels, outdir, filter_diff_thres=.2, filter_response_thres=0, response_grad_cutoff=None, stat_test=None, log_yscale=False, group_a='group A', group_b='group B', group_names=None, tsne_ncell=10000, regression=False, clustering=None, add_filter_response=False, percentage_drop_cluster=.1, min_cluster_freq=0.2, show_filters=True): """ Plots the results of a CellCnn analysis. Args: - results : Dictionary containing the results of a CellCnn analysis. - samples : Samples from which to visualize the selected cell populations. - phenotypes : List of phenotypes corresponding to the provided `samples`. - labels : Names of measured markers. - outdir : Output directory where the generated plots will be stored. - filter_diff_thres : Threshold that defines which filters are most discriminative. Given an array ``filter_diff`` of average cell filter response differences between classes, sorted in decreasing order, keep a filter ``i, i > 0`` if it holds that ``filter_diff[i-1] - filter_diff[i] < filter_diff_thres * filter_diff[i-1]``. For regression problems, the array ``filter_diff`` contains Kendall's tau values for each filter. - filter_response_thres : Threshold for choosing a responding cell population. Default is 0. - response_grad_cutoff : Threshold on the gradient of the cell filter response CDF, might be useful for defining the selected cell population. - stat_test: None | 'ttest' | 'mannwhitneyu' Optionally, perform a statistical test on selected cell population frequencies between two groups and report the corresponding p-value on the boxplot figure (see plots description below). Default is None. Currently only used for binary classification problems. - group_a : Name of the first class. - group_b : Name of the second class. - group_names : List of names for the different phenotype classes. - log_yscale : If True, display the y-axis of the boxplot figure (see plots description below) in logarithmic scale. - clustering: None | 'dbscan' | 'louvain' Post-processing option for selected cell populations. Default is None. - tsne_ncell : Number of cells to include in t-SNE calculations and plots. - regression : Whether it is a regression problem. - show_filters : Whether to plot learned filter weights. Returns: A list with the indices and corresponding cell filter response thresholds of selected discriminative filters. \ This function also produces a collection of plots for model interpretation. These plots are stored in `outdir`. They comprise the following: - clustered_filter_weights.pdf : Filter weight vectors from all trained networks that pass a validation accuracy threshold, grouped in clusters via hierarchical clustering. Each row corresponds to a filter. The last column(s) indicate the weight(s) connecting each filter to the output class(es). Indices on the y-axis indicate the filter cluster memberships, as a result of the hierarchical clustering procedure. - consensus_filter_weights.pdf : One representative filter per cluster is chosen (the filter with minimum distance to all other memebers of the cluster). We call these selected filters "consensus filters". - best_net_weights.pdf : Filter weight vectors of the network that achieved the highest validation accuracy. - filter_response_differences.pdf : Difference in cell filter response between classes for each consensus filter. To compute this difference for a filter, we first choose a filter-specific class, that's the class with highest output weight connection to the filter. Then we compute the average cell filter response (value after the pooling layer) for validation samples belonging to the filter-specific class (``v1``) and the average cell filter response for validation samples not belonging to the filter-specific class (``v0``). The difference is computed as ``v1 - v0``. For regression problems, we cannot compute a difference between classes. Instead we compute Kendall's rank correlation coefficient between the predictions of each individual filter (value after the pooling layer) and the true response values. This plot helps decide on a cutoff (``filter_diff_thres`` parameter) for selecting discriminative filters. - tsne_all_cells.png : Marker distribution overlaid on t-SNE map. In addition, the following plots are produced for each selected filter (e.g. filter ``i``): - cdf_filter_i.pdf : Cumulative distribution function of cell filter response for filter ``i``. This plot helps decide on a cutoff (``filter_response_thres`` parameter) for selecting the responding cell population. - selected_population_distribution_filter_i.pdf : Histograms of univariate marker expression profiles for the cell population selected by filter ``i`` vs all cells. - selected_population_frequencies_filter_i.pdf : Boxplot of selected cell population frequencies in samples of the different classes, if running a classification problem. For regression settings, a scatter plot of selected cell population frequencies vs response variable is generated. - tsne_cell_response_filter_i.png : Cell filter response overlaid on t-SNE map. - tsne_selected_cells_filter_i.png : Marker distribution of selected cell population overlaid on t-SNE map. """ # create the output directory mkdir_p(outdir) # number of measured markers nmark = samples[0].shape[1] if results['selected_filters'] is not None: print 'Loading the weights of consensus filters.' filters = results['selected_filters'] else: sys.exit('Consensus filters were not found.') if show_filters: plot_filters(results, labels, outdir) # get discriminative filter indices in consensus matrix keep_idx = discriminative_filters(results, outdir, filter_diff_thres, show_filters=show_filters) # encode the sample and sample-phenotype for each cell sample_sizes = [] per_cell_ids = [] for i, x in enumerate(samples): sample_sizes.append(x.shape[0]) per_cell_ids.append(i * np.ones(x.shape[0])) # for each selected filter, plot the selected cell population x = np.vstack(samples) z = np.hstack(per_cell_ids) if results['scaler'] is not None: x = results['scaler'].transform(x) print 'Computing t-SNE projection...' tsne_idx = np.random.choice(x.shape[0], tsne_ncell) x_for_tsne = x[tsne_idx].copy() x_tsne = TSNE(n_components=2).fit_transform(x_for_tsne) vmin, vmax = np.zeros(x.shape[1]), np.zeros(x.shape[1]) for seq_index in range(x.shape[1]): vmin[seq_index] = np.percentile(x[:, seq_index], 1) vmax[seq_index] = np.percentile(x[:, seq_index], 99) fig_path = os.path.join(outdir, 'tsne_all_cells') plot_tsne_grid(x_tsne, x_for_tsne, fig_path, labels=labels, fig_size=(20, 20), point_size=5) return_filters = [] for i_filter in keep_idx: w = filters[i_filter, :nmark] b = filters[i_filter, nmark] g = np.sum(w.reshape(1, -1) * x, axis=1) + b g = g * (g > 0) # skip a filter if it does not select any cell if np.max(g) <= 0: continue ecdf = sm.distributions.ECDF(g) gx = np.linspace(np.min(g), np.max(g)) gy = ecdf(gx) plt.figure() sns.set_style('whitegrid') a = plt.step(gx, gy) t = filter_response_thres # set a threshold to the CDF gradient? if response_grad_cutoff is not None: by = np.array(a[0].get_ydata())[::-1] bx = np.array(a[0].get_xdata())[::-1] b_diff_idx = np.where(by[:-1] - by[1:] >= response_grad_cutoff)[0] if len(b_diff_idx) > 0: t = bx[b_diff_idx[0]+1] plt.plot((t, t), (np.min(gy), 1.), 'r--') plt.xlabel('Cell filter response') plt.ylabel('Cumulative distribution function (CDF)') sns.despine() plt.savefig(os.path.join(outdir, 'cdf_filter_%d.pdf' % i_filter), format='pdf') plt.clf() plt.close() condition = g > t x1 = x[condition] z1 = z[condition] g1 = g[condition] # skip a filter if it does not select any cell with the new cutoff threshold if x1.shape[0] == 0: continue # else add the filters to selected filters return_filters.append((i_filter, t)) # t-SNE plots for characterizing the selected cell population fig_path = os.path.join(outdir, 'tsne_cell_response_filter_%d.png' % i_filter) plot_2D_map(x_tsne, g[tsne_idx], fig_path, s=5) # overlay marker values on TSNE map for selected cells fig_path = os.path.join(outdir, 'tsne_selected_cells_filter_%d' % i_filter) g_tsne = g[tsne_idx] x_pos = x_for_tsne[g_tsne > t] x_tsne_pos = x_tsne[g_tsne > t] plot_tsne_selection_grid(x_tsne_pos, x_pos, x_tsne, vmin, vmax, fig_path=fig_path, labels=labels, fig_size=(20, 20), s=5, suffix='png') if clustering is None: suffix = 'filter_%d' % i_filter plot_selected_subset(x1, z1, x, labels, sample_sizes, phenotypes, outdir, suffix, stat_test, log_yscale, group_a, group_b, group_names, regression) else: if clustering == 'louvain': print 'Creating a k-NN graph with %d/%d cells...' % (x1.shape[0], x.shape[0]) k = 10 G = create_graph(x1, k, g1, add_filter_response) print 'Identifying cell communities...' cl = G.community_fastgreedy() clusters = np.array(cl.as_clustering().membership) else: print 'Clustering using the dbscan algorithm...' eps = set_dbscan_eps(x1, os.path.join(outdir, 'kNN_distances.png')) cl = DBSCAN(eps=eps, min_samples=5, metric='l1') clusters = cl.fit_predict(x1) # discard outliers, i.e. clusters with very few cells c = Counter(clusters) cluster_ids = [] min_cells = int(min_cluster_freq * x1.shape[0]) for key, val in c.items(): if (key != -1) and (val > min_cells): cluster_ids.append(key) num_clusters = len(cluster_ids) scores = np.zeros(num_clusters) for j in range(num_clusters): cl_id = cluster_ids[j] scores[j] = np.mean(g1[clusters == cl_id]) # keep the communities with high cell filter response sorted_idx = np.argsort(scores)[::-1] scores = scores[sorted_idx] keep_idx_comm = [sorted_idx[0]] for i in range(1, num_clusters): if (scores[i-1] - scores[i]) < percentage_drop_cluster * scores[i-1]: keep_idx_comm.append(sorted_idx[i]) else: break for j in keep_idx_comm: cl_id = cluster_ids[j] xc = x1[clusters == cl_id] zc = z1[clusters == cl_id] suffix = 'filter_%d_cluster_%d' % (i_filter, cl_id) plot_selected_subset(xc, zc, x, labels, sample_sizes, phenotypes, outdir, suffix, stat_test, log_yscale, group_a, group_b, group_names, regression) print 'Done.\n' return return_filters
def discriminative_filters(results, outdir, filter_diff_thres, show_filters=True): mkdir_p(outdir) filters = results['selected_filters'] # select the discriminative filters based on the validation set if 'filter_diff' in results: filter_diff = results['filter_diff'] filter_diff[np.isnan(filter_diff)] = -1 sorted_idx = np.argsort(filter_diff)[::-1] filter_diff = filter_diff[sorted_idx] keep_idx = [sorted_idx[0]] for i in range(0, len(filter_diff)-1): if (filter_diff[i] - filter_diff[i+1]) < filter_diff_thres * filter_diff[i]: keep_idx.append(sorted_idx[i+1]) else: break if show_filters: plt.figure() sns.set_style('whitegrid') plt.plot(range(len(filter_diff)), filter_diff, '--') plt.xticks(range(len(filter_diff)), ['filter %d' % i for i in sorted_idx], rotation='vertical') plt.ylabel('average cell filter response difference between classes') sns.despine() plt.savefig(os.path.join(outdir, 'filter_response_differences.pdf'), format='pdf') plt.clf() plt.close() elif 'filter_tau' in results: filter_diff = results['filter_tau'] filter_diff[np.isnan(filter_diff)] = -1 sorted_idx = np.argsort(filter_diff)[::-1] filter_diff = filter_diff[sorted_idx] keep_idx = [sorted_idx[0]] for i in range(0, len(filter_diff)-1): if (filter_diff[i] - filter_diff[i+1]) < filter_diff_thres * filter_diff[i]: keep_idx.append(sorted_idx[i+1]) else: break if show_filters: plt.figure() sns.set_style('whitegrid') plt.plot(range(len(filter_diff)), filter_diff, '--') plt.xticks(range(len(filter_diff)), ['filter %d' % i for i in sorted_idx], rotation='vertical') plt.ylabel('Kendalls tau') sns.despine() plt.savefig(os.path.join(outdir, 'filter_response_differences.pdf'), format='pdf') plt.clf() plt.close() # if no validation samples were provided, keep all consensus filters else: filters = results['selected_filters'] keep_idx = range(filters.shape[0]) return keep_idx def plot_filters(results, labels, outdir): mkdir_p(outdir) nmark = len(labels) # plot the filter weights of the best network w_best = results['w_best_net'] idx_except_bias = np.array(range(nmark) + range(nmark+1, w_best.shape[1])) nc = w_best.shape[1] - (nmark+1) labels_except_bias = labels + ['out %d' % i for i in range(nc)] w_best = w_best[:, idx_except_bias] fig_path = os.path.join(outdir, 'best_net_weights.pdf') plot_nn_weights(w_best, labels_except_bias, fig_path, fig_size=(10, 10)) # plot the filter clustering cl = results['clustering_result'] cl_w = cl['w'][:, idx_except_bias] fig_path = os.path.join(outdir, 'clustered_filter_weights.pdf') plot_nn_weights(cl_w, labels_except_bias, fig_path, row_linkage=cl['cluster_linkage'], y_labels=cl['cluster_assignments'], fig_size=(10, 10)) # plot the selected filters if results['selected_filters'] is not None: w = results['selected_filters'][:, idx_except_bias] fig_path = os.path.join(outdir, 'consensus_filter_weights.pdf') plot_nn_weights(w, labels_except_bias, fig_path, fig_size=(10, 10)) filters = results['selected_filters'] else: sys.exit('Consensus filters were not found.') def plot_nn_weights(w, x_labels, fig_path, row_linkage=None, y_labels=None, fig_size=(10, 3)): if y_labels is None: y_labels = range(w.shape[0]) if w.shape[0] > 1: plt.figure(figsize=fig_size) clmap = sns.clustermap(pd.DataFrame(w, columns=x_labels), method='average', metric='cosine', row_linkage=row_linkage, col_cluster=False, robust=True, yticklabels=y_labels, cmap="RdBu_r") plt.setp(clmap.ax_heatmap.yaxis.get_majorticklabels(), rotation=0) plt.setp(clmap.ax_heatmap.xaxis.get_majorticklabels(), rotation=90) clmap.cax.set_visible(True) else: plt.figure(figsize=(10, 1.5)) ax = sns.heatmap(pd.DataFrame(w, columns=x_labels), robust=True, yticklabels=y_labels) plt.tight_layout() plt.savefig(fig_path) plt.clf() plt.close() def plot_selected_subset(xc, zc, x, labels, sample_sizes, phenotypes, outdir, suffix, stat_test=None, log_yscale=False, group_a='group A', group_b='group B', group_names=None, regression=False): ks_values = [] nmark = x.shape[1] for j in range(nmark): ks = stats.ks_2samp(xc[:, j], x[:, j]) ks_values.append(ks[0]) # sort markers in decreasing order of KS statistic sorted_idx = np.argsort(np.array(ks_values))[::-1] sorted_labels = [labels[i] for i in sorted_idx] sorted_ks = [('KS = %.2f' % ks_values[i]) for i in sorted_idx] fig_path = os.path.join(outdir, 'selected_population_distribution_%s.pdf' % suffix) plot_marker_distribution([x[:, sorted_idx], xc[:, sorted_idx]], ['all cells', 'selected'], sorted_labels, grid_size=(4, 9), ks_list=sorted_ks, figsize=(24, 10), colors=['blue', 'red'], fig_path=fig_path, hist=False) # for classification, plot a boxplot of per class frequencies # for regression, make a biaxial plot (phenotype vs. frequency) if regression: frequencies = [] for i, (n, y_i) in enumerate(zip(sample_sizes, phenotypes)): freq = 100. * np.sum(zc == i) / n frequencies.append(freq) _fig, ax = plt.subplots(figsize=(2.5, 2.5)) plt.scatter(phenotypes, frequencies) if log_yscale: ax.set_yscale('log') plt.ylim(0, np.max(frequencies) + 1) plt.ylabel("selected population frequency (%)") plt.xlabel("response variable") sns.despine() plt.tight_layout() fig_path = os.path.join(outdir, 'selected_population_frequencies_%s.pdf' % suffix) plt.savefig(fig_path) plt.clf() plt.close() else: n_pheno = len(np.unique(phenotypes)) frequencies = dict() for i, (n, y_i) in enumerate(zip(sample_sizes, phenotypes)): freq = 100. * np.sum(zc == i) / n assert freq <= 100 if y_i in frequencies: frequencies[y_i].append(freq) else: frequencies[y_i] = [freq] # optionally, perform a statistical test if (n_pheno == 2) and (stat_test is not None): freq_a, freq_b = frequencies[0], frequencies[1] if stat_test == 'mannwhitneyu': _t, pval = stats.mannwhitneyu(freq_a, freq_b) elif stat_test == 'ttest': _t, pval = stats.ttest_ind(freq_a, freq_b) else: _t, pval = stats.ttest_ind(freq_a, freq_b) else: pval = None # make a boxplot with error bars if group_names is None: if n_pheno == 2: group_names = [group_a, group_b] else: group_names = ['group %d' % (y_i+1) for y_i in range(n_pheno)] box_grade = [] for group_name, y_i in zip(group_names, range(n_pheno)): box_grade += [group_name] * len(frequencies[y_i]) box_data = np.hstack([np.array(frequencies[y_i]) for y_i in range(n_pheno)]) box = pd.DataFrame(np.array(zip(box_grade, box_data)), columns=['group', 'selected population frequency (%)']) box['selected population frequency (%)'] = \ box['selected population frequency (%)'].astype('float64') _fig, ax = plt.subplots(figsize=(2.5, 2.5)) ax = sns.boxplot(x="group", y="selected population frequency (%)", data=box, width=.5, palette=sns.color_palette('Set2')) ax = sns.swarmplot(x="group", y="selected population frequency (%)", data=box, color=".25") if stat_test is not None: ax.text(.45, 1.1, '%s pval = %.2e' % (stat_test, pval), horizontalalignment='center', transform=ax.transAxes, size=8, weight='bold') if log_yscale: ax.set_yscale('log') plt.ylim(0, np.max(box_data) + 1) sns.despine() plt.tight_layout() fig_path = os.path.join(outdir, 'selected_population_frequencies_%s.pdf' % suffix) plt.savefig(fig_path) plt.clf() plt.close() def plot_marker_distribution(datalist, namelist, labels, grid_size, fig_path=None, letter_size=16, figsize=(9, 9), ks_list=None, colors=None, hist=False): nmark = len(labels) assert len(datalist) == len(namelist) g_i, g_j = grid_size sns.set_style('white') if colors is None: colors = sns.color_palette("Set1", n_colors=len(datalist), desat=.5) fig = plt.figure(figsize=figsize) grid = gridspec.GridSpec(g_i, g_j, wspace=0.1, hspace=.6) for i in range(g_i): for j in range(g_j): seq_index = g_j * i + j if seq_index < nmark: ax = fig.add_subplot(grid[i, j]) if ks_list is not None: ax.text(.5, 1.2, labels[seq_index], fontsize=letter_size, ha='center', transform=ax.transAxes) ax.text(.5, 1.02, ks_list[seq_index], fontsize=letter_size-4, ha='center', transform=ax.transAxes) else: ax.text(.5, 1.1, labels[seq_index], fontsize=letter_size, ha='center', transform=ax.transAxes) for i_name, (name, x) in enumerate(zip(namelist, datalist)): lower = np.percentile(x[:, seq_index], 0.5) upper = np.percentile(x[:, seq_index], 99.5) if seq_index == nmark - 1: if hist: plt.hist(x[:, seq_index], np.linspace(lower, upper, 10), color=colors[i_name], label=name, alpha=.5, normed=True) else: sns.kdeplot(x[:, seq_index], shade=True, color=colors[i_name], label=name, clip=(lower, upper)) else: if hist: plt.hist(x[:, seq_index], np.linspace(lower, upper, 10), color=colors[i_name], label=name, alpha=.5, normed=True) else: sns.kdeplot(x[:, seq_index], shade=True, color=colors[i_name], clip=(lower, upper)) ax.get_yaxis().set_ticks([]) #ax.get_xaxis().set_ticks([-2, 0, 2, 4]) #plt.legend(loc="upper right", prop={'size':letter_size}) plt.legend(bbox_to_anchor=(1.5, 0.9)) sns.despine() if fig_path is not None: plt.savefig(fig_path) plt.close() else: plt.show() def set_dbscan_eps(x, fig_path=None): nbrs = NearestNeighbors(n_neighbors=2, metric='l1').fit(x) distances, _indices = nbrs.kneighbors(x) if fig_path is not None: plt.figure() plt.hist(distances[:, 1], bins=20) plt.savefig(fig_path) plt.clf() plt.close() return np.percentile(distances, 90) def make_biaxial(train_feat, valid_feat, test_feat, train_y, valid_y, test_y, figpath, xlabel=None, ylabel=None, add_legend=False): # make the biaxial figure sns.set_style('white') palette = np.array(sns.color_palette("Set2", 3)) plt.figure(figsize=(3, 3)) ax = plt.subplot(aspect='equal') # the training samples ax.scatter(train_feat[:, 0], train_feat[:, 1], s=30, alpha=.5, c=palette[train_y], marker='>', edgecolors='face') # the validation samples ax.scatter(valid_feat[:, 0], valid_feat[:, 1], s=30, alpha=.5, c=palette[valid_y], marker=(5, 1), edgecolors='face') # the test samples ax.scatter(test_feat[:, 0], test_feat[:, 1], s=30, alpha=.5, c=palette[test_y], marker='o', edgecolors='face') # http://stackoverflow.com/questions/13303928/how-to-make-custom-legend-in-matplotlib a1 = plt.Line2D((0, 1), (0, 0), color=palette[0]) a2 = plt.Line2D((0, 1), (0, 0), color=palette[1]) a3 = plt.Line2D((0, 1), (0, 0), color=palette[2]) a4 = plt.Line2D((0, 1), (0, 0), color='k', marker='>', linestyle='', markersize=8) a5 = plt.Line2D((0, 1), (0, 0), color='k', marker=(5, 1), linestyle='', markersize=8) a6 = plt.Line2D((0, 1), (0, 0), color='k', marker='o', linestyle='', markersize=8) #Create legend from custom artist/label lists if add_legend: first_legend = plt.legend([a1, a2, a3], ['healthy', 'CN', 'CBF'], fontsize=16, loc=1, fancybox=True) plt.gca().add_artist(first_legend) plt.legend([a4, a5, a6], ['train', 'valid', 'test'], fontsize=16, loc=4, fancybox=True) #plt.xlim(-2, 2) #plt.ylim(-2, 2) ax.set_aspect('equal', 'datalim') ax.margins(0.1) if xlabel is not None: plt.xlabel(xlabel, fontsize=12) if ylabel is not None: plt.ylabel(ylabel, fontsize=12) plt.tight_layout() sns.despine() plt.savefig(figpath, format='eps') plt.clf() plt.close() def plot_tsne_grid(z, x, fig_path, labels=None, fig_size=(9, 9), g_j=7, suffix='png', point_size=.1): ncol = x.shape[1] g_i = ncol // g_j if (ncol % g_j == 0) else ncol // g_j + 1 if labels is None: labels = [str(a) for a in range(ncol)] sns.set_style('white') fig = plt.figure(figsize=fig_size) fig.clf() grid = ImageGrid(fig, 111, nrows_ncols=(g_i, g_j), ngrids=None if ncol % g_j == 0 else ncol, aspect=True, direction="row", axes_pad=(0.15, 0.5), add_all=True, label_mode="1", share_all=True, cbar_location="top", cbar_mode="each", cbar_size="8%", cbar_pad="5%", ) for seq_index in range(ncol): ax = grid[seq_index] ax.text(0, .92, labels[seq_index], horizontalalignment='center', transform=ax.transAxes, size=20, weight='bold') vmin = np.percentile(x[:, seq_index], 1) vmax = np.percentile(x[:, seq_index], 99) #sns.kdeplot(z[:, 0], z[:, 1], colors='gray', cmap=None, linewidths=0.5) im = ax.scatter(z[:, 0], z[:, 1], s=point_size, marker='o', c=x[:, seq_index], cmap=cm.jet, alpha=0.5, edgecolors='face', vmin=vmin, vmax=vmax) ax.cax.colorbar(im) clean_axis(ax) ax.grid(False) plt.savefig('.'.join([fig_path, suffix]), format=suffix) plt.clf() plt.close() def plot_tsne_selection_grid(z_pos, x_pos, z_neg, vmin, vmax, fig_path, labels=None, fig_size=(9, 9), g_j=7, s=.5, suffix='png'): ncol = x_pos.shape[1] g_i = ncol // g_j if (ncol % g_j == 0) else ncol // g_j + 1 if labels is None: labels = [str(a) for a in np.range(ncol)] fig = plt.figure(figsize=fig_size) fig.clf() grid = ImageGrid(fig, 111, nrows_ncols=(g_i, g_j), ngrids=None if ncol % g_j == 0 else ncol, aspect=True, direction="row", axes_pad=(0.15, 0.5), add_all=True, label_mode="1", share_all=True, cbar_location="top", cbar_mode="each", cbar_size="8%", cbar_pad="5%", ) for seq_index in range(ncol): ax = grid[seq_index] ax.text(0, .92, labels[seq_index], horizontalalignment='center', transform=ax.transAxes, size=20, weight='bold') a = x_pos[:, seq_index] ax.scatter(z_neg[:, 0], z_neg[:, 1], s=s, marker='o', c='lightgray', alpha=0.5, edgecolors='face') im = ax.scatter(z_pos[:, 0], z_pos[:, 1], s=s, marker='o', c=a, cmap=cm.jet, edgecolors='face', vmin=vmin[seq_index], vmax=vmax[seq_index]) ax.cax.colorbar(im) clean_axis(ax) ax.grid(False) plt.savefig('.'.join([fig_path, suffix]), format=suffix) plt.clf() plt.close() def plot_2D_map(z, feat, fig_path, s=2, plot_contours=False): sns.set_style('white') _fig, ax = plt.subplots(figsize=(5, 5)) if plot_contours: sns.kdeplot(z[:, 0], z[:, 1], colors='lightgray', cmap=None, linewidths=0.5) if issubclass(feat.dtype.type, np.integer): c = np.squeeze(feat) colors = sns.color_palette("Set2", len(np.unique(c))) for i in np.unique(c): plt.scatter(z[c == i, 0], z[c == i, 1], s=s, marker='o', c=colors[i], edgecolors='face', label=str(i)) else: im = ax.scatter(z[:, 0], z[:, 1], s=s, marker='o', c=feat, vmin=np.percentile(feat, 1), cmap=cm.jet, alpha=0.5, edgecolors='face', vmax=np.percentile(feat, 99)) # magic parameters from # http://stackoverflow.com/questions/16702479/matplotlib-colorbar-placement-and-size plt.colorbar(im, fraction=0.046, pad=0.04) clean_axis(ax) ax.grid(False) sns.despine() if issubclass(feat.dtype.type, np.integer): plt.legend(loc="upper left", markerscale=5., scatterpoints=1, fontsize=10) plt.xlabel('tSNE dimension 1', fontsize=20) plt.ylabel('tSNE dimension 2', fontsize=20) plt.savefig(fig_path, format=fig_path.split('.')[-1]) plt.clf() plt.close() def plot_tsne_per_sample(z_list, data_labels, fig_dir, fig_size=(9, 9), density=True, scatter=True, colors=None, pref=''): if colors is None: colors = sns.color_palette("husl", len(z_list)) _fig, ax = plt.subplots(figsize=fig_size) for i, z in enumerate(z_list): ax.scatter(z[:, 0], z[:, 1], s=1, marker='o', c=colors[i], alpha=0.5, edgecolors='face', label=data_labels[i]) clean_axis(ax) ax.grid(False) plt.legend(loc="upper left", markerscale=20., scatterpoints=1, fontsize=10) plt.xlabel('t-SNE dimension 1', fontsize=20) plt.ylabel('t-SNE dimension 2', fontsize=20) plt.savefig(os.path.join(fig_dir, pref+'_tsne_all_samples.png'), format='png') plt.clf() plt.close() # density plots if density: for i, z in enumerate(z_list): _fig = plt.figure(figsize=fig_size) sns.kdeplot(z[:, 0], z[:, 1], n_levels=30, shade=True) plt.title(data_labels[i]) plt.savefig(os.path.join(fig_dir, pref+'tsne_density_%d.png' % i), format='png') plt.clf() plt.close() if scatter: for i, z in enumerate(z_list): _fig = plt.figure(figsize=fig_size) plt.scatter(z[:, 0], z[:, 1], s=1, marker='o', c=colors[i], alpha=0.5, edgecolors='face') plt.title(data_labels[i]) plt.savefig(os.path.join(fig_dir, pref+'tsne_scatterplot_%d.png' % i), format='png') plt.clf() plt.close() def clean_axis(ax): ax.get_xaxis().set_ticks([]) ax.get_yaxis().set_ticks([]) for sp in ax.spines.values(): sp.set_visible(False)