Diff of /aggmap/show.py [000000] .. [9e8054]

Switch to unified view

a b/aggmap/show.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Tue Aug 18 13:01:00 2020
4
5
@author: SHEN WANXIANG
6
"""
7
8
9
10
import warnings
11
warnings.filterwarnings("ignore")
12
import matplotlib.pyplot as plt
13
import seaborn as sns
14
15
16
def imshow(x_arr,  ax, mode = 'dark',  color_list = ['#1300ff','#ff0c00','#25ff00', '#d000ff','#e2ff00', 
17
                                                     '#00fff6', '#ff8800', '#fccde5','#178b66', '#8a0075'],
18
           x_max = 255, vmin = -1, vmax = 1,):
19
    
20
    
21
    assert x_arr.ndim == 3, 'input must be 3d array!'
22
    w, h, c = x_arr.shape
23
    assert len(color_list) >= c, 'length of the color list should equal or larger than channel numbers'
24
    
25
    x = x_arr.copy()
26
    x[x == 0] = 'nan'
27
28
    xxx = x_arr.sum(axis=-1)
29
    xxx[xxx != 0] = 'nan'
30
31
    if mode == 'dark':
32
        cmaps = [sns.dark_palette(color, n_colors =  50, reverse=False) for color in color_list]
33
34
    else:
35
        cmaps = [sns.light_palette(color, n_colors =  50, reverse=False) for color in color_list]
36
        
37
    for i in range(c):
38
        data = x[:,:,i]/x_max
39
        sns.heatmap(data, cmap = cmaps[i],  vmin = vmin, vmax = vmax,  
40
                    yticklabels=False, xticklabels=False, cbar=False, ax=ax, ) # linewidths=0.005, linecolor = '0.9'
41
42
    if mode == 'dark':
43
        sns.heatmap(xxx, vmin=-10000, vmax=1, cmap = 'Greys', yticklabels=False, xticklabels=False, cbar=False, ax=ax)
44
    else:
45
        sns.heatmap(xxx, vmin=0, vmax=1, cmap = 'Greys', yticklabels=False, xticklabels=False, cbar=False, ax=ax)
46
        ax.axhline(y=0, color='grey',lw=2, ls =  '--')
47
        ax.axhline(y=data.shape[0], color='grey',lw=2, ls =  '--')
48
        ax.autoscale()
49
        ax.axvline(x=data.shape[1], color='grey',lw=2, ls =  '--')  
50
        ax.axvline(x=0, color='grey',lw=2, ls =  '--')
51
52
53
def imshow_wrap(x,  mode = 'dark', color_list = ['#1300ff','#ff0c00','#25ff00', '#d000ff','#e2ff00', 
54
                                                 '#00fff6', '#ff8800', '#fccde5','#178b66', '#8a0075'], 
55
                x_max = 255, vmin = -1, vmax = 1,):
56
    
57
    fig, ax = plt.subplots(figsize=(4,4))
58
    imshow(x.astype(float),mode = mode, color_list = color_list, ax=ax, x_max = x_max, vmin = vmin, vmax=vmax)