--- a +++ b/code/main.py @@ -0,0 +1,110 @@ +import pandas as pd +from mil_data_generator import * +from mil_models_pytorch import* +from mil_trainer_torch import * +import argparse + +np.random.seed(42) +random.seed(42) +torch.manual_seed(42) + + +def main(args): + + metrics = [] + for i_iteration in np.arange(0, args.iterations): + id = str(i_iteration) + '_' + df = pd.read_excel(args.dir_data_frame) + + # Set data generators + dataset_train = MILDataset(args.dir_images, df[df['Partition'] == 'train'], args.classes, + bag_id='slide_name', input_shape=args.input_shape, + data_augmentation=args.data_augmentation, images_on_ram=args.images_on_ram, + pMIL=args.pMIL, proportions=args.proportions) + data_generator_train = MILDataGenerator(dataset_train, batch_size=1, shuffle=True, max_instances=512) + + dataset_val = MILDataset(args.dir_images, df[df['Partition'] == 'val'], args.classes, + bag_id='slide_name', input_shape=args.input_shape, + data_augmentation=args.data_augmentation, images_on_ram=args.images_on_ram, + pMIL=args.pMIL, proportions=args.proportions) + data_generator_val = MILDataGenerator(dataset_val, batch_size=1, shuffle=False, max_instances=512) + + dataset_test = MILDataset(args.dir_images, df[df['Partition'] == 'test'], args.classes, + bag_id='slide_name', input_shape=args.input_shape, + data_augmentation=args.data_augmentation, images_on_ram=args.images_on_ram, + pMIL=args.pMIL, proportions=args.proportions, + dataframe_instances=pd.read_excel(args.dir_data_frame_test)) + data_generator_test = MILDataGenerator(dataset_test, batch_size=1, shuffle=False, max_instances=512) + + # Set network architecture + network = MILArchitecture(args.classes, mode=args.mode, aggregation=args.aggregation, + backbone='vgg19', include_background=args.include_background) + + # Perform training + trainer = MILTrainer(args.dir_results + args.experiment_name + '/', network, + lr=args.lr, pMIL=args.pMIL, margin=args.margin, + alpha_ic=args.alpha_ic, alpha_pc=args.alpha_pc, t_ic=args.t_ic, + t_pc=args.t_pc, alpha_ce=args.alpha_ce, id=id, + early_stopping=args.early_stopping, scheduler=args.scheduler, + virtual_batch_size=args.virtual_batch_size, + criterion=args.criterion, + alpha_H=args.alpha_H) + trainer.train(train_generator=data_generator_train, val_generator=data_generator_val, + test_generator=data_generator_test, epochs=args.epochs) + + metrics.append([list(trainer.metrics.values())[1:]]) + + # Get overall metrics + metrics = np.squeeze(np.array(metrics)) + + mu = np.mean(metrics, axis=0) + std = np.std(metrics, axis=0) + + info = "AUCtest={:.4f}({:.4f}) ; AUCval={:.4f}({:.4f}) ; acc={:.4f}({:.4f}) ; f1-score={:.4f}({:.4f}) ; k2={:.4f}({:.4f})".format( + mu[0], std[0], mu[1], std[1], mu[2], std[2], mu[3], std[3], mu[4], std[4]) + if args.alpha_pc > 0: + info += " ; constrain_cumpliment={:.4f}({:.4f}) ; constrain_proportion={:.4f}({:.4f})".format( + mu[5], std[5], mu[6], std[6]) + + f = open(args.dir_results + args.experiment_name + '/' + 'method_metrics.txt', 'w') + f.write(info) + f.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--dir_images", default='../data/SICAP_MIL/patches/', type=str) + parser.add_argument("--dir_data_frame", default='../data/SICAP_MIL/dataframes/gt_global_slides.xlsx', type=str) + parser.add_argument("--dir_data_frame_test", default='../data/SICAP_MIL/dataframes/gt_test_patches.xlsx', type=str) + parser.add_argument("--dir_results", default='../data/results/', type=str) + parser.add_argument("--criterion", default='z', type=str) + parser.add_argument("--experiment_name", default="test_test_test", type=str) + parser.add_argument("--classes", default=['G3', 'G4', 'G5'], type=list) + parser.add_argument("--proportions", default=['Primary', 'Secondary'], type=list) + parser.add_argument("--input_shape", default=(3, 224, 224), type=list) + parser.add_argument("--images_on_ram", default=False, type=bool) + parser.add_argument("--epochs", default=100, type=int) + parser.add_argument("--aggregation", default="max", type=str) + parser.add_argument("--mode", default="instance", type=str) + parser.add_argument("--include_background", default=True, type=bool) + parser.add_argument("--lr", default=1*1e-2, type=float) + parser.add_argument("--pMIL", default=False, type=bool) + + parser.add_argument("--alpha_ce", default=1., type=float) + parser.add_argument("--margin", default=0., type=float) + parser.add_argument("--alpha_ic", default=1, type=float) + parser.add_argument("--alpha_pc", default=1, type=float) + parser.add_argument("--alpha_H", default=0, type=float) + parser.add_argument("--t_ic", default=15, type=float) + parser.add_argument("--t_pc", default=5, type=float) + parser.add_argument("--data_augmentation", default=True, type=bool) + parser.add_argument("--iterations", default=3, type=int) + + parser.add_argument("--early_stopping", default=True, type=bool) + parser.add_argument("--scheduler", default=True, type=bool) + + parser.add_argument("--virtual_batch_size", default=1, type=int) + + args = parser.parse_args() + main(args) +