Diff of /utils/utils.py [000000] .. [fdd588]

Switch to unified view

a b/utils/utils.py
1
import torch
2
import torch.optim as optim
3
import torch.nn as nn
4
from torchvision import transforms
5
from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler, RandomSampler, SequentialSampler, sampler
6
import torch.nn.functional as F
7
8
import numpy as np
9
import pdb
10
import math
11
from itertools import islice
12
import collections
13
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
15
class SubsetSequentialSampler(Sampler):
16
    """Samples elements sequentially from a given list of indices, without replacement.
17
18
    Arguments:
19
        indices (sequence): a sequence of indices
20
    """
21
    def __init__(self, indices):
22
        self.indices = indices
23
24
    def __iter__(self):
25
        return iter(self.indices)
26
27
    def __len__(self):
28
        return len(self.indices)
29
30
def collate_MIL_mtl_concat(batch):
31
    img = torch.cat([item[0] for item in batch], dim = 0)
32
    label = torch.LongTensor([item[1] for item in batch])
33
    site = torch.LongTensor([item[2] for item in batch])
34
    sex = torch.LongTensor([item[3] for item in batch])
35
    return [img, label, site, sex]
36
37
def get_simple_loader(dataset, batch_size=1):
38
    kwargs = {'num_workers': 4} if device.type == "cuda" else {}
39
    loader = DataLoader(dataset, batch_size=batch_size, sampler = sampler.SequentialSampler(dataset), collate_fn = collate_MIL_mtl_concat, **kwargs)
40
    return loader 
41
42
def get_split_loader(split_dataset, training = False, testing = False, weighted = False):
43
    """
44
        return either the validation loader or training loader 
45
    """
46
    kwargs = {'num_workers': 4} if device.type == "cuda" else {}
47
    if not testing:
48
        if training:
49
            if weighted:
50
                weights = make_weights_for_balanced_classes_split(split_dataset)
51
                loader = DataLoader(split_dataset, batch_size=1, sampler = WeightedRandomSampler(weights, len(weights)), collate_fn = collate_MIL_mtl_concat, **kwargs) 
52
            else:
53
                loader = DataLoader(split_dataset, batch_size=1, sampler = RandomSampler(split_dataset), collate_fn = collate_MIL_mtl_concat, **kwargs)
54
        else:
55
            loader = DataLoader(split_dataset, batch_size=1, sampler = SequentialSampler(split_dataset), collate_fn = collate_MIL_mtl_concat, **kwargs)
56
    
57
    else:
58
        ids = np.random.choice(np.arange(len(split_dataset)), int(len(split_dataset)*0.01), replace = False)
59
        loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate_MIL_mtl_concat, **kwargs )
60
61
    return loader
62
63
def get_optim(model, args):
64
    if args.opt == "adam":
65
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.reg)
66
    elif args.opt == 'sgd':
67
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.reg)
68
    else:
69
        raise NotImplementedError
70
    return optimizer
71
72
def print_network(net):
73
    num_params = 0
74
    num_params_train = 0
75
    print(net)
76
    
77
    for param in net.parameters():
78
        n = param.numel()
79
        num_params += n
80
        if param.requires_grad:
81
            num_params_train += n
82
    
83
    print('Total number of parameters: %d' % num_params)
84
    print('Total number of trainable parameters: %d' % num_params_train)
85
86
87
def generate_split(cls_ids, val_num, test_num, samples, n_splits = 5,
88
    seed = 7, label_frac = 1.0, custom_test_ids = None):
89
    indices = np.arange(samples).astype(int)
90
91
    if custom_test_ids is not None:
92
        indices = np.setdiff1d(indices, custom_test_ids)
93
94
    np.random.seed(seed)
95
    for i in range(n_splits):
96
        all_val_ids = []
97
        all_test_ids = []
98
        sampled_train_ids = []
99
        
100
        if custom_test_ids is not None: # pre-built test split, do not need to sample
101
            all_test_ids.extend(custom_test_ids)
102
103
        for c in range(len(val_num)):
104
            possible_indices = np.intersect1d(cls_ids[c], indices) #all indices of this class
105
            remaining_ids = possible_indices
106
107
            if val_num[c] > 0:
108
                val_ids = np.random.choice(possible_indices, val_num[c], replace = False) # validation ids
109
                remaining_ids = np.setdiff1d(possible_indices, val_ids) #indices of this class left after validation
110
                all_val_ids.extend(val_ids)
111
112
            if custom_test_ids is None and test_num[c] > 0: # sample test split
113
114
                test_ids = np.random.choice(remaining_ids, test_num[c], replace = False)
115
                remaining_ids = np.setdiff1d(remaining_ids, test_ids)
116
                all_test_ids.extend(test_ids)
117
118
            if label_frac == 1:
119
                sampled_train_ids.extend(remaining_ids)
120
            
121
            else:
122
                sample_num  = math.ceil(len(remaining_ids) * label_frac)
123
                slice_ids = np.arange(sample_num)
124
                sampled_train_ids.extend(remaining_ids[slice_ids])
125
126
        yield sampled_train_ids, all_val_ids, all_test_ids
127
128
129
def nth(iterator, n, default=None):
130
    if n is None:
131
        return collections.deque(iterator, maxlen=0)
132
    else:
133
        return next(islice(iterator,n, None), default)
134
135
def calculate_error(Y_hat, Y):
136
    error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()
137
138
    return error
139
140
def make_weights_for_balanced_classes_split(dataset):
141
    N = float(len(dataset))                                           
142
    weight_per_class = [N/len(dataset.slide_cls_ids[c]) for c in range(len(dataset.slide_cls_ids))]                                                                                                     
143
    weight = [0] * int(N)                                           
144
    for idx in range(len(dataset)):   
145
        y = dataset.getlabel(idx)                        
146
        weight[idx] = weight_per_class[y]                                  
147
148
    return torch.DoubleTensor(weight)
149
150
def initialize_weights(module):
151
    for m in module.modules():
152
        if isinstance(m, nn.Linear):
153
            nn.init.xavier_normal_(m.weight)
154
            m.bias.data.zero_()