|
a |
|
b/aggmap/show.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
""" |
|
|
3 |
Created on Tue Aug 18 13:01:00 2020 |
|
|
4 |
|
|
|
5 |
@author: SHEN WANXIANG |
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
|
|
|
10 |
import warnings |
|
|
11 |
warnings.filterwarnings("ignore") |
|
|
12 |
import matplotlib.pyplot as plt |
|
|
13 |
import seaborn as sns |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def imshow(x_arr, ax, mode = 'dark', color_list = ['#1300ff','#ff0c00','#25ff00', '#d000ff','#e2ff00', |
|
|
17 |
'#00fff6', '#ff8800', '#fccde5','#178b66', '#8a0075'], |
|
|
18 |
x_max = 255, vmin = -1, vmax = 1,): |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
assert x_arr.ndim == 3, 'input must be 3d array!' |
|
|
22 |
w, h, c = x_arr.shape |
|
|
23 |
assert len(color_list) >= c, 'length of the color list should equal or larger than channel numbers' |
|
|
24 |
|
|
|
25 |
x = x_arr.copy() |
|
|
26 |
x[x == 0] = 'nan' |
|
|
27 |
|
|
|
28 |
xxx = x_arr.sum(axis=-1) |
|
|
29 |
xxx[xxx != 0] = 'nan' |
|
|
30 |
|
|
|
31 |
if mode == 'dark': |
|
|
32 |
cmaps = [sns.dark_palette(color, n_colors = 50, reverse=False) for color in color_list] |
|
|
33 |
|
|
|
34 |
else: |
|
|
35 |
cmaps = [sns.light_palette(color, n_colors = 50, reverse=False) for color in color_list] |
|
|
36 |
|
|
|
37 |
for i in range(c): |
|
|
38 |
data = x[:,:,i]/x_max |
|
|
39 |
sns.heatmap(data, cmap = cmaps[i], vmin = vmin, vmax = vmax, |
|
|
40 |
yticklabels=False, xticklabels=False, cbar=False, ax=ax, ) # linewidths=0.005, linecolor = '0.9' |
|
|
41 |
|
|
|
42 |
if mode == 'dark': |
|
|
43 |
sns.heatmap(xxx, vmin=-10000, vmax=1, cmap = 'Greys', yticklabels=False, xticklabels=False, cbar=False, ax=ax) |
|
|
44 |
else: |
|
|
45 |
sns.heatmap(xxx, vmin=0, vmax=1, cmap = 'Greys', yticklabels=False, xticklabels=False, cbar=False, ax=ax) |
|
|
46 |
ax.axhline(y=0, color='grey',lw=2, ls = '--') |
|
|
47 |
ax.axhline(y=data.shape[0], color='grey',lw=2, ls = '--') |
|
|
48 |
ax.autoscale() |
|
|
49 |
ax.axvline(x=data.shape[1], color='grey',lw=2, ls = '--') |
|
|
50 |
ax.axvline(x=0, color='grey',lw=2, ls = '--') |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
def imshow_wrap(x, mode = 'dark', color_list = ['#1300ff','#ff0c00','#25ff00', '#d000ff','#e2ff00', |
|
|
54 |
'#00fff6', '#ff8800', '#fccde5','#178b66', '#8a0075'], |
|
|
55 |
x_max = 255, vmin = -1, vmax = 1,): |
|
|
56 |
|
|
|
57 |
fig, ax = plt.subplots(figsize=(4,4)) |
|
|
58 |
imshow(x.astype(float),mode = mode, color_list = color_list, ax=ax, x_max = x_max, vmin = vmin, vmax=vmax) |