Diff of /code/main.py [000000] .. [118be8]

Switch to unified view

a b/code/main.py
1
import pandas as pd
2
from mil_data_generator import *
3
from mil_models_pytorch import*
4
from mil_trainer_torch import *
5
import argparse
6
7
np.random.seed(42)
8
random.seed(42)
9
torch.manual_seed(42)
10
11
12
def main(args):
13
14
    metrics = []
15
    for i_iteration in np.arange(0, args.iterations):
16
        id = str(i_iteration) + '_'
17
        df = pd.read_excel(args.dir_data_frame)
18
19
        # Set data generators
20
        dataset_train = MILDataset(args.dir_images, df[df['Partition'] == 'train'], args.classes,
21
                                   bag_id='slide_name', input_shape=args.input_shape,
22
                                   data_augmentation=args.data_augmentation, images_on_ram=args.images_on_ram,
23
                                   pMIL=args.pMIL, proportions=args.proportions)
24
        data_generator_train = MILDataGenerator(dataset_train, batch_size=1, shuffle=True, max_instances=512)
25
26
        dataset_val = MILDataset(args.dir_images, df[df['Partition'] == 'val'], args.classes,
27
                                 bag_id='slide_name', input_shape=args.input_shape,
28
                                 data_augmentation=args.data_augmentation, images_on_ram=args.images_on_ram,
29
                                 pMIL=args.pMIL, proportions=args.proportions)
30
        data_generator_val = MILDataGenerator(dataset_val, batch_size=1, shuffle=False, max_instances=512)
31
32
        dataset_test = MILDataset(args.dir_images, df[df['Partition'] == 'test'], args.classes,
33
                                  bag_id='slide_name', input_shape=args.input_shape,
34
                                  data_augmentation=args.data_augmentation, images_on_ram=args.images_on_ram,
35
                                  pMIL=args.pMIL, proportions=args.proportions,
36
                                  dataframe_instances=pd.read_excel(args.dir_data_frame_test))
37
        data_generator_test = MILDataGenerator(dataset_test, batch_size=1, shuffle=False, max_instances=512)
38
39
        # Set network architecture
40
        network = MILArchitecture(args.classes, mode=args.mode, aggregation=args.aggregation,
41
                                  backbone='vgg19', include_background=args.include_background)
42
43
        # Perform training
44
        trainer = MILTrainer(args.dir_results + args.experiment_name + '/', network,
45
                             lr=args.lr, pMIL=args.pMIL, margin=args.margin,
46
                             alpha_ic=args.alpha_ic, alpha_pc=args.alpha_pc, t_ic=args.t_ic,
47
                             t_pc=args.t_pc, alpha_ce=args.alpha_ce, id=id,
48
                             early_stopping=args.early_stopping, scheduler=args.scheduler,
49
                             virtual_batch_size=args.virtual_batch_size,
50
                             criterion=args.criterion,
51
                             alpha_H=args.alpha_H)
52
        trainer.train(train_generator=data_generator_train, val_generator=data_generator_val,
53
                      test_generator=data_generator_test, epochs=args.epochs)
54
55
        metrics.append([list(trainer.metrics.values())[1:]])
56
57
    # Get overall metrics
58
    metrics = np.squeeze(np.array(metrics))
59
60
    mu = np.mean(metrics, axis=0)
61
    std = np.std(metrics, axis=0)
62
63
    info = "AUCtest={:.4f}({:.4f}) ; AUCval={:.4f}({:.4f})  ; acc={:.4f}({:.4f}) ; f1-score={:.4f}({:.4f}) ; k2={:.4f}({:.4f})".format(
64
          mu[0], std[0], mu[1], std[1], mu[2], std[2], mu[3], std[3], mu[4], std[4])
65
    if args.alpha_pc > 0:
66
        info += " ; constrain_cumpliment={:.4f}({:.4f}) ; constrain_proportion={:.4f}({:.4f})".format(
67
          mu[5], std[5], mu[6], std[6])
68
69
    f = open(args.dir_results + args.experiment_name + '/' + 'method_metrics.txt', 'w')
70
    f.write(info)
71
    f.close()
72
73
74
if __name__ == '__main__':
75
    parser = argparse.ArgumentParser()
76
    parser.add_argument("--dir_images", default='../data/SICAP_MIL/patches/', type=str)
77
    parser.add_argument("--dir_data_frame", default='../data/SICAP_MIL/dataframes/gt_global_slides.xlsx', type=str)
78
    parser.add_argument("--dir_data_frame_test", default='../data/SICAP_MIL/dataframes/gt_test_patches.xlsx', type=str)
79
    parser.add_argument("--dir_results", default='../data/results/', type=str)
80
    parser.add_argument("--criterion", default='z', type=str)
81
    parser.add_argument("--experiment_name", default="test_test_test", type=str)
82
    parser.add_argument("--classes", default=['G3', 'G4', 'G5'], type=list)
83
    parser.add_argument("--proportions", default=['Primary', 'Secondary'], type=list)
84
    parser.add_argument("--input_shape", default=(3, 224, 224), type=list)
85
    parser.add_argument("--images_on_ram", default=False, type=bool)
86
    parser.add_argument("--epochs", default=100, type=int)
87
    parser.add_argument("--aggregation", default="max", type=str)
88
    parser.add_argument("--mode", default="instance", type=str)
89
    parser.add_argument("--include_background", default=True, type=bool)
90
    parser.add_argument("--lr", default=1*1e-2, type=float)
91
    parser.add_argument("--pMIL", default=False, type=bool)
92
93
    parser.add_argument("--alpha_ce", default=1., type=float)
94
    parser.add_argument("--margin", default=0., type=float)
95
    parser.add_argument("--alpha_ic", default=1, type=float)
96
    parser.add_argument("--alpha_pc", default=1, type=float)
97
    parser.add_argument("--alpha_H", default=0, type=float)
98
    parser.add_argument("--t_ic", default=15, type=float)
99
    parser.add_argument("--t_pc", default=5, type=float)
100
    parser.add_argument("--data_augmentation", default=True, type=bool)
101
    parser.add_argument("--iterations", default=3, type=int)
102
103
    parser.add_argument("--early_stopping", default=True, type=bool)
104
    parser.add_argument("--scheduler", default=True, type=bool)
105
106
    parser.add_argument("--virtual_batch_size", default=1, type=int)
107
108
    args = parser.parse_args()
109
    main(args)
110