|
a |
|
b/code/main.py |
|
|
1 |
import pandas as pd |
|
|
2 |
from mil_data_generator import * |
|
|
3 |
from mil_models_pytorch import* |
|
|
4 |
from mil_trainer_torch import * |
|
|
5 |
import argparse |
|
|
6 |
|
|
|
7 |
np.random.seed(42) |
|
|
8 |
random.seed(42) |
|
|
9 |
torch.manual_seed(42) |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
def main(args): |
|
|
13 |
|
|
|
14 |
metrics = [] |
|
|
15 |
for i_iteration in np.arange(0, args.iterations): |
|
|
16 |
id = str(i_iteration) + '_' |
|
|
17 |
df = pd.read_excel(args.dir_data_frame) |
|
|
18 |
|
|
|
19 |
# Set data generators |
|
|
20 |
dataset_train = MILDataset(args.dir_images, df[df['Partition'] == 'train'], args.classes, |
|
|
21 |
bag_id='slide_name', input_shape=args.input_shape, |
|
|
22 |
data_augmentation=args.data_augmentation, images_on_ram=args.images_on_ram, |
|
|
23 |
pMIL=args.pMIL, proportions=args.proportions) |
|
|
24 |
data_generator_train = MILDataGenerator(dataset_train, batch_size=1, shuffle=True, max_instances=512) |
|
|
25 |
|
|
|
26 |
dataset_val = MILDataset(args.dir_images, df[df['Partition'] == 'val'], args.classes, |
|
|
27 |
bag_id='slide_name', input_shape=args.input_shape, |
|
|
28 |
data_augmentation=args.data_augmentation, images_on_ram=args.images_on_ram, |
|
|
29 |
pMIL=args.pMIL, proportions=args.proportions) |
|
|
30 |
data_generator_val = MILDataGenerator(dataset_val, batch_size=1, shuffle=False, max_instances=512) |
|
|
31 |
|
|
|
32 |
dataset_test = MILDataset(args.dir_images, df[df['Partition'] == 'test'], args.classes, |
|
|
33 |
bag_id='slide_name', input_shape=args.input_shape, |
|
|
34 |
data_augmentation=args.data_augmentation, images_on_ram=args.images_on_ram, |
|
|
35 |
pMIL=args.pMIL, proportions=args.proportions, |
|
|
36 |
dataframe_instances=pd.read_excel(args.dir_data_frame_test)) |
|
|
37 |
data_generator_test = MILDataGenerator(dataset_test, batch_size=1, shuffle=False, max_instances=512) |
|
|
38 |
|
|
|
39 |
# Set network architecture |
|
|
40 |
network = MILArchitecture(args.classes, mode=args.mode, aggregation=args.aggregation, |
|
|
41 |
backbone='vgg19', include_background=args.include_background) |
|
|
42 |
|
|
|
43 |
# Perform training |
|
|
44 |
trainer = MILTrainer(args.dir_results + args.experiment_name + '/', network, |
|
|
45 |
lr=args.lr, pMIL=args.pMIL, margin=args.margin, |
|
|
46 |
alpha_ic=args.alpha_ic, alpha_pc=args.alpha_pc, t_ic=args.t_ic, |
|
|
47 |
t_pc=args.t_pc, alpha_ce=args.alpha_ce, id=id, |
|
|
48 |
early_stopping=args.early_stopping, scheduler=args.scheduler, |
|
|
49 |
virtual_batch_size=args.virtual_batch_size, |
|
|
50 |
criterion=args.criterion, |
|
|
51 |
alpha_H=args.alpha_H) |
|
|
52 |
trainer.train(train_generator=data_generator_train, val_generator=data_generator_val, |
|
|
53 |
test_generator=data_generator_test, epochs=args.epochs) |
|
|
54 |
|
|
|
55 |
metrics.append([list(trainer.metrics.values())[1:]]) |
|
|
56 |
|
|
|
57 |
# Get overall metrics |
|
|
58 |
metrics = np.squeeze(np.array(metrics)) |
|
|
59 |
|
|
|
60 |
mu = np.mean(metrics, axis=0) |
|
|
61 |
std = np.std(metrics, axis=0) |
|
|
62 |
|
|
|
63 |
info = "AUCtest={:.4f}({:.4f}) ; AUCval={:.4f}({:.4f}) ; acc={:.4f}({:.4f}) ; f1-score={:.4f}({:.4f}) ; k2={:.4f}({:.4f})".format( |
|
|
64 |
mu[0], std[0], mu[1], std[1], mu[2], std[2], mu[3], std[3], mu[4], std[4]) |
|
|
65 |
if args.alpha_pc > 0: |
|
|
66 |
info += " ; constrain_cumpliment={:.4f}({:.4f}) ; constrain_proportion={:.4f}({:.4f})".format( |
|
|
67 |
mu[5], std[5], mu[6], std[6]) |
|
|
68 |
|
|
|
69 |
f = open(args.dir_results + args.experiment_name + '/' + 'method_metrics.txt', 'w') |
|
|
70 |
f.write(info) |
|
|
71 |
f.close() |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
if __name__ == '__main__': |
|
|
75 |
parser = argparse.ArgumentParser() |
|
|
76 |
parser.add_argument("--dir_images", default='../data/SICAP_MIL/patches/', type=str) |
|
|
77 |
parser.add_argument("--dir_data_frame", default='../data/SICAP_MIL/dataframes/gt_global_slides.xlsx', type=str) |
|
|
78 |
parser.add_argument("--dir_data_frame_test", default='../data/SICAP_MIL/dataframes/gt_test_patches.xlsx', type=str) |
|
|
79 |
parser.add_argument("--dir_results", default='../data/results/', type=str) |
|
|
80 |
parser.add_argument("--criterion", default='z', type=str) |
|
|
81 |
parser.add_argument("--experiment_name", default="test_test_test", type=str) |
|
|
82 |
parser.add_argument("--classes", default=['G3', 'G4', 'G5'], type=list) |
|
|
83 |
parser.add_argument("--proportions", default=['Primary', 'Secondary'], type=list) |
|
|
84 |
parser.add_argument("--input_shape", default=(3, 224, 224), type=list) |
|
|
85 |
parser.add_argument("--images_on_ram", default=False, type=bool) |
|
|
86 |
parser.add_argument("--epochs", default=100, type=int) |
|
|
87 |
parser.add_argument("--aggregation", default="max", type=str) |
|
|
88 |
parser.add_argument("--mode", default="instance", type=str) |
|
|
89 |
parser.add_argument("--include_background", default=True, type=bool) |
|
|
90 |
parser.add_argument("--lr", default=1*1e-2, type=float) |
|
|
91 |
parser.add_argument("--pMIL", default=False, type=bool) |
|
|
92 |
|
|
|
93 |
parser.add_argument("--alpha_ce", default=1., type=float) |
|
|
94 |
parser.add_argument("--margin", default=0., type=float) |
|
|
95 |
parser.add_argument("--alpha_ic", default=1, type=float) |
|
|
96 |
parser.add_argument("--alpha_pc", default=1, type=float) |
|
|
97 |
parser.add_argument("--alpha_H", default=0, type=float) |
|
|
98 |
parser.add_argument("--t_ic", default=15, type=float) |
|
|
99 |
parser.add_argument("--t_pc", default=5, type=float) |
|
|
100 |
parser.add_argument("--data_augmentation", default=True, type=bool) |
|
|
101 |
parser.add_argument("--iterations", default=3, type=int) |
|
|
102 |
|
|
|
103 |
parser.add_argument("--early_stopping", default=True, type=bool) |
|
|
104 |
parser.add_argument("--scheduler", default=True, type=bool) |
|
|
105 |
|
|
|
106 |
parser.add_argument("--virtual_batch_size", default=1, type=int) |
|
|
107 |
|
|
|
108 |
args = parser.parse_args() |
|
|
109 |
main(args) |
|
|
110 |
|