Diff of /MultiOmiVAE.py [000000] .. [2d53aa]

Switch to unified view

a b/MultiOmiVAE.py
1
import torch
2
import numpy as np
3
import pandas as pd
4
from sklearn.model_selection import train_test_split
5
from torch import nn, optim
6
from torch.utils.data import Dataset, DataLoader
7
from torch.nn import functional as F
8
from torch.utils.tensorboard import SummaryWriter
9
from earlystoping import Earlystopping
10
from sklearn import metrics
11
12
13
def MultiOmiVAE(input_path, expr_df, methy_chr_df_list, random_seed=42, no_cuda=False, model_parallelism=True,
14
                separate_testing=True, batch_size=32, latent_dim=128, learning_rate=1e-3, p1_epoch_num=50,
15
                p2_epoch_num=100, output_loss_record=True, classifier=True, early_stopping=True):
16
17
    torch.manual_seed(random_seed)
18
    torch.cuda.manual_seed_all(random_seed)
19
20
    device = torch.device('cuda:0' if not no_cuda and torch.cuda.is_available() else 'cpu')
21
    parallel = torch.cuda.device_count() > 1 and model_parallelism
22
23
    # Sample ID and order that has both gene expression and DNA methylation data
24
    sample_id = np.loadtxt(input_path + 'both_samples.tsv', delimiter='\t', dtype='str')
25
26
    # Loading label
27
    label = pd.read_csv(input_path + 'both_samples_tumour_type_digit.tsv', sep='\t', header=0, index_col=0)
28
    class_num = len(label.tumour_type.unique())
29
    label_array = label['tumour_type'].values
30
31
    if separate_testing:
32
        # Get testing set index and training set index
33
        # Separate according to different tumour types
34
        testset_ratio = 0.2
35
        valset_ratio = 0.5
36
37
        train_index, test_index, train_label, test_label = train_test_split(sample_id, label_array,
38
                                                                            test_size=testset_ratio,
39
                                                                            random_state=random_seed,
40
                                                                            stratify=label_array)
41
        val_index, test_index, val_label, test_label = train_test_split(test_index, test_label, test_size=valset_ratio,
42
                                                                        random_state=random_seed, stratify=test_label)
43
44
        expr_df_test = expr_df[test_index]
45
        expr_df_val = expr_df[val_index]
46
        expr_df_train = expr_df[train_index]
47
48
        methy_chr_df_test_list = []
49
        methy_chr_df_val_list = []
50
        methy_chr_df_train_list = []
51
        for chrom_index in range(0, 23):
52
            methy_chr_df_test = methy_chr_df_list[chrom_index][test_index]
53
            methy_chr_df_test_list.append(methy_chr_df_test)
54
            methy_chr_df_val = methy_chr_df_list[chrom_index][val_index]
55
            methy_chr_df_val_list.append(methy_chr_df_val)
56
            methy_chr_df_train = methy_chr_df_list[chrom_index][train_index]
57
            methy_chr_df_train_list.append(methy_chr_df_train)
58
59
    # Get multi-omics dataset information
60
    sample_num = len(sample_id)
61
    expr_feature_num = expr_df.shape[0]
62
    methy_feature_num_list = []
63
    for chrom_index in range(0, 23):
64
        feature_num = methy_chr_df_list[chrom_index].shape[0]
65
        methy_feature_num_list.append(feature_num)
66
    methy_feature_num_array = np.array(methy_feature_num_list)
67
    methy_feature_num = methy_feature_num_array.sum()
68
    print('\nNumber of samples: {}'.format(sample_num))
69
    print('Number of gene expression features: {}'.format(expr_feature_num))
70
    print('Number of methylation features: {}'.format(methy_feature_num))
71
    if classifier:
72
        print('Number of classes: {}'.format(class_num))
73
74
    class MultiOmiDataset(Dataset):
75
        """
76
        Load multi-omics data
77
        """
78
79
        def __init__(self, expr_df, methy_df_list, labels):
80
            self.expr_df = expr_df
81
            self.methy_df_list = methy_df_list
82
            self.labels = labels
83
84
        def __len__(self):
85
            return self.expr_df.shape[1]
86
87
        def __getitem__(self, index):
88
            omics_data = []
89
            # Gene expression tensor index 0
90
            expr_line = self.expr_df.iloc[:, index].values
91
            expr_line_tensor = torch.Tensor(expr_line)
92
            omics_data.append(expr_line_tensor)
93
            # Methylation tensor index 1-23
94
            for methy_chrom_index in range(0, 23):
95
                methy_chr_line = self.methy_df_list[methy_chrom_index].iloc[:, index].values
96
                methy_chr_line_tensor = torch.Tensor(methy_chr_line)
97
                omics_data.append(methy_chr_line_tensor)
98
            label = self.labels[index]
99
            return [omics_data, label]
100
101
    # DataSets and DataLoaders
102
    if separate_testing:
103
        train_dataset = MultiOmiDataset(expr_df=expr_df_train, methy_df_list=methy_chr_df_train_list, labels=train_label)
104
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
105
        val_dataset = MultiOmiDataset(expr_df=expr_df_val, methy_df_list=methy_chr_df_val_list, labels=val_label)
106
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
107
        test_dataset = MultiOmiDataset(expr_df=expr_df_test, methy_df_list=methy_chr_df_test_list, labels=test_label)
108
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
109
    else:
110
        train_dataset = MultiOmiDataset(expr_df=expr_df, methy_df_list=methy_chr_df_list, labels=label_array)
111
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
112
    full_dataset = MultiOmiDataset(expr_df=expr_df, methy_df_list=methy_chr_df_list, labels=label_array)
113
    full_loader = DataLoader(full_dataset, batch_size=batch_size, num_workers=6)
114
115
    # Setting dimensions
116
    latent_space_dim = latent_dim
117
    input_dim_expr = expr_feature_num
118
    input_dim_methy_array = methy_feature_num_array
119
    level_2_dim_expr = 4096
120
    level_2_dim_methy = 256
121
    level_3_dim_expr = 1024
122
    level_3_dim_methy = 1024
123
    level_4_dim = 512
124
    classifier_1_dim = 128
125
    classifier_2_dim = 64
126
    classifier_out_dim = class_num
127
128
    class VAE(nn.Module):
129
        def __init__(self):
130
            super(VAE, self).__init__()
131
            # ENCODER fc layers
132
            # level 1
133
            # Methy input for each chromosome
134
            self.e_fc1_methy_1 = self.fc_layer(input_dim_methy_array[0], level_2_dim_methy)
135
            self.e_fc1_methy_2 = self.fc_layer(input_dim_methy_array[1], level_2_dim_methy)
136
            self.e_fc1_methy_3 = self.fc_layer(input_dim_methy_array[2], level_2_dim_methy)
137
            self.e_fc1_methy_4 = self.fc_layer(input_dim_methy_array[3], level_2_dim_methy)
138
            self.e_fc1_methy_5 = self.fc_layer(input_dim_methy_array[4], level_2_dim_methy)
139
            self.e_fc1_methy_6 = self.fc_layer(input_dim_methy_array[5], level_2_dim_methy)
140
            self.e_fc1_methy_7 = self.fc_layer(input_dim_methy_array[6], level_2_dim_methy)
141
            self.e_fc1_methy_8 = self.fc_layer(input_dim_methy_array[7], level_2_dim_methy)
142
            self.e_fc1_methy_9 = self.fc_layer(input_dim_methy_array[8], level_2_dim_methy)
143
            self.e_fc1_methy_10 = self.fc_layer(input_dim_methy_array[9], level_2_dim_methy)
144
            self.e_fc1_methy_11 = self.fc_layer(input_dim_methy_array[10], level_2_dim_methy)
145
            self.e_fc1_methy_12 = self.fc_layer(input_dim_methy_array[11], level_2_dim_methy)
146
            self.e_fc1_methy_13 = self.fc_layer(input_dim_methy_array[12], level_2_dim_methy)
147
            self.e_fc1_methy_14 = self.fc_layer(input_dim_methy_array[13], level_2_dim_methy)
148
            self.e_fc1_methy_15 = self.fc_layer(input_dim_methy_array[14], level_2_dim_methy)
149
            self.e_fc1_methy_16 = self.fc_layer(input_dim_methy_array[15], level_2_dim_methy)
150
            self.e_fc1_methy_17 = self.fc_layer(input_dim_methy_array[16], level_2_dim_methy)
151
            self.e_fc1_methy_18 = self.fc_layer(input_dim_methy_array[17], level_2_dim_methy)
152
            self.e_fc1_methy_19 = self.fc_layer(input_dim_methy_array[18], level_2_dim_methy)
153
            self.e_fc1_methy_20 = self.fc_layer(input_dim_methy_array[19], level_2_dim_methy)
154
            self.e_fc1_methy_21 = self.fc_layer(input_dim_methy_array[20], level_2_dim_methy)
155
            self.e_fc1_methy_22 = self.fc_layer(input_dim_methy_array[21], level_2_dim_methy)
156
            self.e_fc1_methy_X = self.fc_layer(input_dim_methy_array[22], level_2_dim_methy)
157
158
            # Expr input
159
            self.e_fc1_expr = self.fc_layer(input_dim_expr, level_2_dim_expr)
160
161
            # Level 2
162
            self.e_fc2_methy = self.fc_layer(level_2_dim_methy*23, level_3_dim_methy)
163
            self.e_fc2_expr = self.fc_layer(level_2_dim_expr, level_3_dim_expr)
164
            # self.e_fc2_methy = self.fc_layer(level_2_dim_methy * 23, level_3_dim_methy, dropout=True)
165
            # self.e_fc2_expr = self.fc_layer(level_2_dim_expr, level_3_dim_expr, dropout=True)
166
167
            # Level 3
168
            self.e_fc3 = self.fc_layer(level_3_dim_methy+level_3_dim_expr, level_4_dim)
169
            # self.e_fc3 = self.fc_layer(level_3_dim_methy+level_3_dim_expr, level_4_dim, dropout=True)
170
171
            # Level 4
172
            self.e_fc4_mean = self.fc_layer(level_4_dim, latent_space_dim, activation=0)
173
            self.e_fc4_log_var = self.fc_layer(level_4_dim, latent_space_dim, activation=0)
174
175
            # model parallelism
176
            if parallel:
177
                self.e_fc1_methy_1.to('cuda:0')
178
                self.e_fc1_methy_2.to('cuda:0')
179
                self.e_fc1_methy_3.to('cuda:0')
180
                self.e_fc1_methy_4.to('cuda:0')
181
                self.e_fc1_methy_5.to('cuda:0')
182
                self.e_fc1_methy_6.to('cuda:0')
183
                self.e_fc1_methy_7.to('cuda:0')
184
                self.e_fc1_methy_8.to('cuda:0')
185
                self.e_fc1_methy_9.to('cuda:0')
186
                self.e_fc1_methy_10.to('cuda:0')
187
                self.e_fc1_methy_11.to('cuda:0')
188
                self.e_fc1_methy_12.to('cuda:0')
189
                self.e_fc1_methy_13.to('cuda:0')
190
                self.e_fc1_methy_14.to('cuda:0')
191
                self.e_fc1_methy_15.to('cuda:0')
192
                self.e_fc1_methy_16.to('cuda:0')
193
                self.e_fc1_methy_17.to('cuda:0')
194
                self.e_fc1_methy_18.to('cuda:0')
195
                self.e_fc1_methy_19.to('cuda:0')
196
                self.e_fc1_methy_20.to('cuda:0')
197
                self.e_fc1_methy_21.to('cuda:0')
198
                self.e_fc1_methy_22.to('cuda:0')
199
                self.e_fc1_methy_X.to('cuda:0')
200
                self.e_fc1_expr.to('cuda:0')
201
                self.e_fc2_methy.to('cuda:0')
202
                self.e_fc2_expr.to('cuda:0')
203
                self.e_fc3.to('cuda:0')
204
                self.e_fc4_mean.to('cuda:0')
205
                self.e_fc4_log_var.to('cuda:0')
206
207
            # DECODER fc layers
208
            # Level 4
209
            self.d_fc4 = self.fc_layer(latent_space_dim, level_4_dim)
210
211
            # Level 3
212
            self.d_fc3 = self.fc_layer(level_4_dim, level_3_dim_methy+level_3_dim_expr)
213
            # self.d_fc3 = self.fc_layer(level_4_dim, level_3_dim_methy+level_3_dim_expr, dropout=True)
214
215
            # Level 2
216
            self.d_fc2_methy = self.fc_layer(level_3_dim_methy, level_2_dim_methy*23)
217
            self.d_fc2_expr = self.fc_layer(level_3_dim_expr, level_2_dim_expr)
218
            # self.d_fc2_methy = self.fc_layer(level_3_dim_methy, level_2_dim_methy*23, dropout=True)
219
            # self.d_fc2_expr = self.fc_layer(level_3_dim_expr, level_2_dim_expr, dropout=True)
220
221
            # level 1
222
            # Methy output for each chromosome
223
            self.d_fc1_methy_1 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[0], activation=2)
224
            self.d_fc1_methy_2 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[1], activation=2)
225
            self.d_fc1_methy_3 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[2], activation=2)
226
            self.d_fc1_methy_4 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[3], activation=2)
227
            self.d_fc1_methy_5 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[4], activation=2)
228
            self.d_fc1_methy_6 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[5], activation=2)
229
            self.d_fc1_methy_7 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[6], activation=2)
230
            self.d_fc1_methy_8 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[7], activation=2)
231
            self.d_fc1_methy_9 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[8], activation=2)
232
            self.d_fc1_methy_10 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[9], activation=2)
233
            self.d_fc1_methy_11 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[10], activation=2)
234
            self.d_fc1_methy_12 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[11], activation=2)
235
            self.d_fc1_methy_13 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[12], activation=2)
236
            self.d_fc1_methy_14 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[13], activation=2)
237
            self.d_fc1_methy_15 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[14], activation=2)
238
            self.d_fc1_methy_16 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[15], activation=2)
239
            self.d_fc1_methy_17 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[16], activation=2)
240
            self.d_fc1_methy_18 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[17], activation=2)
241
            self.d_fc1_methy_19 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[18], activation=2)
242
            self.d_fc1_methy_20 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[19], activation=2)
243
            self.d_fc1_methy_21 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[20], activation=2)
244
            self.d_fc1_methy_22 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[21], activation=2)
245
            self.d_fc1_methy_X = self.fc_layer(level_2_dim_methy, input_dim_methy_array[22], activation=2)
246
            # Expr output
247
            self.d_fc1_expr = self.fc_layer(level_2_dim_expr, input_dim_expr, activation=2)
248
249
            # model parallelism
250
            if parallel:
251
                self.d_fc4.to('cuda:1')
252
                self.d_fc3.to('cuda:1')
253
                self.d_fc2_methy.to('cuda:1')
254
                self.d_fc2_expr.to('cuda:1')
255
                self.d_fc1_methy_1.to('cuda:1')
256
                self.d_fc1_methy_2.to('cuda:1')
257
                self.d_fc1_methy_3.to('cuda:1')
258
                self.d_fc1_methy_4.to('cuda:1')
259
                self.d_fc1_methy_5.to('cuda:1')
260
                self.d_fc1_methy_6.to('cuda:1')
261
                self.d_fc1_methy_7.to('cuda:1')
262
                self.d_fc1_methy_8.to('cuda:1')
263
                self.d_fc1_methy_9.to('cuda:1')
264
                self.d_fc1_methy_10.to('cuda:1')
265
                self.d_fc1_methy_11.to('cuda:1')
266
                self.d_fc1_methy_12.to('cuda:1')
267
                self.d_fc1_methy_13.to('cuda:1')
268
                self.d_fc1_methy_14.to('cuda:1')
269
                self.d_fc1_methy_15.to('cuda:1')
270
                self.d_fc1_methy_16.to('cuda:1')
271
                self.d_fc1_methy_17.to('cuda:1')
272
                self.d_fc1_methy_18.to('cuda:1')
273
                self.d_fc1_methy_19.to('cuda:1')
274
                self.d_fc1_methy_20.to('cuda:1')
275
                self.d_fc1_methy_21.to('cuda:1')
276
                self.d_fc1_methy_22.to('cuda:1')
277
                self.d_fc1_methy_X.to('cuda:1')
278
                self.d_fc1_expr.to('cuda:1')
279
280
            # CLASSIFIER fc layers
281
            self.c_fc1 = self.fc_layer(latent_space_dim, classifier_1_dim)
282
            self.c_fc2 = self.fc_layer(classifier_1_dim, classifier_2_dim)
283
            # self.c_fc2 = self.fc_layer(classifier_1_dim, classifier_2_dim, dropout=True)
284
            self.c_fc3 = self.fc_layer(classifier_2_dim, classifier_out_dim, activation=0)
285
286
            # model parallelism
287
            if parallel:
288
                self.c_fc1.to('cuda:1')
289
                self.c_fc2.to('cuda:1')
290
                self.c_fc3.to('cuda:1')
291
292
        # Activation - 0: no activation, 1: ReLU, 2: Sigmoid
293
        def fc_layer(self, in_dim, out_dim, activation=1, dropout=False, dropout_p=0.5):
294
            if activation == 0:
295
                layer = nn.Sequential(
296
                    nn.Linear(in_dim, out_dim),
297
                    nn.BatchNorm1d(out_dim))
298
            elif activation == 2:
299
                layer = nn.Sequential(
300
                    nn.Linear(in_dim, out_dim),
301
                    nn.BatchNorm1d(out_dim),
302
                    nn.Sigmoid())
303
            else:
304
                if dropout:
305
                    layer = nn.Sequential(
306
                        nn.Linear(in_dim, out_dim),
307
                        nn.BatchNorm1d(out_dim),
308
                        nn.ReLU(),
309
                        nn.Dropout(p=dropout_p))
310
                else:
311
                    layer = nn.Sequential(
312
                        nn.Linear(in_dim, out_dim),
313
                        nn.BatchNorm1d(out_dim),
314
                        nn.ReLU())
315
            return layer
316
317
        def encode(self, x):
318
            methy_1_level2_layer = self.e_fc1_methy_1(x[1])
319
            methy_2_level2_layer = self.e_fc1_methy_2(x[2])
320
            methy_3_level2_layer = self.e_fc1_methy_3(x[3])
321
            methy_4_level2_layer = self.e_fc1_methy_4(x[4])
322
            methy_5_level2_layer = self.e_fc1_methy_5(x[5])
323
            methy_6_level2_layer = self.e_fc1_methy_6(x[6])
324
            methy_7_level2_layer = self.e_fc1_methy_7(x[7])
325
            methy_8_level2_layer = self.e_fc1_methy_8(x[8])
326
            methy_9_level2_layer = self.e_fc1_methy_9(x[9])
327
            methy_10_level2_layer = self.e_fc1_methy_10(x[10])
328
            methy_11_level2_layer = self.e_fc1_methy_11(x[11])
329
            methy_12_level2_layer = self.e_fc1_methy_12(x[12])
330
            methy_13_level2_layer = self.e_fc1_methy_13(x[13])
331
            methy_14_level2_layer = self.e_fc1_methy_14(x[14])
332
            methy_15_level2_layer = self.e_fc1_methy_15(x[15])
333
            methy_16_level2_layer = self.e_fc1_methy_16(x[16])
334
            methy_17_level2_layer = self.e_fc1_methy_17(x[17])
335
            methy_18_level2_layer = self.e_fc1_methy_18(x[18])
336
            methy_19_level2_layer = self.e_fc1_methy_19(x[19])
337
            methy_20_level2_layer = self.e_fc1_methy_20(x[20])
338
            methy_21_level2_layer = self.e_fc1_methy_21(x[21])
339
            methy_22_level2_layer = self.e_fc1_methy_22(x[22])
340
            methy_X_level2_layer = self.e_fc1_methy_X(x[23])
341
342
            # concat methy tensor together
343
            methy_level2_layer = torch.cat((methy_1_level2_layer, methy_2_level2_layer, methy_3_level2_layer,
344
                                            methy_4_level2_layer, methy_5_level2_layer, methy_6_level2_layer,
345
                                            methy_7_level2_layer, methy_8_level2_layer, methy_9_level2_layer,
346
                                            methy_10_level2_layer, methy_11_level2_layer, methy_12_level2_layer,
347
                                            methy_13_level2_layer, methy_14_level2_layer, methy_15_level2_layer,
348
                                            methy_16_level2_layer, methy_17_level2_layer, methy_18_level2_layer,
349
                                            methy_19_level2_layer, methy_20_level2_layer, methy_21_level2_layer,
350
                                            methy_22_level2_layer, methy_X_level2_layer), 1)
351
352
            expr_level2_layer = self.e_fc1_expr(x[0])
353
354
            methy_level3_layer = self.e_fc2_methy(methy_level2_layer)
355
            expr_level3_layer = self.e_fc2_expr(expr_level2_layer)
356
            level_3_layer = torch.cat((methy_level3_layer, expr_level3_layer), 1)
357
358
            level_4_layer = self.e_fc3(level_3_layer)
359
360
            latent_mean = self.e_fc4_mean(level_4_layer)
361
            latent_log_var = self.e_fc4_log_var(level_4_layer)
362
363
            return latent_mean, latent_log_var
364
365
        def reparameterize(self, mean, log_var):
366
            sigma = torch.exp(0.5 * log_var)
367
            eps = torch.randn_like(sigma)
368
            return mean + eps * sigma
369
370
        def decode(self, z):
371
            level_4_layer = self.d_fc4(z)
372
373
            level_3_layer = self.d_fc3(level_4_layer)
374
            methy_level3_layer = level_3_layer.narrow(1, 0, level_3_dim_methy)
375
            expr_level3_layer = level_3_layer.narrow(1, level_3_dim_methy, level_3_dim_expr)
376
377
            methy_level2_layer = self.d_fc2_methy(methy_level3_layer)
378
            methy_1_level2_layer = methy_level2_layer.narrow(1, 0, level_2_dim_methy)
379
            methy_2_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy, level_2_dim_methy)
380
            methy_3_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*2, level_2_dim_methy)
381
            methy_4_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*3, level_2_dim_methy)
382
            methy_5_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*4, level_2_dim_methy)
383
            methy_6_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*5, level_2_dim_methy)
384
            methy_7_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*6, level_2_dim_methy)
385
            methy_8_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*7, level_2_dim_methy)
386
            methy_9_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*8, level_2_dim_methy)
387
            methy_10_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*9, level_2_dim_methy)
388
            methy_11_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*10, level_2_dim_methy)
389
            methy_12_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*11, level_2_dim_methy)
390
            methy_13_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*12, level_2_dim_methy)
391
            methy_14_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*13, level_2_dim_methy)
392
            methy_15_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*14, level_2_dim_methy)
393
            methy_16_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*15, level_2_dim_methy)
394
            methy_17_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*16, level_2_dim_methy)
395
            methy_18_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*17, level_2_dim_methy)
396
            methy_19_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*18, level_2_dim_methy)
397
            methy_20_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*19, level_2_dim_methy)
398
            methy_21_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*20, level_2_dim_methy)
399
            methy_22_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*21, level_2_dim_methy)
400
            methy_X_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*22, level_2_dim_methy)
401
402
            expr_level2_layer = self.d_fc2_expr(expr_level3_layer)
403
404
            recon_x0 = self.d_fc1_expr(expr_level2_layer)
405
406
            recon_x1 = self.d_fc1_methy_1(methy_1_level2_layer)
407
            recon_x2 = self.d_fc1_methy_2(methy_2_level2_layer)
408
            recon_x3 = self.d_fc1_methy_3(methy_3_level2_layer)
409
            recon_x4 = self.d_fc1_methy_4(methy_4_level2_layer)
410
            recon_x5 = self.d_fc1_methy_5(methy_5_level2_layer)
411
            recon_x6 = self.d_fc1_methy_6(methy_6_level2_layer)
412
            recon_x7 = self.d_fc1_methy_7(methy_7_level2_layer)
413
            recon_x8 = self.d_fc1_methy_8(methy_8_level2_layer)
414
            recon_x9 = self.d_fc1_methy_9(methy_9_level2_layer)
415
            recon_x10 = self.d_fc1_methy_10(methy_10_level2_layer)
416
            recon_x11 = self.d_fc1_methy_11(methy_11_level2_layer)
417
            recon_x12 = self.d_fc1_methy_12(methy_12_level2_layer)
418
            recon_x13 = self.d_fc1_methy_13(methy_13_level2_layer)
419
            recon_x14 = self.d_fc1_methy_14(methy_14_level2_layer)
420
            recon_x15 = self.d_fc1_methy_15(methy_15_level2_layer)
421
            recon_x16 = self.d_fc1_methy_16(methy_16_level2_layer)
422
            recon_x17 = self.d_fc1_methy_17(methy_17_level2_layer)
423
            recon_x18 = self.d_fc1_methy_18(methy_18_level2_layer)
424
            recon_x19 = self.d_fc1_methy_19(methy_19_level2_layer)
425
            recon_x20 = self.d_fc1_methy_20(methy_20_level2_layer)
426
            recon_x21 = self.d_fc1_methy_21(methy_21_level2_layer)
427
            recon_x22 = self.d_fc1_methy_22(methy_22_level2_layer)
428
            recon_x23 = self.d_fc1_methy_X(methy_X_level2_layer)
429
430
            return [recon_x0, recon_x1, recon_x2, recon_x3, recon_x4, recon_x5, recon_x6, recon_x7, recon_x8, recon_x9,
431
                    recon_x10, recon_x11, recon_x12, recon_x13, recon_x14, recon_x15, recon_x16, recon_x17, recon_x18,
432
                    recon_x19, recon_x20, recon_x21, recon_x22, recon_x23]
433
434
        def classifier(self, mean):
435
            level_1_layer = self.c_fc1(mean)
436
            level_2_layer = self.c_fc2(level_1_layer)
437
            output_layer = self.c_fc3(level_2_layer)
438
            return output_layer
439
440
        def forward(self, x):
441
            mean, log_var = self.encode(x)
442
            z = self.reparameterize(mean, log_var)
443
            classifier_x = mean
444
            if parallel:
445
                z = z.to('cuda:1')
446
                classifier_x = classifier_x.to('cuda:1')
447
            recon_x = self.decode(z)
448
            pred_y = self.classifier(classifier_x)
449
            return z, recon_x, mean, log_var, pred_y
450
451
    # Instantiate VAE
452
    if parallel:
453
        vae_model = VAE()
454
    else:
455
        vae_model = VAE().to(device)
456
457
    # Early Stopping
458
    if early_stopping:
459
        early_stop_ob = Earlystopping()
460
461
    # Tensorboard writer
462
    train_writer = SummaryWriter(log_dir='logs/train')
463
    val_writer = SummaryWriter(log_dir='logs/val')
464
465
    # print the model information
466
    # print('\nModel information:')
467
    # print(vae_model)
468
    total_params = sum(params.numel() for params in vae_model.parameters())
469
    print('Number of parameters: {}'.format(total_params))
470
471
    optimizer = optim.Adam(vae_model.parameters(), lr=learning_rate)
472
473
    def methy_recon_loss(recon_x, x):
474
        loss = F.binary_cross_entropy(recon_x[1], x[1], reduction='sum')
475
        for i in range(2, 24):
476
            loss += F.binary_cross_entropy(recon_x[i], x[i], reduction='sum')
477
        loss /= 23
478
        return loss
479
480
    def expr_recon_loss(recon_x, x):
481
        loss = F.binary_cross_entropy(recon_x[0], x[0], reduction='sum')
482
        return loss
483
484
    def kl_loss(mean, log_var):
485
        loss = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
486
        return loss
487
488
    def classifier_loss(pred_y, y):
489
        loss = F.cross_entropy(pred_y, y, reduction='sum')
490
        return loss
491
492
    # k_methy_recon = 1
493
    # k_expr_recon = 1
494
    # k_kl = 1
495
    # k_class = 1
496
497
    # loss record
498
    loss_array = np.zeros(shape=(11, p1_epoch_num+p2_epoch_num+1))
499
    # performance metrics
500
    metrics_array = np.zeros(4)
501
502
    def train(e_index, e_num, k_methy_recon, k_expr_recon, k_kl, k_c):
503
        vae_model.train()
504
        train_methy_recon = 0
505
        train_expr_recon = 0
506
        train_kl = 0
507
        train_classifier = 0
508
        train_correct_num = 0
509
        train_total_loss = 0
510
        for batch_index, sample in enumerate(train_loader):
511
            data = sample[0]
512
            y = sample[1]
513
            for chr_i in range(24):
514
                data[chr_i] = data[chr_i].to(device)
515
            y = y.to(device)
516
            optimizer.zero_grad()
517
            _, recon_data, mean, log_var, pred_y = vae_model(data)
518
            if parallel:
519
                for chr_i in range(24):
520
                    recon_data[chr_i] = recon_data[chr_i].to('cuda:0')
521
                pred_y = pred_y.to('cuda:0')
522
523
            methy_recon = methy_recon_loss(recon_data, data)
524
            expr_recon = expr_recon_loss(recon_data, data)
525
            kl = kl_loss(mean, log_var)
526
            class_loss = classifier_loss(pred_y, y)
527
            loss = k_methy_recon * methy_recon + k_expr_recon * expr_recon + k_kl * kl + k_c * class_loss
528
529
            loss.backward()
530
531
            with torch.no_grad():
532
                pred_y_softmax = F.softmax(pred_y, dim=1)
533
                _, predicted = torch.max(pred_y_softmax, 1)
534
                correct = (predicted == y).sum().item()
535
536
                train_methy_recon += methy_recon.item()
537
                train_expr_recon += expr_recon.item()
538
                train_kl += kl.item()
539
                train_classifier += class_loss.item()
540
                train_correct_num += correct
541
                train_total_loss += loss.item()
542
543
            optimizer.step()
544
545
            # if batch_index % 15 == 0:
546
            #     print('Epoch {:3d}/{:3d}  ---  [{:5d}/{:5d}] ({:2d}%)\n'
547
            #           'Methy Recon Loss: {:.2f}   Expr Recon Loss: {:.2f}   KL Loss: {:.2f}   '
548
            #           'Classification Loss: {:.2f}\nACC: {:.2f}%'.format(
549
            #         e_index + 1, e_num, batch_index * len(data[0]), len(train_dataset),
550
            #         round(100. * batch_index / len(train_loader)), methy_recon.item() / len(data[0]),
551
            #         expr_recon.item() / len(data[0]), kl.item() / len(data[0]), class_loss.item() / len(data[0]),
552
            #         correct / len(data[0]) * 100))
553
554
        train_methy_recon_ave = train_methy_recon / len(train_dataset)
555
        train_expr_recon_ave = train_expr_recon / len(train_dataset)
556
        train_kl_ave = train_kl / len(train_dataset)
557
        train_classifier_ave = train_classifier / len(train_dataset)
558
        train_accuracy = train_correct_num / len(train_dataset) * 100
559
        train_total_loss_ave = train_total_loss / len(train_dataset)
560
561
        print('Epoch {:3d}/{:3d}\n'
562
              'Training\n'
563
              'Methy Recon Loss: {:.2f}   Expr Recon Loss: {:.2f}   KL Loss: {:.2f}   '
564
              'Classification Loss: {:.2f}\nACC: {:.2f}%'.
565
              format(e_index + 1, e_num, train_methy_recon_ave, train_expr_recon_ave, train_kl_ave,
566
                     train_classifier_ave, train_accuracy))
567
        loss_array[0, e_index] = train_methy_recon_ave
568
        loss_array[1, e_index] = train_expr_recon_ave
569
        loss_array[2, e_index] = train_kl_ave
570
        loss_array[3, e_index] = train_classifier_ave
571
        loss_array[4, e_index] = train_accuracy
572
573
        # TB
574
        train_writer.add_scalar('Total loss', train_total_loss_ave, e_index)
575
        train_writer.add_scalar('Methy recon loss', train_methy_recon_ave, e_index)
576
        train_writer.add_scalar('Expr recon loss', train_expr_recon_ave, e_index)
577
        train_writer.add_scalar('KL loss', train_kl_ave, e_index)
578
        train_writer.add_scalar('Classification loss', train_classifier_ave, e_index)
579
        train_writer.add_scalar('Accuracy', train_accuracy, e_index)
580
581
    if separate_testing:
582
        def val(e_index, get_metrics=False):
583
            vae_model.eval()
584
            val_methy_recon = 0
585
            val_expr_recon = 0
586
            val_kl = 0
587
            val_classifier = 0
588
            val_correct_num = 0
589
            val_total_loss = 0
590
            y_store = torch.tensor([0])
591
            predicted_store = torch.tensor([0])
592
593
            with torch.no_grad():
594
                for batch_index, sample in enumerate(val_loader):
595
                    data = sample[0]
596
                    y = sample[1]
597
                    for chr_i in range(24):
598
                        data[chr_i] = data[chr_i].to(device)
599
                    y = y.to(device)
600
                    _, recon_data, mean, log_var, pred_y = vae_model(data)
601
                    if parallel:
602
                        for chr_i in range(24):
603
                            recon_data[chr_i] = recon_data[chr_i].to('cuda:0')
604
                        pred_y = pred_y.to('cuda:0')
605
606
                    methy_recon = methy_recon_loss(recon_data, data)
607
                    expr_recon = expr_recon_loss(recon_data, data)
608
                    kl = kl_loss(mean, log_var)
609
                    class_loss = classifier_loss(pred_y, y)
610
                    loss = methy_recon + expr_recon + kl + class_loss
611
612
                    pred_y_softmax = F.softmax(pred_y, dim=1)
613
                    _, predicted = torch.max(pred_y_softmax, 1)
614
                    correct = (predicted == y).sum().item()
615
616
                    y_store = torch.cat((y_store, y.cpu()))
617
                    predicted_store = torch.cat((predicted_store, predicted.cpu()))
618
619
                    val_methy_recon += methy_recon.item()
620
                    val_expr_recon += expr_recon.item()
621
                    val_kl += kl.item()
622
                    val_classifier += class_loss.item()
623
                    val_correct_num += correct
624
                    val_total_loss += loss.item()
625
626
            output_y = y_store[1:].numpy()
627
            output_pred_y = predicted_store[1:].numpy()
628
629
            if get_metrics:
630
                metrics_array[0] = metrics.accuracy_score(output_y, output_pred_y)
631
                metrics_array[1] = metrics.precision_score(output_y, output_pred_y, average='weighted')
632
                metrics_array[2] = metrics.recall_score(output_y, output_pred_y, average='weighted')
633
                metrics_array[3] = metrics.f1_score(output_y, output_pred_y, average='weighted')
634
635
            val_methy_recon_ave = val_methy_recon / len(val_dataset)
636
            val_expr_recon_ave = val_expr_recon / len(val_dataset)
637
            val_kl_ave = val_kl / len(val_dataset)
638
            val_classifier_ave = val_classifier / len(val_dataset)
639
            val_accuracy = val_correct_num / len(val_dataset) * 100
640
            val_total_loss_ave = val_total_loss / len(val_dataset)
641
642
            print('Validation\n'
643
                  'Methy Recon Loss: {:.2f}   Expr Recon Loss: {:.2f}   KL Loss: {:.2f}   Classification Loss: {:.2f}'
644
                  '\nACC: {:.2f}%\n'.
645
                  format(val_methy_recon_ave, val_expr_recon_ave, val_kl_ave, val_classifier_ave, val_accuracy))
646
            loss_array[5, e_index] = val_methy_recon_ave
647
            loss_array[6, e_index] = val_expr_recon_ave
648
            loss_array[7, e_index] = val_kl_ave
649
            loss_array[8, e_index] = val_classifier_ave
650
            loss_array[9, e_index] = val_accuracy
651
652
            # TB
653
            val_writer.add_scalar('Total loss', val_total_loss_ave, e_index)
654
            val_writer.add_scalar('Methy recon loss', val_methy_recon_ave, e_index)
655
            val_writer.add_scalar('Expr recon loss', val_expr_recon_ave, e_index)
656
            val_writer.add_scalar('KL loss', val_kl_ave, e_index)
657
            val_writer.add_scalar('Classification loss', val_classifier_ave, e_index)
658
            val_writer.add_scalar('Accuracy', val_accuracy, e_index)
659
660
            return val_accuracy, output_pred_y
661
662
    print('\nUNSUPERVISED PHASE\n')
663
    # unsupervised phase
664
    for epoch_index in range(p1_epoch_num):
665
        train(e_index=epoch_index, e_num=p1_epoch_num+p2_epoch_num, k_methy_recon=1, k_expr_recon=1, k_kl=1, k_c=0)
666
        if separate_testing:
667
            _, out_pred_y = val(epoch_index)
668
669
    print('\nSUPERVISED PHASE\n')
670
    # supervised phase
671
    epoch_number = p1_epoch_num
672
    for epoch_index in range(p1_epoch_num, p1_epoch_num+p2_epoch_num):
673
        epoch_number += 1
674
        train(e_index=epoch_index, e_num=p1_epoch_num+p2_epoch_num, k_methy_recon=0, k_expr_recon=0, k_kl=0, k_c=1)
675
        if separate_testing:
676
            if epoch_index == p1_epoch_num+p2_epoch_num-1:
677
                val_classification_acc, out_pred_y = val(epoch_index, get_metrics=True)
678
            else:
679
                val_classification_acc, out_pred_y = val(epoch_index)
680
            if early_stopping:
681
                early_stop_ob(vae_model, val_classification_acc)
682
                if early_stop_ob.stop_now:
683
                    print('Early stopping\n')
684
                    break
685
686
    if early_stopping:
687
        best_epoch = p1_epoch_num + early_stop_ob.best_epoch_num
688
        loss_array[10, 0] = best_epoch
689
        print('Load model of Epoch {:d}'.format(best_epoch))
690
        vae_model.load_state_dict(torch.load('../ssd/checkpoint.pt'))
691
        _, out_pred_y = val(epoch_number, get_metrics=True)
692
693
    # Encode all of the data into the latent space
694
    print('Encoding all the data into latent space...')
695
    vae_model.eval()
696
    with torch.no_grad():
697
        d_z_store = torch.zeros(1, latent_dim).to(device)
698
        for i, sample in enumerate(full_loader):
699
            d = sample[0]
700
            for chr_i in range(24):
701
                d[chr_i] = d[chr_i].to(device)
702
            _, _, d_z, _, _ = vae_model(d)
703
            d_z_store = torch.cat((d_z_store, d_z), 0)
704
    all_data_z = d_z_store[1:]
705
    all_data_z_np = all_data_z.cpu().numpy()
706
707
    # Output file
708
    print('Preparing the output files... ')
709
    input_path_name = input_path.split('/')[-1]
710
    latent_space_path = 'results/' + input_path_name + str(latent_dim) + 'D_latent_space.tsv'
711
712
    # Whether variable sample_id exists
713
    all_data_z_df = pd.DataFrame(all_data_z_np, index=sample_id)
714
    all_data_z_df.to_csv(latent_space_path, sep='\t')
715
716
    if separate_testing:
717
        pred_y_path =  'results/' + input_path_name + str(latent_dim) + 'D_pred_y.tsv'
718
        np.savetxt(pred_y_path, out_pred_y, delimiter='\t')
719
720
        metrics_record_path = 'results/' + input_path_name + str(latent_dim) + 'D_metrics.tsv'
721
        np.savetxt(metrics_record_path, metrics_array, delimiter='\t')
722
723
    if output_loss_record:
724
        loss_record_path = 'results/' + input_path_name + str(latent_dim) + 'D_loss_record.tsv'
725
        np.savetxt(loss_record_path, loss_array, delimiter='\t')
726
727
    return all_data_z_df