Diff of /utils.py [000000] .. [fe0e8b]

Switch to unified view

a b/utils.py
1
import anndata
2
import scanpy as sc
3
import scipy as s
4
from scipy.sparse import csr_matrix, issparse
5
6
def load_adata(adata_file, metadata_file = None, normalise = False, cells = None, cell_column = "cell", features = None, filter_lowly_expressed_genes = False, set_colors = False, keep_counts=False):
7
8
    adata = sc.read(adata_file)
9
10
    # Convert to sparse matrices
11
    if not s.sparse.issparse(adata.X):
12
        adata.X = csr_matrix(adata.X)
13
    if len(adata.layers.keys())>0:
14
        for i in list(adata.layers.keys()):
15
            if not issparse(adata.layers[i]):
16
                adata.layers[i] = csr_matrix(adata.layers[i])
17
18
    if cells is not None:
19
        tmp = np.mean(np.isin(cells,adata.obs.index.values)==False)
20
        if tmp<1: print("%.2f%% of cells provided are not observed in the adata, taking the intersect..." % (100*tmp))
21
        cells = np.intersect1d(cells,adata.obs.index.values)
22
        adata = adata[cells,:]
23
24
    if features is not None:
25
        adata = adata[:,features]
26
27
    if metadata_file is not None:
28
        metadata = pd.read_table(metadata_file, delimiter="\t", header=0).set_index(cell_column, drop=False)
29
        metadata = metadata.loc[cells]
30
        assert np.all(adata.obs.index.isin(metadata[cell_column]))
31
        # assert np.all(metadata.cell.isin(adata.obs.index))
32
        assert metadata.shape[0] == adata.shape[0]
33
        adata.obs = metadata#.reindex(adata.obs.index)
34
35
    if filter_lowly_expressed_genes:
36
        sc.pp.filter_genes(adata, min_counts=10)
37
38
    if keep_counts:
39
        adata.layers["raw"] = adata.X.copy()
40
41
    if normalise:
42
        sc.pp.normalize_total(adata, target_sum=None, exclude_highly_expressed=False)
43
        sc.pp.log1p(adata)
44
45
    if set_colors:
46
        colPalette_celltypes = [opts["celltype_colors"][i.replace(" ","_").replace("/","_")] for i in sorted(np.unique(adata.obs['celltype']))]
47
        adata.uns['celltype_colors'] = colPalette_celltypes
48
        colPalette_stages = [opts["stage_colors"][i.replace(" ","_").replace("/","_")] for i in sorted(np.unique(adata.obs['stage']))]
49
        adata.uns['stage_colors'] = colPalette_stages
50
51
    return adata
52
53
def scale(X, x_min, x_max):
54
    nom = (X - X.min(axis=0)) * (x_max - x_min)
55
    denom = X.max(axis=0) - X.min(axis=0)
56
    denom[denom == 0] = 1
57
    return x_min + nom / denom
58
59
60
# cmap = custom_div_cmap(11, mincol='g', midcol='0.9' ,maxcol='CornflowerBlue')
61
def custom_div_cmap(numcolors=11, name='custom_div_cmap',
62
                    mincol='blue', midcol='white', maxcol='red'):
63
    """ 
64
    Default is blue to white to red with 11 colors.  
65
    Colors can be specified in any way understandable by matplotlib.colors.ColorConverter.to_rgb()
66
    """
67
68
    from matplotlib.colors import LinearSegmentedColormap 
69
    cmap = LinearSegmentedColormap.from_list(name=name, colors =[mincol, midcol, maxcol], N=numcolors)
70
    return cmap