Diff of /dataloader.py [000000] .. [e72cf6]

Switch to unified view

a b/dataloader.py
1
2
import numpy as np
3
import torch
4
import anndata as ad
5
import scanpy as sc
6
7
8
import gc
9
10
11
def load_nips_rna_atac_dataset(mod_file_path, gene_encoding):
12
    adata = ad.read_h5ad(mod_file_path)
13
14
    feature_gex_index = np.array(adata.var.feature_types) == 'GEX'
15
    feature_adt_index = np.array(adata.var.feature_types) == 'ATAC'
16
17
    gex = adata[:, feature_gex_index].copy()
18
    atac = adata[:, feature_adt_index].copy()
19
    del adata
20
21
    gc.collect()
22
23
    index = []
24
    for i in range(gex.shape[1]):
25
        if gex.var['gene_id'][i] != gene_encoding['gene_id'][i]:
26
            print('Warning')
27
        else:
28
            A = bool(gene_encoding['is_gene_coding'][i])
29
            index.append(A)
30
31
    gex = gex[:, index].copy()
32
33
    # gex.var.to_csv('./gex_name.csv')
34
    # atac.var.to_csv('./atac_name.csv')
35
36
    adata_mod1 = gex.copy()
37
    adata_mod1.X = adata_mod1.layers['counts']
38
    del gex
39
40
    adata_mod2 = atac.copy()
41
    adata_mod2.X = adata_mod2.layers['counts']
42
    del atac
43
44
    gc.collect()
45
46
    # obs = adata.obs
47
    # adata_mod1 = ad.AnnData(X=adata.layers['counts'][:, feature_gex_index], obs=obs)
48
    # adata_mod2 = ad.AnnData(X=adata.layers['counts'][:, feature_adt_index], obs=obs)
49
50
    adata_mod1_original = ad.AnnData.copy(adata_mod1)
51
    adata_mod2_original = ad.AnnData.copy(adata_mod2)
52
53
    sc.pp.normalize_total(adata_mod1, target_sum=1e4)
54
    sc.pp.log1p(adata_mod1)
55
    sc.pp.highly_variable_genes(adata_mod1)
56
    index = adata_mod1.var['highly_variable'].values
57
58
    adata_mod1 = ad.AnnData.copy(adata_mod1_original)
59
    adata_mod1 = adata_mod1[:, index].copy()
60
61
    del adata_mod1_original
62
    gc.collect()
63
64
    sc.pp.normalize_total(adata_mod2, target_sum=1e4)
65
    sc.pp.log1p(adata_mod2)
66
    sc.pp.highly_variable_genes(adata_mod2)
67
    index = adata_mod2.var['highly_variable'].values
68
69
    adata_mod2 = ad.AnnData.copy(adata_mod2_original)
70
    del adata_mod2_original
71
    gc.collect()
72
73
    adata_mod2 = adata_mod2[:, index].copy()
74
75
    return adata_mod1, adata_mod2
76
77
def prepare_nips_dataset(adata_gex, adata_mod2,
78
                         batch_col = 'batch',
79
                        ):
80
81
    batch_index = np.array(adata_gex.obs[batch_col].values)
82
    unique_batch = list(np.unique(batch_index))
83
    batch_index = np.array([unique_batch.index(xs) for xs in batch_index])
84
85
    obs = adata_gex.obs
86
    obs.insert(obs.shape[1], 'batch_indices', batch_index)
87
    adata_gex = ad.AnnData(X=adata_gex.X, obs=obs)
88
89
    obs = adata_mod2.obs
90
    obs.insert(obs.shape[1], 'batch_indices', batch_index)
91
92
    X = adata_mod2.X
93
    adata_mod2 = ad.AnnData(X=X, obs=obs)
94
95
    Index = np.array(X.sum(1)>0).squeeze()
96
97
    adata_gex = adata_gex[Index]
98
    obs = adata_gex.obs
99
    adata_gex = ad.AnnData(X=adata_gex.X, obs=obs)
100
101
    adata_mod2 = adata_mod2[Index]
102
    obs = adata_mod2.obs
103
    adata_mod2 = ad.AnnData(X=adata_mod2.X, obs=obs)
104
105
    return adata_gex, adata_mod2
106
107
def data_process_moETM(adata_mod1, adata_mod2):
108
    # train/test on the whole
109
    train_adata_mod1 = adata_mod1
110
    train_adata_mod2 = adata_mod2
111
112
    ########################################################
113
    # Training dataset
114
    X_mod1 = np.array(train_adata_mod1.X.todense())
115
    X_mod2 = np.array(train_adata_mod2.X.todense())
116
    batch_index = np.array(train_adata_mod1.obs['batch_indices'])
117
118
    X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis]
119
    X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis]
120
121
    X_mod1_train_T = torch.from_numpy(X_mod1).float()
122
    X_mod2_train_T = torch.from_numpy(X_mod2).float()
123
    batch_index_train_T = torch.from_numpy(batch_index).to(torch.int64)
124
125
    del X_mod1, X_mod2, batch_index
126
127
    return X_mod1_train_T, X_mod2_train_T, batch_index_train_T, train_adata_mod1
128
129
def data_process_moETM_split(adata_mod1, adata_mod2, n_sample, test_ratio=0.1):
130
    ###### random split for training and testing
131
    from sklearn.utils import resample
132
    Index = np.arange(0, n_sample)
133
    train_index = resample(Index, n_samples=int(n_sample*(1-test_ratio)), replace=False)
134
    test_index = np.array(list(set(range(n_sample)).difference(train_index)))
135
136
    train_adata_mod1 = adata_mod1[train_index]
137
    obs = train_adata_mod1.obs
138
    X = train_adata_mod1.X
139
    train_adata_mod1 = ad.AnnData(X=X, obs=obs)
140
141
    train_adata_mod2 = adata_mod2[train_index]
142
    obs = train_adata_mod2.obs
143
    X = train_adata_mod2.X
144
    train_adata_mod2 = ad.AnnData(X=X, obs=obs)
145
146
    test_adata_mod1 = adata_mod1[test_index]
147
    obs = test_adata_mod1.obs
148
    X = test_adata_mod1.X
149
    test_adata_mod1 = ad.AnnData(X=X, obs=obs)
150
151
    test_adata_mod2 = adata_mod2[test_index]
152
    obs = test_adata_mod2.obs
153
    X = test_adata_mod2.X
154
    test_adata_mod2 = ad.AnnData(X=X, obs=obs)
155
156
    ########################################################
157
    # Training dataset
158
    X_mod1 = np.array(train_adata_mod1.X.todense())
159
    X_mod2 = np.array(train_adata_mod2.X.todense())
160
    batch_index = np.array(train_adata_mod1.obs['batch_indices'])
161
162
    X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis]
163
    X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis]
164
165
    X_mod1_train_T = torch.from_numpy(X_mod1).float()
166
    X_mod2_train_T = torch.from_numpy(X_mod2).float()
167
    batch_index_train_T = torch.from_numpy(batch_index).to(torch.int64).cuda()
168
169
    # Testing dataset
170
    X_mod1 = np.array(test_adata_mod1.X.todense())
171
    X_mod2 = np.array(test_adata_mod2.X.todense())
172
    batch_index = np.array(test_adata_mod1.obs['batch_indices'])
173
174
    X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis]
175
    X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis]
176
177
    X_mod1_test_T = torch.from_numpy(X_mod1).float()
178
    X_mod2_test_T = torch.from_numpy(X_mod2).float()
179
    batch_index_test_T = torch.from_numpy(batch_index).to(torch.int64)
180
181
    del X_mod1, X_mod2, batch_index
182
183
    return X_mod1_train_T, X_mod2_train_T, batch_index_train_T, X_mod1_test_T, X_mod2_test_T, batch_index_test_T, test_adata_mod1
184
185
def data_process_moETM_leave_one_batch(adata_mod1, adata_mod2, batch_index_as_test):
186
    #leave one batch for testing
187
    train_index = (adata_mod1.obs['batch_indices'] != batch_index_as_test)
188
    test_index = (adata_mod1.obs['batch_indices'] == batch_index_as_test)
189
190
    train_adata_mod1 = adata_mod1[train_index]
191
    obs = train_adata_mod1.obs
192
    X = train_adata_mod1.X
193
    train_adata_mod1 = ad.AnnData(X=X, obs=obs)
194
195
    train_adata_mod2 = adata_mod2[train_index]
196
    obs = train_adata_mod2.obs
197
    X = train_adata_mod2.X
198
    train_adata_mod2 = ad.AnnData(X=X, obs=obs)
199
200
    test_adata_mod1 = adata_mod1[test_index]
201
    obs = test_adata_mod1.obs
202
    X = test_adata_mod1.X
203
    test_adata_mod1 = ad.AnnData(X=X, obs=obs)
204
205
    test_adata_mod2 = adata_mod2[test_index]
206
    obs = test_adata_mod2.obs
207
    X = test_adata_mod2.X
208
    test_adata_mod2 = ad.AnnData(X=X, obs=obs)
209
210
    ########################################################
211
    # Training dataset
212
    X_mod1 = np.array(train_adata_mod1.X.todense())
213
    X_mod2 = np.array(train_adata_mod2.X.todense())
214
    batch_index = np.array(train_adata_mod1.obs['batch_indices'])
215
216
    ##convert batch index
217
    batch_mapping = {batch: i for i, batch in enumerate(set(batch_index))}
218
    mapped_index = np.array([batch_mapping[batch] for batch in batch_index])
219
    batch_index = mapped_index
220
    
221
    X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis]
222
    X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis]
223
224
    X_mod1_train_T = torch.from_numpy(X_mod1).float()
225
    X_mod2_train_T = torch.from_numpy(X_mod2).float()
226
    batch_index_train_T = torch.from_numpy(batch_index).to(torch.int64).cuda()
227
228
    # Testing dataset
229
    X_mod1 = np.array(test_adata_mod1.X.todense())
230
    X_mod2 = np.array(test_adata_mod2.X.todense())
231
    batch_index = np.array(test_adata_mod1.obs['batch_indices'])
232
233
    ##convert batch index
234
    batch_mapping = {batch: i for i, batch in enumerate(set(batch_index))}
235
    mapped_index = np.array([batch_mapping[batch] for batch in batch_index])
236
    batch_index = mapped_index
237
    
238
    X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis]
239
    X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis]
240
241
    X_mod1_test_T = torch.from_numpy(X_mod1).float()
242
    X_mod2_test_T = torch.from_numpy(X_mod2).float()
243
    batch_index_test_T = torch.from_numpy(batch_index).to(torch.int64)
244
245
    del X_mod1, X_mod2, batch_index
246
247
    return X_mod1_train_T, X_mod2_train_T, batch_index_train_T, X_mod1_test_T, X_mod2_test_T, batch_index_test_T, test_adata_mod1, train_adata_mod1
248
249
250
def data_process_moETM_cross_prediction(adata_mod1, adata_mod2, n_sample):
251
    from sklearn.utils import resample
252
253
    Index = np.arange(0, n_sample)
254
    train_index = resample(Index, n_samples=n_sample)
255
    test_index = np.array(list(set(range(n_sample)).difference(train_index)))
256
257
    train_adata_mod1 = adata_mod1[train_index]
258
    obs = train_adata_mod1.obs
259
    X = train_adata_mod1.X
260
    train_adata_mod1 = ad.AnnData(X=X, obs=obs)
261
262
    train_adata_mod2 = adata_mod2[train_index]
263
    obs = train_adata_mod2.obs
264
    X = train_adata_mod2.X
265
    train_adata_mod2 = ad.AnnData(X=X, obs=obs)
266
267
    test_adata_mod1 = adata_mod1[test_index]
268
    obs = test_adata_mod1.obs
269
    X = test_adata_mod1.X
270
    test_adata_mod1 = ad.AnnData(X=X, obs=obs)
271
272
    test_adata_mod2 = adata_mod2[test_index]
273
    obs = test_adata_mod2.obs
274
    X = test_adata_mod2.X
275
    test_adata_mod2 = ad.AnnData(X=X, obs=obs)
276
277
    ########################################################
278
    # Training dataset
279
    X_mod1 = np.array(train_adata_mod1.X.todense())
280
    X_mod2 = np.array(train_adata_mod2.X.todense())
281
    batch_index = np.array(train_adata_mod1.obs['batch_indices'])
282
283
    X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis]
284
    X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis]
285
286
    X_mod1_train_T = torch.from_numpy(X_mod1).float()
287
    X_mod2_train_T = torch.from_numpy(X_mod2).float()
288
    batch_index_train_T = torch.from_numpy(batch_index).to(torch.int64).cuda()
289
290
    # Testing dataset
291
    X_mod1 = np.array(test_adata_mod1.X.todense())
292
    X_mod2 = np.array(test_adata_mod2.X.todense())
293
    batch_index = np.array(test_adata_mod1.obs['batch_indices'])
294
295
    sum1 = X_mod1.sum(1)
296
    sum2 = X_mod2.sum(1)
297
298
    X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis]
299
    X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis]
300
301
    X_mod1_test_T = torch.from_numpy(X_mod1).float()
302
    X_mod2_test_T = torch.from_numpy(X_mod2).float()
303
    batch_index_test_T = torch.from_numpy(batch_index).to(torch.int64)
304
305
306
    del X_mod1, X_mod2, batch_index
307
308
    return X_mod1_train_T, X_mod2_train_T, batch_index_train_T, X_mod1_test_T, X_mod2_test_T, batch_index_test_T, test_adata_mod1, train_adata_mod1, sum1, sum2
309
310
311
def load_nips_dataset_rna_protein_dataset(mod_file_path, gene_encoding, protein_encoding):
312
313
    adata = ad.read_h5ad(mod_file_path)
314
315
    feature_gex_index = np.array(adata.var.feature_types) == 'GEX'
316
    feature_adt_index = np.array(adata.var.feature_types) == 'ADT'
317
318
    adata_mod1 = adata[:, feature_gex_index].copy()
319
    adata_mod2 = adata[:, feature_adt_index].copy()
320
321
    adata_mod1.X = adata_mod1.layers['counts']
322
    adata_mod2.X = adata_mod2.layers['counts']
323
324
    index = []
325
    for i in range(adata_mod1.shape[1]):
326
        if adata_mod1.var.index[i] != gene_encoding['X'][i]:
327
            print('Warning')
328
        else:
329
            index.append(bool(gene_encoding['is_gene_coding'][i]))
330
331
    adata_mod1_original = adata_mod1[:, index].copy()
332
    adata_mod1 = adata_mod1[:, index].copy()
333
334
    sc.pp.normalize_total(adata_mod1, target_sum=1e4)
335
    sc.pp.log1p(adata_mod1)
336
    sc.pp.highly_variable_genes(adata_mod1)  # n_top_genes
337
    index = adata_mod1.var['highly_variable'].values
338
339
    adata_mod1_original = adata_mod1_original[:, index].copy()
340
341
    index = []
342
    for i in range(adata_mod2.shape[1]):
343
        if adata_mod2.var.index[i] != protein_encoding['X'][i]:
344
            print('Warning')
345
        else:
346
            index.append(bool(protein_encoding['is_protein_coding'][i]))
347
348
    adata_mod2 = adata_mod2[:, index].copy()
349
350
    return adata_mod1_original, adata_mod2
351
352
def load_nips_rna_atac_dataset_with_pathway(mod_file_path, gene_encoding, gene_pathway):
353
    adata = ad.read_h5ad(mod_file_path)
354
355
    feature_gex_index = np.array(adata.var.feature_types) == 'GEX'
356
    feature_adt_index = np.array(adata.var.feature_types) == 'ATAC'
357
358
    gex = adata[:, feature_gex_index].copy()
359
    atac = adata[:, feature_adt_index].copy()
360
    del adata
361
362
    gc.collect()
363
364
    gene_pathway_sum = gene_pathway.sum(0)
365
    index = []
366
    for i in range(gex.shape[1]):
367
        if gex.var['gene_id'][i] != gene_encoding['gene_id'][i]:
368
            print('Warning')
369
        else:
370
            A = bool(gene_encoding['is_gene_coding'][i])
371
            B = bool(gene_pathway_sum[i])
372
            index.append(A & B)
373
374
    gex = gex[:, index].copy()
375
    gene_pathway = gene_pathway[:, index].copy()
376
377
    adata_mod1 = gex.copy()
378
    adata_mod1.X = adata_mod1.layers['counts']
379
    del gex
380
381
    adata_mod2 = atac.copy()
382
    adata_mod2.X = adata_mod2.layers['counts']
383
    del atac
384
385
    gc.collect()
386
387
    adata_mod1_original = ad.AnnData.copy(adata_mod1)
388
    adata_mod2_original = ad.AnnData.copy(adata_mod2)
389
390
    sc.pp.normalize_total(adata_mod1, target_sum=1e4)
391
    sc.pp.log1p(adata_mod1)
392
    sc.pp.highly_variable_genes(adata_mod1)
393
    index = adata_mod1.var['highly_variable'].values
394
395
    adata_mod1 = ad.AnnData.copy(adata_mod1_original)
396
    adata_mod1 = adata_mod1[:, index].copy()
397
    gene_pathway = gene_pathway[:, index].copy()
398
399
    del adata_mod1_original
400
    gc.collect()
401
402
    sc.pp.normalize_total(adata_mod2, target_sum=1e4)
403
    sc.pp.log1p(adata_mod2)
404
    sc.pp.highly_variable_genes(adata_mod2)
405
    index = adata_mod2.var['highly_variable'].values
406
407
    adata_mod2 = ad.AnnData.copy(adata_mod2_original)
408
    del adata_mod2_original
409
    gc.collect()
410
411
    adata_mod2 = adata_mod2[:, index].copy()
412
413
    return adata_mod1, adata_mod2, gene_pathway