[69f1e5]: / code / main.py

Download this file

111 lines (89 with data), 5.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
import pandas as pd
from mil_data_generator import *
from mil_models_pytorch import*
from mil_trainer_torch import *
import argparse
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
def main(args):
metrics = []
for i_iteration in np.arange(0, args.iterations):
id = str(i_iteration) + '_'
df = pd.read_excel(args.dir_data_frame)
# Set data generators
dataset_train = MILDataset(args.dir_images, df[df['Partition'] == 'train'], args.classes,
bag_id='slide_name', input_shape=args.input_shape,
data_augmentation=args.data_augmentation, images_on_ram=args.images_on_ram,
pMIL=args.pMIL, proportions=args.proportions)
data_generator_train = MILDataGenerator(dataset_train, batch_size=1, shuffle=True, max_instances=512)
dataset_val = MILDataset(args.dir_images, df[df['Partition'] == 'val'], args.classes,
bag_id='slide_name', input_shape=args.input_shape,
data_augmentation=args.data_augmentation, images_on_ram=args.images_on_ram,
pMIL=args.pMIL, proportions=args.proportions)
data_generator_val = MILDataGenerator(dataset_val, batch_size=1, shuffle=False, max_instances=512)
dataset_test = MILDataset(args.dir_images, df[df['Partition'] == 'test'], args.classes,
bag_id='slide_name', input_shape=args.input_shape,
data_augmentation=args.data_augmentation, images_on_ram=args.images_on_ram,
pMIL=args.pMIL, proportions=args.proportions,
dataframe_instances=pd.read_excel(args.dir_data_frame_test))
data_generator_test = MILDataGenerator(dataset_test, batch_size=1, shuffle=False, max_instances=512)
# Set network architecture
network = MILArchitecture(args.classes, mode=args.mode, aggregation=args.aggregation,
backbone='vgg19', include_background=args.include_background)
# Perform training
trainer = MILTrainer(args.dir_results + args.experiment_name + '/', network,
lr=args.lr, pMIL=args.pMIL, margin=args.margin,
alpha_ic=args.alpha_ic, alpha_pc=args.alpha_pc, t_ic=args.t_ic,
t_pc=args.t_pc, alpha_ce=args.alpha_ce, id=id,
early_stopping=args.early_stopping, scheduler=args.scheduler,
virtual_batch_size=args.virtual_batch_size,
criterion=args.criterion,
alpha_H=args.alpha_H)
trainer.train(train_generator=data_generator_train, val_generator=data_generator_val,
test_generator=data_generator_test, epochs=args.epochs)
metrics.append([list(trainer.metrics.values())[1:]])
# Get overall metrics
metrics = np.squeeze(np.array(metrics))
mu = np.mean(metrics, axis=0)
std = np.std(metrics, axis=0)
info = "AUCtest={:.4f}({:.4f}) ; AUCval={:.4f}({:.4f}) ; acc={:.4f}({:.4f}) ; f1-score={:.4f}({:.4f}) ; k2={:.4f}({:.4f})".format(
mu[0], std[0], mu[1], std[1], mu[2], std[2], mu[3], std[3], mu[4], std[4])
if args.alpha_pc > 0:
info += " ; constrain_cumpliment={:.4f}({:.4f}) ; constrain_proportion={:.4f}({:.4f})".format(
mu[5], std[5], mu[6], std[6])
f = open(args.dir_results + args.experiment_name + '/' + 'method_metrics.txt', 'w')
f.write(info)
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dir_images", default='../data/SICAP_MIL/patches/', type=str)
parser.add_argument("--dir_data_frame", default='../data/SICAP_MIL/dataframes/gt_global_slides.xlsx', type=str)
parser.add_argument("--dir_data_frame_test", default='../data/SICAP_MIL/dataframes/gt_test_patches.xlsx', type=str)
parser.add_argument("--dir_results", default='../data/results/', type=str)
parser.add_argument("--criterion", default='z', type=str)
parser.add_argument("--experiment_name", default="test_test_test", type=str)
parser.add_argument("--classes", default=['G3', 'G4', 'G5'], type=list)
parser.add_argument("--proportions", default=['Primary', 'Secondary'], type=list)
parser.add_argument("--input_shape", default=(3, 224, 224), type=list)
parser.add_argument("--images_on_ram", default=False, type=bool)
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument("--aggregation", default="max", type=str)
parser.add_argument("--mode", default="instance", type=str)
parser.add_argument("--include_background", default=True, type=bool)
parser.add_argument("--lr", default=1*1e-2, type=float)
parser.add_argument("--pMIL", default=False, type=bool)
parser.add_argument("--alpha_ce", default=1., type=float)
parser.add_argument("--margin", default=0., type=float)
parser.add_argument("--alpha_ic", default=1, type=float)
parser.add_argument("--alpha_pc", default=1, type=float)
parser.add_argument("--alpha_H", default=0, type=float)
parser.add_argument("--t_ic", default=15, type=float)
parser.add_argument("--t_pc", default=5, type=float)
parser.add_argument("--data_augmentation", default=True, type=bool)
parser.add_argument("--iterations", default=3, type=int)
parser.add_argument("--early_stopping", default=True, type=bool)
parser.add_argument("--scheduler", default=True, type=bool)
parser.add_argument("--virtual_batch_size", default=1, type=int)
args = parser.parse_args()
main(args)