a b/code/train_student.py
1
import pandas as pd
2
from mil_data_generator import *
3
from mil_models_pytorch import*
4
from mil_trainer_torch import *
5
from sklearn.utils.class_weight import compute_class_weight
6
import torch
7
from torchvision import transforms
8
import argparse
9
10
np.random.seed(42)
11
random.seed(42)
12
torch.manual_seed(42)
13
14
15
def main(args):
16
17
    # INPUTS #
18
    dir_images = '../data/SICAP_MIL/patches/'
19
    dir_data_frame = '../data/SICAP_MIL/dataframes/gt_global_slides.xlsx'
20
    dir_data_frame_test = '../data/SICAP_MIL/dataframes/gt_test_patches.xlsx'
21
    dir_experiment = '../data/results/' + args.experiment_name + '/'
22
23
    classes = ['G3', 'G4', 'G5']
24
    proportions = ['pG3', 'pG4', 'pG5']
25
    input_shape = (3, 224, 224)
26
    images_on_ram = True
27
    data_augmentation = True
28
    pMIL = False
29
    aggregation = 'max'  # 'max', 'mean', 'attentionMIL', 'mcAttentionMIL'
30
    mode = 'instance'  # 'embedding', 'instance', 'mixed'
31
    include_background = True
32
    iterations = 3
33
34
    df = pd.read_excel(dir_data_frame)
35
36
    metrics = []
37
    for ii_iteration in np.arange(0, iterations):
38
39
        # Set data generators
40
        dataset_train = MILDataset(dir_images, df[df['Partition'] == 'train'], classes, bag_id='slide_name',
41
                                   input_shape=input_shape, data_augmentation=False, images_on_ram=images_on_ram,
42
                                   pMIL=pMIL, proportions=proportions)
43
        data_generator_train = MILDataGenerator(dataset_train, batch_size=1, shuffle=True, max_instances=512)
44
45
        dataset_test = MILDataset(dir_images, df[df['Partition'] == 'test'], classes, bag_id='slide_name',
46
                                  input_shape=input_shape, data_augmentation=False, images_on_ram=images_on_ram,
47
                                  pMIL=pMIL, proportions=proportions, dataframe_instances=pd.read_excel(dir_data_frame_test))
48
        data_generator_test = MILDataGenerator(dataset_test, batch_size=1, shuffle=False, max_instances=512)
49
50
        # Test at instance level
51
        X_test = data_generator_test.dataset.X[data_generator_test.dataset.y_instances[:, 0] != -1, :, :, :]
52
        Y_test = data_generator_test.dataset.y_instances[data_generator_test.dataset.y_instances[:, 0] != -1, :]
53
54
        # Load network
55
        network = torch.load(dir_experiment + str(ii_iteration) + '_network_weights_best.pth')
56
57
        # Pseudolabels on training set
58
        labels = []
59
        yhat_one_hot = []
60
        Yglobal = data_generator_train.dataset.Yglobal
61
        X = data_generator_train.dataset.X
62
63
        for i in np.arange(0, X.shape[0]):
64
            print(str(i + 1) + '/' + str(X.shape[0]), end='\r')
65
66
            # Tensorize input
67
            x = torch.tensor(X[i, :, :, :]).cuda().float()
68
            x = x.unsqueeze(0)
69
70
            features = network.bb(x)
71
            yhat = torch.softmax(network.classifier(torch.squeeze(features)), 0)
72
            yhat = yhat.detach().cpu().numpy()
73
            yhat_one_hot.append(yhat)
74
75
            if np.max(Yglobal[i, 1:]) == 0:
76
                labels.append(0)
77
            else:
78
                if np.argmax(yhat) > 0:
79
                    if Yglobal[i, np.argmax(yhat)] == 1 and yhat[np.argmax(yhat)] > 0.5:
80
                        labels.append(np.argmax(yhat))
81
                    else:
82
                        labels.append(10)
83
                else:
84
                    labels.append(10)
85
        labels = np.array(labels)
86
        yhat_one_hot = np.array(yhat_one_hot)
87
88
        X = X[labels != 10, :, :, :]
89
        Y = labels[labels != 10]
90
        images_id = np.array(dataset_train.images)[labels != 10]
91
        class_weights = compute_class_weight('balanced', [0, 1, 2, 3], Y)
92
93
        # Set student network architecture
94
        lr = 1e-2
95
        network = MILArchitecture(classes, mode=mode, aggregation=aggregation, backbone='vgg19',
96
                                  include_background=include_background).cuda()
97
        opt = torch.optim.SGD(network.parameters(), lr=lr)
98
99
        tranf = torch.nn.Sequential(transforms.RandomHorizontalFlip(),
100
                                    transforms.RandomRotation(degrees=(-45, 45)),
101
                                    transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
102
                                    transforms.ColorJitter(brightness=.5, hue=.3)).cuda()
103
104
        training_data = CustomImageDataset(X, Y, transform=False)
105
        train_dataloader = CustomGenerator(training_data, bs=32, shuffle=True)
106
107
108
        def test_instances(X, Y, network, dir_out, i_iteration):
109
            network.eval()
110
            Yhat = []
111
            for i in np.arange(0, X.shape[0]):
112
                print(str(i + 1) + '/' + str(X.shape[0]), end='\r')
113
114
                # Tensorize input
115
                x = torch.tensor(X[i, :, :, :]).cuda().float()
116
                x = x.unsqueeze(0)
117
118
                features = network.bb(x)
119
                yhat = torch.softmax(network.classifier(torch.squeeze(features)), 0)
120
                yhat = torch.argmax(yhat).detach().cpu().numpy()
121
122
                Yhat.append(yhat)
123
124
            Yhat = np.array(Yhat)
125
            Y = np.argmax(Y, 1)
126
127
            cr = classification_report(Y, Yhat, target_names=['NC'] + classes, digits=4)
128
            cm = confusion_matrix(Y, Yhat)
129
            k2 = cohen_kappa_score(Y, Yhat, weights='quadratic')
130
131
            f = open(dir_out + str(i_iteration) + '_report_student.txt', 'w')
132
            f.write('Title\n\nClassification Report\n\n{}\n\nConfusion Matrix\n\n{}\n\nKappa\n\n{}\n'.format(cr, cm, k2))
133
            f.close()
134
135
            return k2
136
137
        # STUDENT TRAINING
138
139
140
        epochs = 60
141
        dropout_rate = 0.2
142
        for i_epoch in np.arange(0, epochs):
143
            l_epoch = 0
144
145
            if (i_epoch + 1) % 25 == 0:
146
                for g in opt.param_groups:
147
                    g['lr'] = g['lr'] / 2
148
149
            for i_iteration, (X, Y) in enumerate(train_dataloader):
150
                # Set model to training mode and clear gradients
151
                network.train()
152
                opt.zero_grad()
153
154
                X = X.cuda().float()
155
                X = tranf(X)
156
157
                logits = network.classifier(torch.nn.Dropout(dropout_rate)(torch.squeeze(network.bb(X))))
158
159
                L = torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights).cuda().float())(logits, Y.type(torch.LongTensor).cuda())
160
161
                L.backward()
162
                opt.step()
163
164
                L_iteration = L.detach().cpu().numpy()
165
                l_epoch += L_iteration
166
167
                info = "[INFO] Epoch {}/{}  -- Step {}/{}: Lce={:.6f}".format(
168
                    i_epoch + 1, epochs, i_iteration + 1, len(train_dataloader), L_iteration)
169
                print(info, end='\r')
170
171
            l_epoch = l_epoch/len(train_dataloader)
172
173
            k2 = test_instances(X_test, Y_test, network, dir_experiment, ii_iteration)
174
175
            info = "[INFO] Epoch {}/{}  -- Step {}/{}: Lce={:.6f}; k2={:.6f}".format(
176
                i_epoch+1, epochs, i_iteration, len(train_dataloader), l_epoch, k2)
177
            print(info, end='\n')
178
179
        torch.save(network, dir_experiment + str(ii_iteration) + '_student_network_weights.pth')
180
181
182
if __name__ == '__main__':
183
    parser = argparse.ArgumentParser()
184
    parser.add_argument("--experiment_name", default="test_test_test", type=str)
185
186
    args = parser.parse_args()
187
    main(args)