Diff of /utils/utils.py [000000] .. [4cd6c8]

Switch to unified view

a b/utils/utils.py
1
import pickle
2
import torch
3
import numpy as np
4
import torch.nn as nn
5
import pdb
6
7
import torch
8
import numpy as np
9
import torch.nn as nn
10
from torchvision import transforms
11
from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler
12
import torch.optim as optim
13
import pdb
14
import torch.nn.functional as F
15
import math
16
from itertools import islice
17
import collections
18
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
20
class SubsetSequentialSampler(Sampler):
21
    """Samples elements sequentially from a given list of indices, without replacement.
22
23
    Arguments:
24
        indices (sequence): a sequence of indices
25
    """
26
    def __init__(self, indices):
27
        self.indices = indices
28
29
    def __iter__(self):
30
        return iter(self.indices)
31
32
    def __len__(self):
33
        return len(self.indices)
34
35
def collate_MIL_mtl_sex(batch):
36
    img = torch.cat([item[0] for item in batch], dim = 0)
37
    label = torch.LongTensor([item[1] for item in batch])
38
    site = torch.LongTensor([item[2] for item in batch])
39
    sex = torch.LongTensor([item[3] for item in batch])
40
    # for item in batch:
41
    #   print(item)
42
    return [img, label, site, sex]
43
44
def collate_MIL_mtl(batch):
45
    img = torch.cat([item[0] for item in batch], dim = 0)
46
    label_task1 = torch.LongTensor([item[1] for item in batch])
47
    label_task2 = torch.LongTensor([item[2] for item in batch])
48
    label_task3 = torch.LongTensor([item[3] for item in batch])
49
    # for item in batch:
50
    #   print(item)
51
    return [img, label_task1, label_task2, label_task3]
52
53
def collate_MIL(batch):
54
    img = torch.cat([item[0] for item in batch], dim = 0)
55
    label = torch.LongTensor([item[1] for item in batch])
56
    return [img, label]
57
58
def collate_features(batch):
59
    img = torch.cat([item[0] for item in batch], dim = 0)
60
    coords = np.vstack([item[1] for item in batch])
61
    return [img, coords]
62
63
64
collate_dict = {'MIL': collate_MIL, 'MIL_mtl': collate_MIL_mtl, 'MIL_mtl_sex': collate_MIL_mtl_sex, 'MIL_sex': collate_MIL_mtl}
65
66
def get_simple_loader(dataset, batch_size=1, collate_fn='MIL'):
67
    kwargs = {'num_workers': 32} if device.type == "cuda" else {}
68
    collate = collate_dict[collate_fn]
69
    loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate, **kwargs)
70
    return loader
71
72
def get_split_loader(split_dataset, training = False, testing = False, weighted = False, collate_fn='MIL'):
73
    """
74
        return either the validation loader or training loader
75
    """
76
    collate = collate_dict[collate_fn]
77
78
    kwargs = {'num_workers': 4} if device.type == "cuda" else {}
79
    if not testing:
80
        if training:
81
            if weighted:
82
                weights = make_weights_for_balanced_classes_split(split_dataset)
83
                loader = DataLoader(split_dataset, batch_size=1, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate, **kwargs)
84
            else:
85
                loader = DataLoader(split_dataset, batch_size=1, sampler = RandomSampler(split_dataset), collate_fn = collate, **kwargs)
86
        else:
87
            loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate, **kwargs)
88
89
    else:
90
        ids = np.random.choice(np.arange(len(split_dataset)), int(len(split_dataset)*0.01), replace = False)
91
        loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate, **kwargs )
92
93
    return loader
94
95
def get_optim(model, args):
96
    if args.opt == "adam":
97
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg)
98
    elif args.opt == 'sgd':
99
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg)
100
    else:
101
        raise NotImplementedError
102
    return optimizer
103
104
def print_network(net):
105
    num_params = 0
106
    num_params_train = 0
107
    print(net)
108
109
    for param in net.parameters():
110
        n = param.numel()
111
        num_params += n
112
        if param.requires_grad:
113
            num_params_train += n
114
115
    print('Total number of parameters: %d' % num_params)
116
    print('Total number of trainable parameters: %d' % num_params_train)
117
118
119
def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5,
120
    seed = 7, label_frac = 1.0, custom_test_ids = None):
121
    indices = np.arange(samples).astype(int)
122
123
    if custom_test_ids is not None:
124
        indices = np.setdiff1d(indices, custom_test_ids)
125
126
    np.random.seed(seed)
127
    for i in range(n_splits):
128
        all_val_ids = []
129
        all_test_ids = []
130
        sampled_train_ids = []
131
132
        if custom_test_ids is not None: # pre-built test split, do not need to sample
133
            all_test_ids.extend(custom_test_ids)
134
135
        for c in range(len(val_num)):
136
            if c == 38:
137
                pdb.set_trace()
138
            possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class
139
            remaining_ids = possible_indices
140
141
            if val_num[c] > 0:
142
                val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids
143
                remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation
144
                all_val_ids.extend(val_ids)
145
146
            if custom_test_ids is None and test_num[c] > 0: # sample test split
147
148
                test_ids = np.random.choice(remaining_ids, test_num[c], replace = False)
149
                remaining_ids = np.setdiff1d(remaining_ids, test_ids)
150
                all_test_ids.extend(test_ids)
151
152
            if label_frac == 1:
153
                sampled_train_ids.extend(remaining_ids)
154
155
            else:
156
                sample_num  = math.ceil(len(remaining_ids) * label_frac)
157
                slice_ids = np.arange(sample_num)
158
                sampled_train_ids.extend(remaining_ids[slice_ids])
159
160
        yield sampled_train_ids, all_val_ids, all_test_ids
161
162
163
def nth(iterator, n, default=None):
164
    if n is None:
165
        return collections.deque(iterator, maxlen=0)
166
    else:
167
        return next(islice(iterator,n, None), default)
168
169
def calculate_error(Y_hat, Y):
170
    error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()
171
172
    return error
173
174
def make_weights_for_balanced_classes_split(dataset):
175
    N = float(len(dataset))
176
    weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))]
177
    weight = [0] * int(N)
178
    for idx in range(len(dataset)):
179
        y = dataset.getlabel(idx)
180
        weight[idx] = weight_per_class[y]
181
182
    return torch.DoubleTensor(weight)
183
184
def initialize_weights(module):
185
    for m in module.modules():
186
        if isinstance(m, nn.Linear):
187
            nn.init.xavier_normal_(m.weight)
188
            m.bias.data.zero_()
189
190
        elif isinstance(m, nn.BatchNorm1d):
191
            nn.init.constant_(m.weight, 1)
192
            nn.init.constant_(m.bias, 0)