Diff of /src/preprocess.py [000000] .. [ac720d]

Switch to unified view

a b/src/preprocess.py
1
# Copyright 2017 Goekcen Eraslan
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
# ==============================================================================
15
16
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20
21
import pickle, os, numbers
22
23
import numpy as np
24
import scipy as sp
25
import pandas as pd
26
import scanpy as sc
27
from sklearn.model_selection import train_test_split
28
from sklearn.preprocessing import scale
29
import scipy
30
31
#TODO: Fix this
32
class AnnSequence:
33
    def __init__(self, matrix, batch_size, sf=None):
34
        self.matrix = matrix
35
        if sf is None:
36
            self.size_factors = np.ones((self.matrix.shape[0], 1),
37
                                        dtype=np.float32)
38
        else:
39
            self.size_factors = sf
40
        self.batch_size = batch_size
41
42
    def __len__(self):
43
        return len(self.matrix) // self.batch_size
44
45
    def __getitem__(self, idx):
46
        batch = self.matrix[idx*self.batch_size:(idx+1)*self.batch_size]
47
        batch_sf = self.size_factors[idx*self.batch_size:(idx+1)*self.batch_size]
48
49
        # return an (X, Y) pair
50
        return {'count': batch, 'size_factors': batch_sf}, batch
51
52
53
def read_dataset(adata, transpose=False, test_split=False, copy=False):
54
55
    if isinstance(adata, sc.AnnData):
56
        if copy:
57
            adata = adata.copy()
58
    elif isinstance(adata, str):
59
        adata = sc.read(adata)
60
    else:
61
        raise NotImplementedError
62
63
    norm_error = 'Make sure that the dataset (adata.X) contains unnormalized count data.'
64
    assert 'n_count' not in adata.obs, norm_error
65
66
    if adata.X.size < 50e6: # check if adata.X is integer only if array is small
67
        if sp.sparse.issparse(adata.X):
68
            assert (adata.X.astype(int) != adata.X).nnz == 0, norm_error
69
        else:
70
            assert np.all(adata.X.astype(int) == adata.X), norm_error
71
72
    if transpose: adata = adata.transpose()
73
74
    if test_split:
75
        train_idx, test_idx = train_test_split(np.arange(adata.n_obs), test_size=0.1, random_state=42)
76
        spl = pd.Series(['train'] * adata.n_obs)
77
        spl.iloc[test_idx] = 'test'
78
        adata.obs['DCA_split'] = spl.values
79
    else:
80
        adata.obs['DCA_split'] = 'train'
81
82
    adata.obs['DCA_split'] = adata.obs['DCA_split'].astype('category')
83
    print('### Autoencoder: Successfully preprocessed {} genes and {} cells.'.format(adata.n_vars, adata.n_obs))
84
85
    return adata
86
87
def clr_normalize_each_cell(adata):
88
    """Normalize count vector for each cell, i.e. for each row of .X"""
89
90
    def seurat_clr(x):
91
        # TODO: support sparseness
92
        s = np.sum(np.log1p(x[x > 0]))
93
        exp = np.exp(s / len(x))
94
        return np.log1p(x / exp)
95
    
96
    adata.raw = adata.copy()
97
    sc.pp.normalize_per_cell(adata)
98
    adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts)
99
100
    # apply to dense or sparse matrix, along axis. returns dense matrix
101
    adata.X = np.apply_along_axis(
102
        seurat_clr, 1, (adata.raw.X.A if scipy.sparse.issparse(adata.raw.X) else adata.raw.X)
103
    )
104
    return adata
105
    
106
def normalize(adata, filter_min_counts=True, size_factors=True, normalize_input=True, logtrans_input=True):
107
108
    if filter_min_counts:
109
        sc.pp.filter_genes(adata, min_counts=1)
110
        sc.pp.filter_cells(adata, min_counts=1)
111
112
    if size_factors or normalize_input or logtrans_input:
113
        adata.raw = adata.copy()
114
    else:
115
        adata.raw = adata
116
117
    if size_factors:
118
        sc.pp.normalize_per_cell(adata)
119
        adata.obs['size_factors'] = adata.obs.n_counts / np.median(adata.obs.n_counts)
120
    else:
121
        adata.obs['size_factors'] = 1.0
122
123
    if logtrans_input:
124
        sc.pp.log1p(adata)
125
126
    if normalize_input:
127
        sc.pp.scale(adata)
128
129
    return adata
130
131
def read_genelist(filename):
132
    genelist = list(set(open(filename, 'rt').read().strip().split('\n')))
133
    assert len(genelist) > 0, 'No genes detected in genelist file'
134
    print('### Autoencoder: Subset of {} genes will be denoised.'.format(len(genelist)))
135
136
    return genelist
137
138
def write_text_matrix(matrix, filename, rownames=None, colnames=None, transpose=False):
139
    if transpose:
140
        matrix = matrix.T
141
        rownames, colnames = colnames, rownames
142
143
    pd.DataFrame(matrix, index=rownames, columns=colnames).to_csv(filename,
144
                                                                  sep='\t',
145
                                                                  index=(rownames is not None),
146
                                                                  header=(colnames is not None),
147
                                                                  float_format='%.6f')
148
def read_pickle(inputfile):
149
    return pickle.load(open(inputfile, "rb"))