|
a |
|
b/code/train_student.py |
|
|
1 |
import pandas as pd |
|
|
2 |
from mil_data_generator import * |
|
|
3 |
from mil_models_pytorch import* |
|
|
4 |
from mil_trainer_torch import * |
|
|
5 |
from sklearn.utils.class_weight import compute_class_weight |
|
|
6 |
import torch |
|
|
7 |
from torchvision import transforms |
|
|
8 |
import argparse |
|
|
9 |
|
|
|
10 |
np.random.seed(42) |
|
|
11 |
random.seed(42) |
|
|
12 |
torch.manual_seed(42) |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def main(args): |
|
|
16 |
|
|
|
17 |
# INPUTS # |
|
|
18 |
dir_images = '../data/SICAP_MIL/patches/' |
|
|
19 |
dir_data_frame = '../data/SICAP_MIL/dataframes/gt_global_slides.xlsx' |
|
|
20 |
dir_data_frame_test = '../data/SICAP_MIL/dataframes/gt_test_patches.xlsx' |
|
|
21 |
dir_experiment = '../data/results/' + args.experiment_name + '/' |
|
|
22 |
|
|
|
23 |
classes = ['G3', 'G4', 'G5'] |
|
|
24 |
proportions = ['pG3', 'pG4', 'pG5'] |
|
|
25 |
input_shape = (3, 224, 224) |
|
|
26 |
images_on_ram = True |
|
|
27 |
data_augmentation = True |
|
|
28 |
pMIL = False |
|
|
29 |
aggregation = 'max' # 'max', 'mean', 'attentionMIL', 'mcAttentionMIL' |
|
|
30 |
mode = 'instance' # 'embedding', 'instance', 'mixed' |
|
|
31 |
include_background = True |
|
|
32 |
iterations = 3 |
|
|
33 |
|
|
|
34 |
df = pd.read_excel(dir_data_frame) |
|
|
35 |
|
|
|
36 |
metrics = [] |
|
|
37 |
for ii_iteration in np.arange(0, iterations): |
|
|
38 |
|
|
|
39 |
# Set data generators |
|
|
40 |
dataset_train = MILDataset(dir_images, df[df['Partition'] == 'train'], classes, bag_id='slide_name', |
|
|
41 |
input_shape=input_shape, data_augmentation=False, images_on_ram=images_on_ram, |
|
|
42 |
pMIL=pMIL, proportions=proportions) |
|
|
43 |
data_generator_train = MILDataGenerator(dataset_train, batch_size=1, shuffle=True, max_instances=512) |
|
|
44 |
|
|
|
45 |
dataset_test = MILDataset(dir_images, df[df['Partition'] == 'test'], classes, bag_id='slide_name', |
|
|
46 |
input_shape=input_shape, data_augmentation=False, images_on_ram=images_on_ram, |
|
|
47 |
pMIL=pMIL, proportions=proportions, dataframe_instances=pd.read_excel(dir_data_frame_test)) |
|
|
48 |
data_generator_test = MILDataGenerator(dataset_test, batch_size=1, shuffle=False, max_instances=512) |
|
|
49 |
|
|
|
50 |
# Test at instance level |
|
|
51 |
X_test = data_generator_test.dataset.X[data_generator_test.dataset.y_instances[:, 0] != -1, :, :, :] |
|
|
52 |
Y_test = data_generator_test.dataset.y_instances[data_generator_test.dataset.y_instances[:, 0] != -1, :] |
|
|
53 |
|
|
|
54 |
# Load network |
|
|
55 |
network = torch.load(dir_experiment + str(ii_iteration) + '_network_weights_best.pth') |
|
|
56 |
|
|
|
57 |
# Pseudolabels on training set |
|
|
58 |
labels = [] |
|
|
59 |
yhat_one_hot = [] |
|
|
60 |
Yglobal = data_generator_train.dataset.Yglobal |
|
|
61 |
X = data_generator_train.dataset.X |
|
|
62 |
|
|
|
63 |
for i in np.arange(0, X.shape[0]): |
|
|
64 |
print(str(i + 1) + '/' + str(X.shape[0]), end='\r') |
|
|
65 |
|
|
|
66 |
# Tensorize input |
|
|
67 |
x = torch.tensor(X[i, :, :, :]).cuda().float() |
|
|
68 |
x = x.unsqueeze(0) |
|
|
69 |
|
|
|
70 |
features = network.bb(x) |
|
|
71 |
yhat = torch.softmax(network.classifier(torch.squeeze(features)), 0) |
|
|
72 |
yhat = yhat.detach().cpu().numpy() |
|
|
73 |
yhat_one_hot.append(yhat) |
|
|
74 |
|
|
|
75 |
if np.max(Yglobal[i, 1:]) == 0: |
|
|
76 |
labels.append(0) |
|
|
77 |
else: |
|
|
78 |
if np.argmax(yhat) > 0: |
|
|
79 |
if Yglobal[i, np.argmax(yhat)] == 1 and yhat[np.argmax(yhat)] > 0.5: |
|
|
80 |
labels.append(np.argmax(yhat)) |
|
|
81 |
else: |
|
|
82 |
labels.append(10) |
|
|
83 |
else: |
|
|
84 |
labels.append(10) |
|
|
85 |
labels = np.array(labels) |
|
|
86 |
yhat_one_hot = np.array(yhat_one_hot) |
|
|
87 |
|
|
|
88 |
X = X[labels != 10, :, :, :] |
|
|
89 |
Y = labels[labels != 10] |
|
|
90 |
images_id = np.array(dataset_train.images)[labels != 10] |
|
|
91 |
class_weights = compute_class_weight('balanced', [0, 1, 2, 3], Y) |
|
|
92 |
|
|
|
93 |
# Set student network architecture |
|
|
94 |
lr = 1e-2 |
|
|
95 |
network = MILArchitecture(classes, mode=mode, aggregation=aggregation, backbone='vgg19', |
|
|
96 |
include_background=include_background).cuda() |
|
|
97 |
opt = torch.optim.SGD(network.parameters(), lr=lr) |
|
|
98 |
|
|
|
99 |
tranf = torch.nn.Sequential(transforms.RandomHorizontalFlip(), |
|
|
100 |
transforms.RandomRotation(degrees=(-45, 45)), |
|
|
101 |
transforms.GaussianBlur(3, sigma=(0.1, 2.0)), |
|
|
102 |
transforms.ColorJitter(brightness=.5, hue=.3)).cuda() |
|
|
103 |
|
|
|
104 |
training_data = CustomImageDataset(X, Y, transform=False) |
|
|
105 |
train_dataloader = CustomGenerator(training_data, bs=32, shuffle=True) |
|
|
106 |
|
|
|
107 |
|
|
|
108 |
def test_instances(X, Y, network, dir_out, i_iteration): |
|
|
109 |
network.eval() |
|
|
110 |
Yhat = [] |
|
|
111 |
for i in np.arange(0, X.shape[0]): |
|
|
112 |
print(str(i + 1) + '/' + str(X.shape[0]), end='\r') |
|
|
113 |
|
|
|
114 |
# Tensorize input |
|
|
115 |
x = torch.tensor(X[i, :, :, :]).cuda().float() |
|
|
116 |
x = x.unsqueeze(0) |
|
|
117 |
|
|
|
118 |
features = network.bb(x) |
|
|
119 |
yhat = torch.softmax(network.classifier(torch.squeeze(features)), 0) |
|
|
120 |
yhat = torch.argmax(yhat).detach().cpu().numpy() |
|
|
121 |
|
|
|
122 |
Yhat.append(yhat) |
|
|
123 |
|
|
|
124 |
Yhat = np.array(Yhat) |
|
|
125 |
Y = np.argmax(Y, 1) |
|
|
126 |
|
|
|
127 |
cr = classification_report(Y, Yhat, target_names=['NC'] + classes, digits=4) |
|
|
128 |
cm = confusion_matrix(Y, Yhat) |
|
|
129 |
k2 = cohen_kappa_score(Y, Yhat, weights='quadratic') |
|
|
130 |
|
|
|
131 |
f = open(dir_out + str(i_iteration) + '_report_student.txt', 'w') |
|
|
132 |
f.write('Title\n\nClassification Report\n\n{}\n\nConfusion Matrix\n\n{}\n\nKappa\n\n{}\n'.format(cr, cm, k2)) |
|
|
133 |
f.close() |
|
|
134 |
|
|
|
135 |
return k2 |
|
|
136 |
|
|
|
137 |
# STUDENT TRAINING |
|
|
138 |
|
|
|
139 |
|
|
|
140 |
epochs = 60 |
|
|
141 |
dropout_rate = 0.2 |
|
|
142 |
for i_epoch in np.arange(0, epochs): |
|
|
143 |
l_epoch = 0 |
|
|
144 |
|
|
|
145 |
if (i_epoch + 1) % 25 == 0: |
|
|
146 |
for g in opt.param_groups: |
|
|
147 |
g['lr'] = g['lr'] / 2 |
|
|
148 |
|
|
|
149 |
for i_iteration, (X, Y) in enumerate(train_dataloader): |
|
|
150 |
# Set model to training mode and clear gradients |
|
|
151 |
network.train() |
|
|
152 |
opt.zero_grad() |
|
|
153 |
|
|
|
154 |
X = X.cuda().float() |
|
|
155 |
X = tranf(X) |
|
|
156 |
|
|
|
157 |
logits = network.classifier(torch.nn.Dropout(dropout_rate)(torch.squeeze(network.bb(X)))) |
|
|
158 |
|
|
|
159 |
L = torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights).cuda().float())(logits, Y.type(torch.LongTensor).cuda()) |
|
|
160 |
|
|
|
161 |
L.backward() |
|
|
162 |
opt.step() |
|
|
163 |
|
|
|
164 |
L_iteration = L.detach().cpu().numpy() |
|
|
165 |
l_epoch += L_iteration |
|
|
166 |
|
|
|
167 |
info = "[INFO] Epoch {}/{} -- Step {}/{}: Lce={:.6f}".format( |
|
|
168 |
i_epoch + 1, epochs, i_iteration + 1, len(train_dataloader), L_iteration) |
|
|
169 |
print(info, end='\r') |
|
|
170 |
|
|
|
171 |
l_epoch = l_epoch/len(train_dataloader) |
|
|
172 |
|
|
|
173 |
k2 = test_instances(X_test, Y_test, network, dir_experiment, ii_iteration) |
|
|
174 |
|
|
|
175 |
info = "[INFO] Epoch {}/{} -- Step {}/{}: Lce={:.6f}; k2={:.6f}".format( |
|
|
176 |
i_epoch+1, epochs, i_iteration, len(train_dataloader), l_epoch, k2) |
|
|
177 |
print(info, end='\n') |
|
|
178 |
|
|
|
179 |
torch.save(network, dir_experiment + str(ii_iteration) + '_student_network_weights.pth') |
|
|
180 |
|
|
|
181 |
|
|
|
182 |
if __name__ == '__main__': |
|
|
183 |
parser = argparse.ArgumentParser() |
|
|
184 |
parser.add_argument("--experiment_name", default="test_test_test", type=str) |
|
|
185 |
|
|
|
186 |
args = parser.parse_args() |
|
|
187 |
main(args) |