--- a +++ b/main.py @@ -0,0 +1,313 @@ +from __future__ import print_function + +import argparse +import pdb +import os +import math + +# internal imports +from utils.file_utils import save_pkl, load_pkl +from utils.utils import * +from utils.core_utils import train +from utils.core_utils_mtl import train as train_mtl +from datasets.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset +from datasets.dataset_mtl import Generic_WSI_MTL_Dataset, Generic_MIL_MTL_Dataset + +# pytorch imports +import torch +from torch.utils.data import DataLoader, sampler +import torch.nn as nn +import torch.nn.functional as F + +import pandas as pd +import numpy as np + + +# Rejection grade: +# binary classifier: +# class 0 - low grade +# class 1 - high grade +#------------------------------- +def main_grade(args): + print("-----------------------------------------") + print(" Grade Net (single task binary classifier") + print("-----------------------------------------") + + # create results directory if necessary + if not os.path.isdir(args.results_dir): + os.mkdir(args.results_dir) + + if args.k_start == -1: + start = 0 + else: + start = args.k_start + if args.k_end == -1: + end = args.k + else: + end = args.k_end + + all_test_auc = [] + all_val_auc = [] + all_test_acc = [] + all_val_acc = [] + folds = np.arange(start, end) + for i in folds: + seed_torch(args.seed) + train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, + csv_path='{}/splits_{}.csv'.format(args.split_dir, i)) + + datasets = (train_dataset, val_dataset, test_dataset) + results, test_auc, val_auc, test_acc, val_acc = train(datasets, i, args) + all_test_auc.append(test_auc) + all_val_auc.append(val_auc) + all_test_acc.append(test_acc) + all_val_acc.append(val_acc) + #write results to pkl + filename = os.path.join(args.results_dir, 'split_{}_results.pkl'.format(i)) + save_pkl(filename, results) + + final_df = pd.DataFrame({'folds': folds, 'test_auc': all_test_auc, + 'val_auc': all_val_auc, 'test_acc': all_test_acc, 'val_acc' : all_val_acc}) + + if len(folds) != args.k: + save_name = 'summary_partial_{}_{}.csv'.format(start, end) + else: + save_name = 'summary.csv' + final_df.to_csv(os.path.join(args.results_dir, save_name)) + + +# Multi task classifier for EMB evaluation: +# consist of 3 simultaneous tasks: +# task1: cellular vs non-cellular +# task2: antibody vs non-antibody +# task3: quilty lesion vs no quilty lesion +#------------------------------------------- +def main_mtl(args): + + print("----------------------------------------") + print(" EMB assessment - multi task classifier ") + print("----------------------------------------") + + # create results directory if necessary + if not os.path.isdir(args.results_dir): + os.mkdir(args.results_dir) + + if args.k_start == -1: + start = 0 + else: + start = args.k_start + if args.k_end == -1: + end = args.k + else: + end = args.k_end + +# arrays to collect scores -- replace by generic one when refactoring + all_task1_test_auc = [] + all_task1_val_auc = [] + all_task1_test_acc = [] + all_task1_val_acc = [] + + all_task2_test_auc = [] + all_task2_val_auc = [] + all_task2_test_acc = [] + all_task2_val_acc = [] + + all_task3_test_auc = [] + all_task3_val_auc = [] + all_task3_test_acc = [] + all_task3_val_acc = [] + + + folds = np.arange(start, end) + for i in folds: + seed_torch(args.seed) + train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, + csv_path='{}/splits_{}.csv'.format(args.split_dir, i)) + + print('training: {}, validation: {}, testing: {}'.format(len(train_dataset), len(val_dataset), len(test_dataset))) + datasets = (train_dataset, val_dataset, test_dataset) + + results, \ + task1_test_auc, task1_val_auc, task1_test_acc, task1_val_acc, \ + task2_test_auc, task2_val_auc, task2_test_acc, task2_val_acc, \ + task3_test_auc, task3_val_auc, task3_test_acc, task3_val_acc = train_mtl(datasets, i, args) + + all_task1_test_auc.append(task1_test_auc) + all_task1_val_auc.append( task1_val_auc ) + all_task1_test_acc.append(task1_test_acc) + all_task1_val_acc.append( task1_val_acc ) + + all_task2_test_auc.append(task2_test_auc) + all_task2_val_auc.append( task2_val_auc ) + all_task2_test_acc.append(task2_test_acc) + all_task2_val_acc.append( task2_val_acc ) + + all_task3_test_auc.append(task3_test_auc) + all_task3_val_auc.append( task3_val_auc ) + all_task3_test_acc.append(task3_test_acc) + all_task3_val_acc.append( task3_val_acc ) + + #write results to pkl + filename = os.path.join(args.results_dir, 'split_{}_results.pkl'.format(i)) + save_pkl(filename, results) + + final_df = pd.DataFrame({'folds': folds, + 'task1_test_auc': all_task1_test_auc, 'task1_val_auc': all_task1_val_auc, + 'task1_test_acc': all_task1_test_acc, 'task1_val_acc': all_task1_val_acc, + 'task2_test_auc': all_task2_test_auc, 'task2_val_auc': all_task2_val_auc, + 'task2_test_acc': all_task2_test_acc, 'task2_val_acc': all_task2_val_acc, + 'task3_test_auc': all_task3_test_auc, 'task3_val_auc': all_task3_val_auc, + 'task3_test_acc': all_task3_test_acc, 'task3_val_acc': all_task3_val_acc}) + + if len(folds) != args.k: + save_name = 'summary_partial_{}_{}.csv'.format(start, end) + else: + save_name = 'summary.csv' + final_df.to_csv(os.path.join(args.results_dir, save_name)) + + +# Training settings +parser = argparse.ArgumentParser(description='Configurations for WSI Training') +parser.add_argument('--data_root_dir', type=str, default='/media/fedshyvana/ssd1', + help='data directory') +parser.add_argument('--max_epochs', type=int, default=200, + help='maximum number of epochs to train (default: 200)') +parser.add_argument('--lr', type=float, default=1e-4, + help='learning rate (default: 0.0001)') +parser.add_argument('--label_frac', type=float, default=1.0, + help='fraction of training labels (default: 1.0)') +parser.add_argument('--bag_weight', type=float, default=0.7, + help='clam: weight coefficient for bag-level loss (default: 0.7)') +parser.add_argument('--reg', type=float, default=1e-5, + help='weight decay (default: 1e-5)') +parser.add_argument('--seed', type=int, default=1, + help='random seed for reproducible experiment (default: 1)') +parser.add_argument('--k', type=int, default=10, help='number of folds (default: 10)') +parser.add_argument('--k_start', type=int, default=-1, help='start fold (default: -1, last fold)') +parser.add_argument('--k_end', type=int, default=-1, help='end fold (default: -1, first fold)') +parser.add_argument('--results_dir', default='./results', help='results directory (default: ./results)') +parser.add_argument('--split_dir', type=str, default=None, + help='manually specify the set of splits to use, ' + +'instead of infering from the task and label_frac argument (default: None)') +parser.add_argument('--log_data', action='store_true', default=False, help='log data using tensorboard') +parser.add_argument('--testing', action='store_true', default=False, help='debugging tool') +parser.add_argument('--subtyping', action='store_true', default=False, help='subtyping problem') +parser.add_argument('--early_stopping', action='store_true', default=False, help='enable early stopping') +parser.add_argument('--opt', type=str, choices = ['adam', 'sgd'], default='adam') +parser.add_argument('--drop_out', action='store_true', default=False, help='enabel dropout (p=0.25)') +parser.add_argument('--inst_loss', type=str, choices=['svm', 'ce', None], default=None, + help='instance-level clustering loss function (default: None)') +parser.add_argument('--bag_loss', type=str, choices=['svm', 'ce'], default='ce', + help='slide-level classification loss function (default: ce)') +parser.add_argument('--model_type', type=str, choices=['clam', 'mil', 'clam_simple', 'attention_mil', 'histogram_mil'], default='attention_mil', help='type of model (default: attention_mil)') +parser.add_argument('--exp_code', type=str, help='experiment code for saving results') +parser.add_argument('--weighted_sample', action='store_true', default=False, help='enable weighted sampling') +parser.add_argument('--model_size', type=str, choices=['small', 'big'], default='big', help='size of model') +parser.add_argument('--mtl', action='store_true', default=False, help='flag to enable multi-task problem') +parser.add_argument('--task', type=str, choices=['cardiac-grade','cardiac-mtl']) + + +args = parser.parse_args() +device=torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def seed_torch(seed=7): + import random + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if device.type == 'cuda': + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + +seed_torch(args.seed) + +encoding_size = 1024 +settings = {'num_splits': args.k, + 'k_start': args.k_start, + 'k_end': args.k_end, + 'task': args.task, + 'max_epochs': args.max_epochs, + 'results_dir': args.results_dir, + 'lr': args.lr, + 'experiment': args.exp_code, + 'reg': args.reg, + 'label_frac': args.label_frac, + 'inst_loss': args.inst_loss, + 'bag_loss': args.bag_loss, + 'bag_weight': args.bag_weight, + 'seed': args.seed, + 'model_type': args.model_type, + 'model_size': args.model_size, + "use_drop_out": args.drop_out, + 'weighted_sample': args.weighted_sample, + 'opt': args.opt} + + +print('\nLoad Dataset') +if args.task == 'cardiac-grade': + args.n_classes=2 + dataset = Generic_MIL_Dataset(csv_path = 'dataset_csv/CardiacDummy_Grade.csv', + data_dir= os.path.join(args.data_root_dir, 'features'), + shuffle = False, + seed = args.seed, + print_info = True, + label_dict = {'low':0, 'high':1}, + label_cols=['label_grade'], + patient_strat=False, + ignore=[]) + + +elif args.task == 'cardiac-mtl': + args.n_classes=[2,2,2] + dataset = Generic_MIL_MTL_Dataset(csv_path = 'dataset_csv/CardiacDummy_MTL.csv', + data_dir= os.path.join(args.data_root_dir, 'features'), + shuffle = False, + seed = args.seed, + print_info = True, + label_dicts = [{'no_cell':0, 'cell':1}, + {'no_amr':0, 'amr':1}, + {'no_quilty':0, 'quilty':1}], + label_cols=['label_cell','label_amr','label_quilty'], + patient_strat=False, + ignore=[]) + + +else: + raise NotImplementedError + +if not os.path.isdir(args.results_dir): + os.mkdir(args.results_dir) + +args.results_dir = os.path.join(args.results_dir, str(args.exp_code) + '_s{}'.format(args.seed)) +if not os.path.isdir(args.results_dir): + os.mkdir(args.results_dir) + +if args.split_dir is None: + args.split_dir = os.path.join('splits', args.task+'_{}'.format(int(args.label_frac*100))) + +else: + args.split_dir = os.path.join('splits', args.split_dir) +assert os.path.isdir(args.split_dir) + +settings.update({'split_dir': args.split_dir}) + + +with open(args.results_dir + '/experiment_{}.txt'.format(args.exp_code), 'w') as f: + print(settings, file=f) +f.close() + +print("################# Settings ###################") +for key, val in settings.items(): + print("{}: {}".format(key, val)) + +if __name__ == "__main__": + if args.mtl: + results = main_mtl(args) + else: + results = main_grade(args) + + print("finished!") + print("end script")