Diff of /utils/utils.py [000000] .. [0fdc30]

Switch to unified view

a b/utils/utils.py
1
import pickle
2
import torch
3
import numpy as np
4
import torch.nn as nn
5
import pdb
6
7
import torch
8
import numpy as np
9
import torch.nn as nn
10
from torchvision import transforms
11
from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler
12
import torch.optim as optim
13
import pdb
14
import torch.nn.functional as F
15
import math
16
from itertools import islice
17
import collections
18
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
20
class SubsetSequentialSampler(Sampler):
21
    """Samples elements sequentially from a given list of indices, without replacement.
22
23
    Arguments:
24
        indices (sequence): a sequence of indices
25
    """
26
    def __init__(self, indices):
27
        self.indices = indices
28
29
    def __iter__(self):
30
        return iter(self.indices)
31
32
    def __len__(self):
33
        return len(self.indices)
34
35
def collate_MIL(batch):
36
    img = torch.cat([item[0] for item in batch], dim = 0)
37
    label = torch.LongTensor([item[1] for item in batch])
38
    return [img, label]
39
40
def collate_MIL_survival(batch):
41
    img = torch.cat([item[0] for item in batch], dim = 0)
42
    event = torch.LongTensor([item[1] for item in batch])
43
    time = torch.LongTensor([item[2] for item in batch])
44
    return [img, event, time]
45
46
def collate_features(batch):
47
    img = torch.cat([item[0] for item in batch], dim = 0)
48
    coords = np.vstack([item[1] for item in batch])
49
    return [img, coords]
50
51
def get_simple_loader(dataset, batch_size=1, num_workers=1, survival=False):
52
    collate_fn = collate_MIL_survival if survival else collate_MIL
53
    kwargs = {'num_workers': 4, 'pin_memory': False, 'num_workers': num_workers} if device.type == "cuda" else {}
54
    loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_fn, **kwargs)
55
    return loader 
56
57
def get_split_loader(split_dataset, training = False, testing = False, weighted = False, survival = False):
58
    """
59
        return either the validation loader or training loader 
60
    """
61
    kwargs = {'num_workers': 4} if device.type == "cuda" else {}
62
    collate_fn = collate_MIL_survival if survival else collate_MIL
63
    
64
    if not testing:
65
        if training:
66
            if weighted:
67
                weights = make_weights_for_balanced_classes_split(split_dataset)
68
                loader = DataLoader(split_dataset, batch_size=1, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate_fn, **kwargs) 
69
            else:
70
                loader = DataLoader(split_dataset, batch_size=1, sampler = RandomSampler(split_dataset), collate_fn = collate_fn, **kwargs)
71
        else:
72
            loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate_fn, **kwargs)
73
    
74
    else:
75
        ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False)
76
        loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate_fn, **kwargs )
77
78
    return loader
79
80
def get_optim(model, args):
81
    if args.opt == "adam":
82
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg)
83
    elif args.opt == 'sgd':
84
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg)
85
    else:
86
        raise NotImplementedError
87
    return optimizer
88
89
def print_network(net):
90
    num_params = 0
91
    num_params_train = 0
92
    print(net)
93
    
94
    for param in net.parameters():
95
        n = param.numel()
96
        num_params += n
97
        if param.requires_grad:
98
            num_params_train += n
99
    
100
    print('Total number of parameters: %d' % num_params)
101
    print('Total number of trainable parameters: %d' % num_params_train)
102
103
104
def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5,
105
    seed = 7, label_frac = 1.0, custom_test_ids = None):
106
    indices = np.arange(samples).astype(int)
107
    
108
    if custom_test_ids is not None:
109
        indices = np.setdiff1d(indices, custom_test_ids)
110
111
    np.random.seed(seed)
112
    for i in range(n_splits):
113
        all_val_ids = []
114
        all_test_ids = []
115
        sampled_train_ids = []
116
        
117
        if custom_test_ids is not None: # pre-built test split, do not need to sample
118
            all_test_ids.extend(custom_test_ids)
119
120
        for c in range(len(val_num)):
121
            possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class
122
            val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids
123
124
            remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation
125
            all_val_ids.extend(val_ids)
126
127
            if custom_test_ids is None: # sample test split
128
129
                test_ids = np.random.choice(remaining_ids, test_num[c], replace = False)
130
                remaining_ids = np.setdiff1d(remaining_ids, test_ids)
131
                all_test_ids.extend(test_ids)
132
133
            if label_frac == 1:
134
                sampled_train_ids.extend(remaining_ids)
135
            
136
            else:
137
                sample_num  = math.ceil(len(remaining_ids) * label_frac)
138
                slice_ids = np.arange(sample_num)
139
                sampled_train_ids.extend(remaining_ids[slice_ids])
140
141
        yield sampled_train_ids, all_val_ids, all_test_ids
142
143
144
def nth(iterator, n, default=None):
145
    if n is None:
146
        return collections.deque(iterator, maxlen=0)
147
    else:
148
        return next(islice(iterator,n, None), default)
149
150
def calculate_error(Y_hat, Y):
151
    error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()
152
153
    return error
154
155
def make_weights_for_balanced_classes_split(dataset):
156
    N = float(len(dataset))                                           
157
    weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))]                                                                                                     
158
    weight = [0] * int(N)                                           
159
    for idx in range(len(dataset)):   
160
        y = dataset.getlabel(idx)                        
161
        weight[idx] = weight_per_class[y]                                  
162
163
    return torch.DoubleTensor(weight)
164
165
def initialize_weights(module):
166
    for m in module.modules():
167
        if isinstance(m, nn.Linear):
168
            nn.init.xavier_normal_(m.weight)
169
            m.bias.data.zero_()
170
        
171
        elif isinstance(m, nn.BatchNorm1d):
172
            nn.init.constant_(m.weight, 1)
173
            nn.init.constant_(m.bias, 0)
174