Diff of /util/visualizer.py [000000] .. [03464c]

Switch to unified view

a b/util/visualizer.py
1
import os
2
import time
3
import numpy as np
4
import pandas as pd
5
import sklearn as sk
6
from sklearn.preprocessing import label_binarize
7
from util import util
8
from util import metrics
9
from torch.utils.tensorboard import SummaryWriter
10
11
12
class Visualizer:
13
    """
14
    This class print/save logging information
15
    """
16
17
    def __init__(self, param):
18
        """
19
        Initialize the Visualizer class
20
        """
21
        self.param = param
22
        self.output_path = os.path.join(param.checkpoints_dir, param.experiment_name)
23
        tb_dir = os.path.join(self.output_path, 'tb_log')
24
        util.mkdir(tb_dir)
25
26
        if param.isTrain:
27
            # Create a logging file to store training losses
28
            self.train_log_filename = os.path.join(self.output_path, 'train_log.txt')
29
            with open(self.train_log_filename, 'a') as log_file:
30
                now = time.strftime('%c')
31
                log_file.write('----------------------- Training Log ({:s}) -----------------------\n'.format(now))
32
33
            self.train_summary_filename = os.path.join(self.output_path, 'train_summary.txt')
34
            with open(self.train_summary_filename, 'a') as log_file:
35
                now = time.strftime('%c')
36
                log_file.write('----------------------- Training Summary ({:s}) -----------------------\n'.format(now))
37
38
            # Create log folder for TensorBoard
39
            tb_train_dir = os.path.join(self.output_path, 'tb_log', 'train')
40
            util.mkdir(tb_train_dir)
41
            util.clear_dir(tb_train_dir)
42
43
            # Create TensorBoard writer
44
            self.train_writer = SummaryWriter(log_dir=tb_train_dir)
45
46
        if param.isTest:
47
            # Create a logging file to store testing metrics
48
            self.test_log_filename = os.path.join(self.output_path, 'test_log.txt')
49
            with open(self.test_log_filename, 'a') as log_file:
50
                now = time.strftime('%c')
51
                log_file.write('----------------------- Testing Log ({:s}) -----------------------\n'.format(now))
52
53
            self.test_summary_filename = os.path.join(self.output_path, 'test_summary.txt')
54
            with open(self.test_summary_filename, 'a') as log_file:
55
                now = time.strftime('%c')
56
                log_file.write('----------------------- Testing Summary ({:s}) -----------------------\n'.format(now))
57
58
            # Create log folder for TensorBoard
59
            tb_test_dir = os.path.join(self.output_path, 'tb_log', 'test')
60
            util.mkdir(tb_test_dir)
61
            util.clear_dir(tb_test_dir)
62
63
            # Create TensorBoard writer
64
            self.test_writer = SummaryWriter(log_dir=tb_test_dir)
65
66
    def print_train_log(self, epoch, iteration, losses_dict, metrics_dict, load_time, comp_time, batch_size, dataset_size, with_time=True):
67
        """
68
        print train log on console and save the message to the disk
69
70
        Parameters:
71
            epoch (int)                     -- current epoch
72
            iteration (int)                 -- current training iteration during this epoch
73
            losses_dict (OrderedDict)       -- training losses stored in the ordered dict
74
            metrics_dict (OrderedDict)      -- metrics stored in the ordered dict
75
            load_time (float)               -- data loading time per data point (normalized by batch_size)
76
            comp_time (float)               -- computational time per data point (normalized by batch_size)
77
            batch_size (int)                -- batch size of training
78
            dataset_size (int)              -- size of the training dataset
79
            with_time (bool)                -- print the running time or not
80
        """
81
        data_point_covered = min((iteration + 1) * batch_size, dataset_size)
82
        if with_time:
83
            message = '[TRAIN] [Epoch: {:3d}   Iter: {:4d}   Load_t: {:.3f}   Comp_t: {:.3f}]   '.format(epoch, data_point_covered, load_time, comp_time)
84
        else:
85
            message = '[TRAIN] [Epoch: {:3d}   Iter: {:4d}]\n'.format(epoch, data_point_covered)
86
        for name, loss in losses_dict.items():
87
            message += '{:s}: {:.3f}   '.format(name, loss[-1])
88
        for name, metric in metrics_dict.items():
89
            message += '{:s}: {:.3f}   '.format(name, metric)
90
91
        print(message)  # print the message
92
93
        with open(self.train_log_filename, 'a') as log_file:
94
            log_file.write(message + '\n')  # save the message
95
96
    def print_train_summary(self, epoch, losses_dict, output_dict, train_time, current_lr):
97
        """
98
        print the summary of this training epoch
99
100
        Parameters:
101
            epoch (int)                             -- epoch number of this training model
102
            losses_dict (OrderedDict)               -- the losses dictionary
103
            output_dict (OrderedDict)               -- the downstream output dictionary
104
            train_time (float)                      -- time used for training this epoch
105
            current_lr (float)                      -- the learning rate of this epoch
106
        """
107
        write_message = '{:s}\t'.format(str(epoch))
108
        print_message = '[TRAIN] [Epoch: {:3d}]\n'.format(int(epoch))
109
110
        for name, loss in losses_dict.items():
111
            write_message += '{:.6f}\t'.format(np.mean(loss))
112
            print_message += name + ': {:.3f}   '.format(np.mean(loss))
113
            self.train_writer.add_scalar('loss_'+name, np.mean(loss), epoch)
114
115
        metrics_dict = self.get_epoch_metrics(output_dict)
116
        for name, metric in metrics_dict.items():
117
            write_message += '{:.6f}\t'.format(metric)
118
            print_message += name + ': {:.3f}   '.format(metric)
119
            self.train_writer.add_scalar('metric_'+name, metric, epoch)
120
121
        train_time_msg = 'Training time used: {:.3f}s'.format(train_time)
122
        print_message += '\n' + train_time_msg
123
        with open(self.train_log_filename, 'a') as log_file:
124
            log_file.write(train_time_msg + '\n')
125
126
        current_lr_msg = 'Learning rate for this epoch: {:.7f}'.format(current_lr)
127
        print_message += '\n' + current_lr_msg
128
        self.train_writer.add_scalar('lr', current_lr, epoch)
129
130
        with open(self.train_summary_filename, 'a') as log_file:
131
            log_file.write(write_message + '\n')
132
133
        print(print_message)
134
135
    def print_test_log(self, epoch, iteration, losses_dict, metrics_dict, batch_size, dataset_size):
136
        """
137
        print performance metrics of this iteration on console and save the message to the disk
138
139
        Parameters:
140
            epoch (int)                     -- epoch number of this testing model
141
            iteration (int)                 -- current testing iteration during this epoch
142
            losses_dict (OrderedDict)       -- training losses stored in the ordered dict
143
            metrics_dict (OrderedDict)      -- metrics stored in the ordered dict
144
            batch_size (int)                -- batch size of testing
145
            dataset_size (int)              -- size of the testing dataset
146
        """
147
        data_point_covered = min((iteration + 1) * batch_size, dataset_size)
148
        message = '[TEST] [Epoch: {:3d}   Iter: {:4d}]   '.format(int(epoch), data_point_covered)
149
        for name, loss in losses_dict.items():
150
            message += '{:s}: {:.3f}   '.format(name, loss[-1])
151
        for name, metric in metrics_dict.items():
152
            message += '{:s}: {:.3f}   '.format(name, metric)
153
154
        print(message)
155
156
        with open(self.test_log_filename, 'a') as log_file:
157
            log_file.write(message + '\n')
158
159
    def print_test_summary(self, epoch, losses_dict, output_dict, test_time):
160
        """
161
        print the summary of this testing epoch
162
163
        Parameters:
164
            epoch (int)                             -- epoch number of this testing model
165
            losses_dict (OrderedDict)               -- the losses dictionary
166
            output_dict (OrderedDict)               -- the downstream output dictionary
167
            test_time (float)                       -- time used for testing this epoch
168
        """
169
        write_message = '{:s}\t'.format(str(epoch))
170
        print_message = '[TEST] [Epoch: {:3d}]      '.format(int(epoch))
171
172
        for name, loss in losses_dict.items():
173
            # write_message += '{:.6f}\t'.format(np.mean(loss))
174
            print_message += name + ': {:.3f}   '.format(np.mean(loss))
175
            self.test_writer.add_scalar('loss_'+name, np.mean(loss), epoch)
176
177
        metrics_dict = self.get_epoch_metrics(output_dict)
178
179
        for name, metric in metrics_dict.items():
180
            write_message += '{:.6f}\t'.format(metric)
181
            print_message += name + ': {:.3f}   '.format(metric)
182
            self.test_writer.add_scalar('metric_' + name, metric, epoch)
183
184
        with open(self.test_summary_filename, 'a') as log_file:
185
            log_file.write(write_message + '\n')
186
187
        test_time_msg = 'Testing time used: {:.3f}s'.format(test_time)
188
        print_message += '\n' + test_time_msg
189
        print(print_message)
190
        with open(self.test_log_filename, 'a') as log_file:
191
            log_file.write(test_time_msg + '\n')
192
193
    def get_epoch_metrics(self, output_dict):
194
        """
195
        Get the downstream task metrics for whole epoch
196
197
        Parameters:
198
            output_dict (OrderedDict)  -- the output dictionary used to compute the downstream task metrics
199
        """
200
        if self.param.downstream_task == 'classification':
201
            y_true = output_dict['y_true'].cpu().numpy()
202
            y_true_binary = label_binarize(y_true, classes=range(self.param.class_num))
203
            y_pred = output_dict['y_pred'].cpu().numpy()
204
            y_prob = output_dict['y_prob'].cpu().numpy()
205
            if self.param.class_num == 2:
206
                y_prob = y_prob[:, 1]
207
208
            accuracy = sk.metrics.accuracy_score(y_true, y_pred)
209
            precision = sk.metrics.precision_score(y_true, y_pred, average='macro', zero_division=0)
210
            recall = sk.metrics.recall_score(y_true, y_pred, average='macro', zero_division=0)
211
            f1 = sk.metrics.f1_score(y_true, y_pred, average='macro', zero_division=0)
212
            try:
213
                auc = sk.metrics.roc_auc_score(y_true_binary, y_prob, multi_class='ovo', average='macro')
214
            except ValueError:
215
                auc = -1
216
                print('ValueError: ROC AUC score is not defined in this case.')
217
218
            return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}
219
220
        elif self.param.downstream_task == 'regression':
221
            y_true = output_dict['y_true'].cpu().numpy()
222
            y_pred = output_dict['y_pred'].cpu().detach().numpy()
223
224
            mse = sk.metrics.mean_squared_error(y_true, y_pred)
225
            rmse = sk.metrics.mean_squared_error(y_true, y_pred, squared=False)
226
            mae = sk.metrics.mean_absolute_error(y_true, y_pred)
227
            medae = sk.metrics.median_absolute_error(y_true, y_pred)
228
            r2 = sk.metrics.r2_score(y_true, y_pred)
229
230
            return {'mse': mse, 'rmse': rmse, 'mae': mae, 'medae': medae, 'r2': r2}
231
232
        elif self.param.downstream_task == 'survival':
233
            metrics_start_time = time.time()
234
235
            y_true_E = output_dict['y_true_E'].cpu().numpy()
236
            y_true_T = output_dict['y_true_T'].cpu().numpy()
237
            y_pred_risk = output_dict['risk'].cpu().numpy()
238
            y_pred_survival = output_dict['survival'].cpu().numpy()
239
240
            time_points = util.get_time_points(self.param.survival_T_max, self.param.time_num)
241
242
            try:
243
                c_index = metrics.c_index(y_true_T, y_true_E, y_pred_risk)
244
            except ValueError:
245
                c_index = -1
246
                print('ValueError: NaNs detected in input when calculating c-index.')
247
248
            try:
249
                ibs = metrics.ibs(y_true_T, y_true_E, y_pred_survival, time_points)
250
            except ValueError:
251
                ibs = -1
252
                print('ValueError: NaNs detected in input when calculating integrated brier score.')
253
254
            metrics_time = time.time() - metrics_start_time
255
            print('Metrics computing time: {:.3f}s'.format(metrics_time))
256
257
            return {'c-index': c_index, 'ibs': ibs}
258
259
        elif self.param.downstream_task == 'multitask':
260
            metrics_start_time = time.time()
261
262
            # Survival
263
            y_true_E = output_dict['y_true_E'].cpu().numpy()
264
            y_true_T = output_dict['y_true_T'].cpu().numpy()
265
            y_pred_risk = output_dict['risk'].cpu().numpy()
266
            y_pred_survival = output_dict['survival'].cpu().numpy()
267
            time_points = util.get_time_points(self.param.survival_T_max, self.param.time_num)
268
            try:
269
                c_index = metrics.c_index(y_true_T, y_true_E, y_pred_risk)
270
            except ValueError:
271
                c_index = -1
272
                print('ValueError: NaNs detected in input when calculating c-index.')
273
            try:
274
                ibs = metrics.ibs(y_true_T, y_true_E, y_pred_survival, time_points)
275
            except ValueError:
276
                ibs = -1
277
                print('ValueError: NaNs detected in input when calculating integrated brier score.')
278
279
            # Classification
280
            y_true_cla = output_dict['y_true_cla'].cpu().numpy()
281
            y_true_cla_binary = label_binarize(y_true_cla, classes=range(self.param.class_num))
282
            y_pred_cla = output_dict['y_pred_cla'].cpu().numpy()
283
            y_prob_cla = output_dict['y_prob_cla'].cpu().numpy()
284
            if self.param.class_num == 2:
285
                y_prob_cla = y_prob_cla[:, 1]
286
            accuracy = sk.metrics.accuracy_score(y_true_cla, y_pred_cla)
287
            precision = sk.metrics.precision_score(y_true_cla, y_pred_cla, average='macro', zero_division=0)
288
            recall = sk.metrics.recall_score(y_true_cla, y_pred_cla, average='macro', zero_division=0)
289
            f1 = sk.metrics.f1_score(y_true_cla, y_pred_cla, average='macro', zero_division=0)
290
            '''
291
            try:
292
                auc = sk.metrics.roc_auc_score(y_true_cla_binary, y_prob_cla, multi_class='ovo', average='macro')
293
            except ValueError:
294
                auc = -1
295
                print('ValueError: ROC AUC score is not defined in this case.')
296
            '''
297
298
            # Regression
299
            y_true_reg = output_dict['y_true_reg'].cpu().numpy()
300
            y_pred_reg = output_dict['y_pred_reg'].cpu().detach().numpy()
301
            # mse = sk.metrics.mean_squared_error(y_true_reg, y_pred_reg)
302
            rmse = sk.metrics.mean_squared_error(y_true_reg, y_pred_reg, squared=False)
303
            mae = sk.metrics.mean_absolute_error(y_true_reg, y_pred_reg)
304
            medae = sk.metrics.median_absolute_error(y_true_reg, y_pred_reg)
305
            r2 = sk.metrics.r2_score(y_true_reg, y_pred_reg)
306
307
            metrics_time = time.time() - metrics_start_time
308
            print('Metrics computing time: {:.3f}s'.format(metrics_time))
309
310
            return {'c-index': c_index, 'ibs': ibs, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'rmse': rmse, 'mae': mae, 'medae': medae, 'r2': r2}
311
312
        elif self.param.downstream_task == 'alltask':
313
            metrics_start_time = time.time()
314
315
            # Survival
316
            y_true_E = output_dict['y_true_E'].cpu().numpy()
317
            y_true_T = output_dict['y_true_T'].cpu().numpy()
318
            y_pred_risk = output_dict['risk'].cpu().numpy()
319
            y_pred_survival = output_dict['survival'].cpu().numpy()
320
            time_points = util.get_time_points(self.param.survival_T_max, self.param.time_num)
321
            try:
322
                c_index = metrics.c_index(y_true_T, y_true_E, y_pred_risk)
323
            except ValueError:
324
                c_index = -1
325
                print('ValueError: NaNs detected in input when calculating c-index.')
326
            try:
327
                ibs = metrics.ibs(y_true_T, y_true_E, y_pred_survival, time_points)
328
            except ValueError:
329
                ibs = -1
330
                print('ValueError: NaNs detected in input when calculating integrated brier score.')
331
332
            # Classification
333
            accuracy = []
334
            f1 = []
335
            auc = []
336
            for i in range(self.param.task_num - 2):
337
                y_true_cla = output_dict['y_true_cla'][i].cpu().numpy()
338
                y_true_cla_binary = label_binarize(y_true_cla, classes=range(self.param.class_num[i]))
339
                y_pred_cla = output_dict['y_pred_cla'][i].cpu().numpy()
340
                y_prob_cla = output_dict['y_prob_cla'][i].cpu().numpy()
341
                if self.param.class_num[i] == 2:
342
                    y_prob_cla = y_prob_cla[:, 1]
343
                accuracy.append(sk.metrics.accuracy_score(y_true_cla, y_pred_cla))
344
                f1.append(sk.metrics.f1_score(y_true_cla, y_pred_cla, average='macro', zero_division=0))
345
                try:
346
                    auc.append(sk.metrics.roc_auc_score(y_true_cla_binary, y_prob_cla, multi_class='ovo', average='macro'))
347
                except ValueError:
348
                    auc.append(-1)
349
                    print('ValueError: ROC AUC score is not defined in this case.')
350
351
            # Regression
352
            y_true_reg = output_dict['y_true_reg'].cpu().numpy()
353
            y_pred_reg = output_dict['y_pred_reg'].cpu().detach().numpy()
354
            # mse = sk.metrics.mean_squared_error(y_true_reg, y_pred_reg)
355
            rmse = sk.metrics.mean_squared_error(y_true_reg, y_pred_reg, squared=False)
356
            # mae = sk.metrics.mean_absolute_error(y_true_reg, y_pred_reg)
357
            # medae = sk.metrics.median_absolute_error(y_true_reg, y_pred_reg)
358
            r2 = sk.metrics.r2_score(y_true_reg, y_pred_reg)
359
360
            metrics_time = time.time() - metrics_start_time
361
            print('Metrics computing time: {:.3f}s'.format(metrics_time))
362
363
            return {'c-index': c_index, 'ibs': ibs, 'accuracy_1': accuracy[0], 'f1_1': f1[0], 'auc_1': auc[0], 'accuracy_2': accuracy[1], 'f1_2': f1[1], 'auc_2': auc[1], 'accuracy_3': accuracy[2], 'f1_3': f1[2], 'auc_3': auc[2], 'accuracy_4': accuracy[3], 'f1_4': f1[3], 'auc_4': auc[3], 'accuracy_5': accuracy[4], 'f1_5': f1[4], 'auc_5': auc[4], 'rmse': rmse, 'r2': r2}
364
365
    def save_output_dict(self, output_dict):
366
        """
367
        Save the downstream task output to disk
368
369
        Parameters:
370
            output_dict (OrderedDict)  -- the downstream task output dictionary to be saved
371
        """
372
        down_path = os.path.join(self.output_path, 'down_output')
373
        util.mkdir(down_path)
374
        if self.param.downstream_task == 'classification':
375
            # Prepare files
376
            index = output_dict['index'].numpy()
377
            y_true = output_dict['y_true'].cpu().numpy()
378
            y_pred = output_dict['y_pred'].cpu().numpy()
379
            y_prob = output_dict['y_prob'].cpu().numpy()
380
381
            sample_list = self.param.sample_list[index]
382
383
            # Output files
384
            y_df = pd.DataFrame({'sample': sample_list, 'y_true': y_true, 'y_pred': y_pred}, index=index)
385
            y_df_path = os.path.join(down_path, 'y_df.tsv')
386
            y_df.to_csv(y_df_path, sep='\t')
387
388
            prob_df = pd.DataFrame(y_prob, columns=range(self.param.class_num), index=sample_list)
389
            y_prob_path = os.path.join(down_path, 'y_prob.tsv')
390
            prob_df.to_csv(y_prob_path, sep='\t')
391
392
        elif self.param.downstream_task == 'regression':
393
            # Prepare files
394
            index = output_dict['index'].numpy()
395
            y_true = output_dict['y_true'].cpu().numpy()
396
            y_pred = np.squeeze(output_dict['y_pred'].cpu().detach().numpy())
397
398
            sample_list = self.param.sample_list[index]
399
400
            # Output files
401
            y_df = pd.DataFrame({'sample': sample_list, 'y_true': y_true, 'y_pred': y_pred}, index=index)
402
            y_df_path = os.path.join(down_path, 'y_df.tsv')
403
            y_df.to_csv(y_df_path, sep='\t')
404
405
        elif self.param.downstream_task == 'survival':
406
            # Prepare files
407
            index = output_dict['index'].numpy()
408
            y_true_E = output_dict['y_true_E'].cpu().numpy()
409
            y_true_T = output_dict['y_true_T'].cpu().numpy()
410
            y_pred_risk = output_dict['risk'].cpu().numpy()
411
            survival_function = output_dict['survival'].cpu().numpy()
412
            y_out = output_dict['y_out'].cpu().numpy()
413
414
            sample_list = self.param.sample_list[index]
415
            time_points = util.get_time_points(self.param.survival_T_max, self.param.time_num)
416
417
            # Output files
418
            y_df = pd.DataFrame({'sample': sample_list, 'true_T': y_true_T, 'true_E': y_true_E, 'pred_risk': y_pred_risk}, index=index)
419
            y_df_path = os.path.join(down_path, 'y_df.tsv')
420
            y_df.to_csv(y_df_path, sep='\t')
421
422
            survival_function_df = pd.DataFrame(survival_function, columns=time_points, index=sample_list)
423
            survival_function_path = os.path.join(down_path, 'survival_function.tsv')
424
            survival_function_df.to_csv(survival_function_path, sep='\t')
425
426
            y_out_df = pd.DataFrame(y_out, index=sample_list)
427
            y_out_path = os.path.join(down_path, 'y_out.tsv')
428
            y_out_df.to_csv(y_out_path, sep='\t')
429
430
        elif self.param.downstream_task == 'multitask':
431
            # Survival
432
            index = output_dict['index'].numpy()
433
            y_true_E = output_dict['y_true_E'].cpu().numpy()
434
            y_true_T = output_dict['y_true_T'].cpu().numpy()
435
            y_pred_risk = output_dict['risk'].cpu().numpy()
436
            survival_function = output_dict['survival'].cpu().numpy()
437
            y_out_sur = output_dict['y_out_sur'].cpu().numpy()
438
            sample_list = self.param.sample_list[index]
439
            time_points = util.get_time_points(self.param.survival_T_max, self.param.time_num)
440
            y_df_sur = pd.DataFrame(
441
                {'sample': sample_list, 'true_T': y_true_T, 'true_E': y_true_E, 'pred_risk': y_pred_risk}, index=index)
442
            y_df_sur_path = os.path.join(down_path, 'y_df_survival.tsv')
443
            y_df_sur.to_csv(y_df_sur_path, sep='\t')
444
            survival_function_df = pd.DataFrame(survival_function, columns=time_points, index=sample_list)
445
            survival_function_path = os.path.join(down_path, 'survival_function.tsv')
446
            survival_function_df.to_csv(survival_function_path, sep='\t')
447
            y_out_sur_df = pd.DataFrame(y_out_sur, index=sample_list)
448
            y_out_sur_path = os.path.join(down_path, 'y_out_survival.tsv')
449
            y_out_sur_df.to_csv(y_out_sur_path, sep='\t')
450
451
            # Classification
452
            y_true_cla = output_dict['y_true_cla'].cpu().numpy()
453
            y_pred_cla = output_dict['y_pred_cla'].cpu().numpy()
454
            y_prob_cla = output_dict['y_prob_cla'].cpu().numpy()
455
            y_df_cla = pd.DataFrame({'sample': sample_list, 'y_true': y_true_cla, 'y_pred': y_pred_cla}, index=index)
456
            y_df_cla_path = os.path.join(down_path, 'y_df_classification.tsv')
457
            y_df_cla.to_csv(y_df_cla_path, sep='\t')
458
            prob_cla_df = pd.DataFrame(y_prob_cla, columns=range(self.param.class_num), index=sample_list)
459
            y_prob_cla_path = os.path.join(down_path, 'y_prob_classification.tsv')
460
            prob_cla_df.to_csv(y_prob_cla_path, sep='\t')
461
462
            # Regression
463
            y_true_reg = output_dict['y_true_reg'].cpu().numpy()
464
            y_pred_reg = np.squeeze(output_dict['y_pred_reg'].cpu().detach().numpy())
465
            y_df_reg = pd.DataFrame({'sample': sample_list, 'y_true': y_true_reg, 'y_pred': y_pred_reg}, index=index)
466
            y_df_reg_path = os.path.join(down_path, 'y_df_regression.tsv')
467
            y_df_reg.to_csv(y_df_reg_path, sep='\t')
468
469
        elif self.param.downstream_task == 'alltask':
470
            # Survival
471
            index = output_dict['index'].numpy()
472
            y_true_E = output_dict['y_true_E'].cpu().numpy()
473
            y_true_T = output_dict['y_true_T'].cpu().numpy()
474
            y_pred_risk = output_dict['risk'].cpu().numpy()
475
            survival_function = output_dict['survival'].cpu().numpy()
476
            y_out_sur = output_dict['y_out_sur'].cpu().numpy()
477
            sample_list = self.param.sample_list[index]
478
            time_points = util.get_time_points(self.param.survival_T_max, self.param.time_num)
479
            y_df_sur = pd.DataFrame(
480
                {'sample': sample_list, 'true_T': y_true_T, 'true_E': y_true_E, 'pred_risk': y_pred_risk}, index=index)
481
            y_df_sur_path = os.path.join(down_path, 'y_df_survival.tsv')
482
            y_df_sur.to_csv(y_df_sur_path, sep='\t')
483
            survival_function_df = pd.DataFrame(survival_function, columns=time_points, index=sample_list)
484
            survival_function_path = os.path.join(down_path, 'survival_function.tsv')
485
            survival_function_df.to_csv(survival_function_path, sep='\t')
486
            y_out_sur_df = pd.DataFrame(y_out_sur, index=sample_list)
487
            y_out_sur_path = os.path.join(down_path, 'y_out_survival.tsv')
488
            y_out_sur_df.to_csv(y_out_sur_path, sep='\t')
489
490
            # Classification
491
            for i in range(self.param.task_num - 2):
492
                y_true_cla = output_dict['y_true_cla'][i].cpu().numpy()
493
                y_pred_cla = output_dict['y_pred_cla'][i].cpu().numpy()
494
                y_prob_cla = output_dict['y_prob_cla'][i].cpu().numpy()
495
                y_df_cla = pd.DataFrame({'sample': sample_list, 'y_true': y_true_cla, 'y_pred': y_pred_cla}, index=index)
496
                y_df_cla_path = os.path.join(down_path, 'y_df_classification_'+str(i+1)+'.tsv')
497
                y_df_cla.to_csv(y_df_cla_path, sep='\t')
498
                prob_cla_df = pd.DataFrame(y_prob_cla, columns=range(self.param.class_num[i]), index=sample_list)
499
                y_prob_cla_path = os.path.join(down_path, 'y_prob_classification_'+str(i+1)+'.tsv')
500
                prob_cla_df.to_csv(y_prob_cla_path, sep='\t')
501
502
            # Regression
503
            y_true_reg = output_dict['y_true_reg'].cpu().numpy()
504
            y_pred_reg = np.squeeze(output_dict['y_pred_reg'].cpu().detach().numpy())
505
            y_df_reg = pd.DataFrame({'sample': sample_list, 'y_true': y_true_reg, 'y_pred': y_pred_reg}, index=index)
506
            y_df_reg_path = os.path.join(down_path, 'y_df_regression.tsv')
507
            y_df_reg.to_csv(y_df_reg_path, sep='\t')
508
509
510
    def save_latent_space(self, latent_dict, sample_list):
511
        """
512
            save the latent space matrix to disc
513
514
            Parameters:
515
                latent_dict (OrderedDict)          -- the latent space dictionary
516
                sample_list (ndarray)               -- the sample list for the latent matrix
517
        """
518
        reordered_sample_list = sample_list[latent_dict['index'].astype(int)]
519
        latent_df = pd.DataFrame(latent_dict['latent'], index=reordered_sample_list)
520
        output_path = os.path.join(self.param.checkpoints_dir, self.param.experiment_name, 'latent_space.tsv')
521
        print('Saving the latent space matrix...')
522
        latent_df.to_csv(output_path, sep='\t')
523
524
525
    @staticmethod
526
    def print_phase(phase):
527
        """
528
        print the phase information
529
530
        Parameters:
531
            phase (int)             -- the phase of the training process
532
        """
533
        if phase == 'p1':
534
            print('PHASE 1: Unsupervised Phase')
535
        elif phase == 'p2':
536
            print('PHASE 2: Supervised Phase')
537
        elif phase == 'p3':
538
            print('PHASE 3: Supervised Phase')