a b/utils/utils.py
1
import pickle
2
import torch
3
import numpy as np
4
import torch.nn as nn
5
import pdb
6
7
import torch
8
import numpy as np
9
import torch.nn as nn
10
from torchvision import transforms
11
from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler
12
import torch.optim as optim
13
import pdb
14
import torch.nn.functional as F
15
import math
16
from itertools import islice
17
import collections
18
19
from torch.utils.data.dataloader import default_collate
20
import torch_geometric
21
from torch_geometric.data import Batch
22
23
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
25
class SubsetSequentialSampler(Sampler):
26
    """Samples elements sequentially from a given list of indices, without replacement.
27
28
    Arguments:
29
        indices (sequence): a sequence of indices
30
    """
31
    def __init__(self, indices):
32
        self.indices = indices
33
34
    def __iter__(self):
35
        return iter(self.indices)
36
37
    def __len__(self):
38
        return len(self.indices)
39
40
def collate_MIL(batch):
41
    img = torch.cat([item[0] for item in batch], dim = 0)
42
    label = torch.LongTensor([item[1] for item in batch])
43
    return [img, label]
44
45
def collate_features(batch):
46
    img = torch.cat([item[0] for item in batch], dim = 0)
47
    coords = np.vstack([item[1] for item in batch])
48
    return [img, coords]
49
50
def collate_MIL_survival(batch):
51
    img = torch.cat([item[0] for item in batch], dim = 0)
52
    omic = torch.cat([item[1] for item in batch], dim = 0).type(torch.FloatTensor)
53
    label = torch.LongTensor([item[2] for item in batch])
54
    event_time = torch.FloatTensor([item[3] for item in batch])
55
    c = torch.FloatTensor([item[4] for item in batch])
56
    return [img, omic, label, event_time, c]
57
58
def collate_MIL_survival_cluster(batch):
59
    img = torch.cat([item[0] for item in batch], dim = 0)
60
    cluster_ids = torch.cat([item[1] for item in batch], dim = 0).type(torch.LongTensor)
61
    omic = torch.cat([item[2] for item in batch], dim = 0).type(torch.FloatTensor)
62
    label = torch.LongTensor([item[3] for item in batch])
63
    event_time = np.array([item[4] for item in batch])
64
    c = torch.FloatTensor([item[5] for item in batch])
65
    return [img, cluster_ids, omic, label, event_time, c]
66
67
def collate_MIL_survival_sig(batch):
68
    img = torch.cat([item[0] for item in batch], dim = 0)
69
    omic1 = torch.cat([item[1] for item in batch], dim = 0).type(torch.FloatTensor)
70
    omic2 = torch.cat([item[2] for item in batch], dim = 0).type(torch.FloatTensor)
71
    omic3 = torch.cat([item[3] for item in batch], dim = 0).type(torch.FloatTensor)
72
    omic4 = torch.cat([item[4] for item in batch], dim = 0).type(torch.FloatTensor)
73
    omic5 = torch.cat([item[5] for item in batch], dim = 0).type(torch.FloatTensor)
74
    omic6 = torch.cat([item[6] for item in batch], dim = 0).type(torch.FloatTensor)
75
76
    label = torch.LongTensor([item[7] for item in batch])
77
    event_time = np.array([item[8] for item in batch])
78
    c = torch.FloatTensor([item[9] for item in batch])
79
    return [img, omic1, omic2, omic3, omic4, omic5, omic6, label, event_time, c]
80
81
def get_simple_loader(dataset, batch_size=1):
82
    kwargs = {'num_workers': 4} if device.type == "cuda" else {}
83
    loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL, **kwargs)
84
    return loader 
85
86
def get_split_loader(split_dataset, training = False, testing = False, weighted = False, mode='coattn', batch_size=1):
87
    """
88
        return either the validation loader or training loader 
89
    """
90
    if mode == 'coattn':
91
        collate = collate_MIL_survival_sig
92
    elif mode == 'cluster':
93
        collate = collate_MIL_survival_cluster
94
    else:
95
        collate = collate_MIL_survival
96
97
    kwargs = {'num_workers': 4} if device.type == "cuda" else {}
98
    if not testing:
99
        if training:
100
            if weighted:
101
                weights = make_weights_for_balanced_classes_split(split_dataset)
102
                loader = DataLoader(split_dataset, batch_size=batch_size, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate, **kwargs)    
103
            else:
104
                loader = DataLoader(split_dataset, batch_size=batch_size, sampler = RandomSampler(split_dataset), collate_fn = collate, **kwargs)
105
        else:
106
            loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate, **kwargs)
107
    
108
    else:
109
        ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False)
110
        loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate, **kwargs )
111
112
    return loader
113
114
def get_optim(model, args):
115
    if args.opt == "adam":
116
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg)
117
    elif args.opt == 'sgd':
118
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg)
119
    else:
120
        raise NotImplementedError
121
    return optimizer
122
123
def print_network(net):
124
    num_params = 0
125
    num_params_train = 0
126
    print(net)
127
    
128
    for param in net.parameters():
129
        n = param.numel()
130
        num_params += n
131
        if param.requires_grad:
132
            num_params_train += n
133
    
134
    print('Total number of parameters: %d' % num_params)
135
    print('Total number of trainable parameters: %d' % num_params_train)
136
137
138
def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5,
139
    seed = 7, label_frac = 1.0, custom_test_ids = None):
140
    indices = np.arange(samples).astype(int)
141
    
142
    pdb.set_trace()
143
    if custom_test_ids is not None:
144
        indices = np.setdiff1d(indices, custom_test_ids)
145
146
    np.random.seed(seed)
147
    for i in range(n_splits):
148
        all_val_ids = []
149
        all_test_ids = []
150
        sampled_train_ids = []
151
        
152
        if custom_test_ids is not None: # pre-built test split, do not need to sample
153
            all_test_ids.extend(custom_test_ids)
154
155
        for c in range(len(val_num)):
156
            possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class
157
            remaining_ids = possible_indices
158
159
            if val_num[c] > 0:
160
                val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids
161
                remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation
162
                all_val_ids.extend(val_ids)
163
164
            if custom_test_ids is None and test_num[c] > 0: # sample test split
165
166
                test_ids = np.random.choice(remaining_ids, test_num[c], replace = False)
167
                remaining_ids = np.setdiff1d(remaining_ids, test_ids)
168
                all_test_ids.extend(test_ids)
169
170
            if label_frac == 1:
171
                sampled_train_ids.extend(remaining_ids)
172
            
173
            else:
174
                sample_num  = math.ceil(len(remaining_ids) * label_frac)
175
                slice_ids = np.arange(sample_num)
176
                sampled_train_ids.extend(remaining_ids[slice_ids])
177
178
        yield sorted(sampled_train_ids), sorted(all_val_ids), sorted(all_test_ids)
179
180
181
def nth(iterator, n, default=None):
182
    if n is None:
183
        return collections.deque(iterator, maxlen=0)
184
    else:
185
        return next(islice(iterator,n, None), default)
186
187
def calculate_error(Y_hat, Y):
188
    error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()
189
190
    return error
191
192
def make_weights_for_balanced_classes_split(dataset):
193
    N = float(len(dataset))                                           
194
    weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))]                                                                                                     
195
    weight = [0] * int(N)                                           
196
    for idx in range(len(dataset)):   
197
        y = dataset.getlabel(idx)                        
198
        weight[idx] = weight_per_class[y]                                  
199
200
    return torch.DoubleTensor(weight)
201
202
def initialize_weights(module):
203
    for m in module.modules():
204
        if isinstance(m, nn.Linear):
205
            nn.init.xavier_normal_(m.weight)
206
            m.bias.data.zero_()
207
        
208
        elif isinstance(m, nn.BatchNorm1d):
209
            nn.init.constant_(m.weight, 1)
210
            nn.init.constant_(m.bias, 0)
211
212
213
def dfs_freeze(model):
214
    for name, child in model.named_children():
215
        for param in child.parameters():
216
            param.requires_grad = False
217
        dfs_freeze(child)
218
219
220
def dfs_unfreeze(model):
221
    for name, child in model.named_children():
222
        for param in child.parameters():
223
            param.requires_grad = True
224
        dfs_unfreeze(child)
225
226
227
# divide continuous time scale into k discrete bins in total,  T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)}
228
# Y = T_discrete is the discrete event time:
229
# Y = 0 if T_cont \in (-inf, 0), Y = 1 if T_cont \in [0, a_1),  Y = 2 if T_cont in [a_1, a_2), ..., Y = k if T_cont in [a_(k-1), inf)
230
# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X),  t = 0,1,2,...,k
231
# S: survival function: P(Y > t | X)
232
# all patients are alive from (-inf, 0) by definition, so P(Y=0) = 0
233
# h(0) = 0 ---> do not need to model
234
# S(0) = P(Y > 0 | X) = 1 ----> do not need to model
235
'''
236
Summary: neural network is hazard probability function, h(t) for t = 1,2,...,k
237
corresponding Y = 1, ..., k. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf]
238
'''
239
# def neg_likelihood_loss(hazards, Y, c):
240
#   batch_size = len(Y)
241
#   Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
242
#   c = c.view(batch_size, 1).float() #censorship status, 0 or 1
243
#   S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
244
#   # without padding, S(1) = S[0], h(1) = h[0]
245
#   S_padded = torch.cat([torch.ones_like(c), S], 1) #S(0) = 1, all patients are alive from (-inf, 0) by definition
246
#   # after padding, S(0) = S[0], S(1) = S[1], etc, h(1) = h[0]
247
#   #h[y] = h(1)
248
#   #S[1] = S(1)
249
#   neg_l = - c * torch.log(torch.gather(S_padded, 1, Y)) - (1 - c) * (torch.log(torch.gather(S_padded, 1, Y-1)) + torch.log(hazards[:, Y-1]))
250
#   neg_l = neg_l.mean()
251
#   return neg_l
252
253
254
# divide continuous time scale into k discrete bins in total,  T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)}
255
# Y = T_discrete is the discrete event time:
256
# Y = -1 if T_cont \in (-inf, 0), Y = 0 if T_cont \in [0, a_1),  Y = 1 if T_cont in [a_1, a_2), ..., Y = k-1 if T_cont in [a_(k-1), inf)
257
# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X),  t = -1,0,1,2,...,k
258
# S: survival function: P(Y > t | X)
259
# all patients are alive from (-inf, 0) by definition, so P(Y=-1) = 0
260
# h(-1) = 0 ---> do not need to model
261
# S(-1) = P(Y > -1 | X) = 1 ----> do not need to model
262
'''
263
Summary: neural network is hazard probability function, h(t) for t = 0,1,2,...,k-1
264
corresponding Y = 0,1, ..., k-1. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf]
265
'''
266
def nll_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
267
    batch_size = len(Y)
268
    Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
269
    c = c.view(batch_size, 1).float() #censorship status, 0 or 1
270
    if S is None:
271
        S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
272
    # without padding, S(0) = S[0], h(0) = h[0]
273
    S_padded = torch.cat([torch.ones_like(c), S], 1) #S(-1) = 0, all patients are alive from (-inf, 0) by definition
274
    # after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
275
    #h[y] = h(1)
276
    #S[1] = S(1)
277
    uncensored_loss = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
278
    censored_loss = - c * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps))
279
    neg_l = censored_loss + uncensored_loss
280
    loss = (1-alpha) * neg_l + alpha * uncensored_loss
281
    loss = loss.mean()
282
    return loss
283
284
def ce_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
285
    batch_size = len(Y)
286
    Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
287
    c = c.view(batch_size, 1).float() #censorship status, 0 or 1
288
    if S is None:
289
        S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
290
    # without padding, S(0) = S[0], h(0) = h[0]
291
    # after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
292
    #h[y] = h(1)
293
    #S[1] = S(1)
294
    S_padded = torch.cat([torch.ones_like(c), S], 1)
295
    reg = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y)+eps) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
296
    ce_l = - c * torch.log(torch.gather(S, 1, Y).clamp(min=eps)) - (1 - c) * torch.log(1 - torch.gather(S, 1, Y).clamp(min=eps))
297
    loss = (1-alpha) * ce_l + alpha * reg
298
    loss = loss.mean()
299
    return loss
300
301
# def nll_loss(hazards, Y, c, S=None, alpha=0.4, eps=1e-8):
302
#   batch_size = len(Y)
303
#   Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
304
#   c = c.view(batch_size, 1).float() #censorship status, 0 or 1
305
#   if S is None:
306
#       S = 1 - torch.cumsum(hazards, dim=1) # surival is cumulative product of 1 - hazards
307
#   uncensored_loss = -(1 - c) * (torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
308
#   censored_loss = - c * torch.log(torch.gather(S, 1, Y).clamp(min=eps))
309
#   loss = censored_loss + uncensored_loss
310
#   loss = loss.mean()
311
#   return loss
312
313
class CrossEntropySurvLoss(object):
314
    def __init__(self, alpha=0.15):
315
        self.alpha = alpha
316
317
    def __call__(self, hazards, S, Y, c, alpha=None): 
318
        if alpha is None:
319
            return ce_loss(hazards, S, Y, c, alpha=self.alpha)
320
        else:
321
            return ce_loss(hazards, S, Y, c, alpha=alpha)
322
323
# loss_fn(hazards=hazards, S=S, Y=Y_hat, c=c, alpha=0)
324
class NLLSurvLoss_dep(object):
325
    def __init__(self, alpha=0.15):
326
        self.alpha = alpha
327
328
    def __call__(self, hazards, S, Y, c, alpha=None):
329
        if alpha is None:
330
            return nll_loss(hazards, S, Y, c, alpha=self.alpha)
331
        else:
332
            return nll_loss(hazards, S, Y, c, alpha=alpha)
333
    # h_padded = torch.cat([torch.zeros_like(c), hazards], 1)
334
    #reg = - (1 - c) * (torch.log(torch.gather(hazards, 1, Y)) + torch.gather(torch.cumsum(torch.log(1-h_padded), dim=1), 1, Y))
335
336
337
class CoxSurvLoss(object):
338
    def __call__(hazards, S, c, **kwargs):
339
        # This calculation credit to Travers Ching https://github.com/traversc/cox-nnet
340
        # Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data
341
        current_batch_len = len(S)
342
        R_mat = np.zeros([current_batch_len, current_batch_len], dtype=int)
343
        for i in range(current_batch_len):
344
            for j in range(current_batch_len):
345
                R_mat[i,j] = S[j] >= S[i]
346
347
        R_mat = torch.FloatTensor(R_mat).to(device)
348
        theta = hazards.reshape(-1)
349
        exp_theta = torch.exp(theta)
350
        loss_cox = -torch.mean((theta - torch.log(torch.sum(exp_theta*R_mat, dim=1))) * (1-c))
351
        return loss_cox
352
353
def l1_reg_all(model, reg_type=None):
354
    l1_reg = None
355
356
    for W in model.parameters():
357
        if l1_reg is None:
358
            l1_reg = torch.abs(W).sum()
359
        else:
360
            l1_reg = l1_reg + torch.abs(W).sum() # torch.abs(W).sum() is equivalent to W.norm(1)
361
    return l1_reg
362
363
def l1_reg_modules(model, reg_type=None):
364
    l1_reg = 0
365
366
    l1_reg += l1_reg_all(model.fc_omic)
367
    l1_reg += l1_reg_all(model.mm)
368
369
    return l1_reg
370
371
def l1_reg_omic(model, reg_type=None):
372
    l1_reg = 0
373
374
    if hasattr(model, 'fc_omic'):
375
        l1_reg += l1_reg_all(model.fc_omic)
376
    else:
377
        l1_reg += l1_reg_all(model)
378
379
    return l1_reg
380
381
def get_custom_exp_code(args):
382
    r"""
383
    Updates the argparse.NameSpace with a custom experiment code.
384
385
    Args:
386
        - args (NameSpace)
387
388
    Returns:
389
        - args (NameSpace)
390
    """
391
    exp_code = '_'.join(args.split_dir.split('_')[:2])
392
    dataset_path = 'datasets_csv'
393
    param_code = ''
394
395
    ### Model Type
396
    if args.model_type == 'porpoise_mmf':
397
      param_code += 'PorpoiseMMF'
398
    elif args.model_type == 'porpoise_amil':
399
      param_code += 'PorpoiseAMIL'
400
    elif args.model_type == 'max_net' or args.model_type == 'snn':
401
      param_code += 'SNN'
402
    elif args.model_type == 'amil':
403
      param_code += 'AMIL'
404
    elif args.model_type == 'deepset':
405
      param_code += 'DS'
406
    elif args.model_type == 'mi_fcn':
407
      param_code += 'MIFCN'
408
    elif args.model_type == 'mcat':
409
      param_code += 'MCAT'
410
    else:
411
      raise NotImplementedError
412
413
    ### Loss Function
414
    param_code += '_%s' % args.bag_loss
415
    if args.bag_loss in ['nll_surv']:
416
        param_code += '_a%s' % str(args.alpha_surv)
417
418
    ### Learning Rate
419
    if args.lr != 2e-4:
420
      param_code += '_lr%s' % format(args.lr, '.0e')
421
422
    ### L1-Regularization
423
    if args.reg_type != 'None':
424
      param_code += '_%sreg%s' % (args.reg_type, format(args.lambda_reg, '.0e'))
425
426
    if args.dropinput:
427
      param_code += '_drop%s' % str(int(args.dropinput*100))
428
429
    param_code += '_%s' % args.which_splits.split("_")[0]
430
431
    ### Batch Size
432
    if args.batch_size != 1:
433
      param_code += '_b%s' % str(args.batch_size)
434
435
    ### Gradient Accumulation
436
    if args.gc != 1:
437
      param_code += '_gc%s' % str(args.gc)
438
439
    ### Applying Which Features
440
    if args.apply_sigfeats:
441
      param_code += '_sig'
442
      dataset_path += '_sig'
443
    elif args.apply_mutsig:
444
      param_code += '_mutsig'
445
      dataset_path += '_mutsig'
446
447
    ### Fusion Operation
448
    if args.fusion != "None":
449
      param_code += '_' + args.fusion
450
451
    ### Updating
452
    args.exp_code = exp_code + "_" + param_code
453
    args.param_code = param_code
454
    args.dataset_path = dataset_path
455
456
    return args