from __future__ import print_function
import argparse
import pdb
import os
import math
# internal imports
from utils.file_utils import save_pkl, load_pkl
from utils.utils import *
from utils.core_utils import train
from utils.core_utils_mtl import train as train_mtl
from datasets.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset
from datasets.dataset_mtl import Generic_WSI_MTL_Dataset, Generic_MIL_MTL_Dataset
# pytorch imports
import torch
from torch.utils.data import DataLoader, sampler
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
# Rejection grade:
# binary classifier:
# class 0 - low grade
# class 1 - high grade
#-------------------------------
def main_grade(args):
print("-----------------------------------------")
print(" Grade Net (single task binary classifier")
print("-----------------------------------------")
# create results directory if necessary
if not os.path.isdir(args.results_dir):
os.mkdir(args.results_dir)
if args.k_start == -1:
start = 0
else:
start = args.k_start
if args.k_end == -1:
end = args.k
else:
end = args.k_end
all_test_auc = []
all_val_auc = []
all_test_acc = []
all_val_acc = []
folds = np.arange(start, end)
for i in folds:
seed_torch(args.seed)
train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False,
csv_path='{}/splits_{}.csv'.format(args.split_dir, i))
datasets = (train_dataset, val_dataset, test_dataset)
results, test_auc, val_auc, test_acc, val_acc = train(datasets, i, args)
all_test_auc.append(test_auc)
all_val_auc.append(val_auc)
all_test_acc.append(test_acc)
all_val_acc.append(val_acc)
#write results to pkl
filename = os.path.join(args.results_dir, 'split_{}_results.pkl'.format(i))
save_pkl(filename, results)
final_df = pd.DataFrame({'folds': folds, 'test_auc': all_test_auc,
'val_auc': all_val_auc, 'test_acc': all_test_acc, 'val_acc' : all_val_acc})
if len(folds) != args.k:
save_name = 'summary_partial_{}_{}.csv'.format(start, end)
else:
save_name = 'summary.csv'
final_df.to_csv(os.path.join(args.results_dir, save_name))
# Multi task classifier for EMB evaluation:
# consist of 3 simultaneous tasks:
# task1: cellular vs non-cellular
# task2: antibody vs non-antibody
# task3: quilty lesion vs no quilty lesion
#-------------------------------------------
def main_mtl(args):
print("----------------------------------------")
print(" EMB assessment - multi task classifier ")
print("----------------------------------------")
# create results directory if necessary
if not os.path.isdir(args.results_dir):
os.mkdir(args.results_dir)
if args.k_start == -1:
start = 0
else:
start = args.k_start
if args.k_end == -1:
end = args.k
else:
end = args.k_end
# arrays to collect scores -- replace by generic one when refactoring
all_task1_test_auc = []
all_task1_val_auc = []
all_task1_test_acc = []
all_task1_val_acc = []
all_task2_test_auc = []
all_task2_val_auc = []
all_task2_test_acc = []
all_task2_val_acc = []
all_task3_test_auc = []
all_task3_val_auc = []
all_task3_test_acc = []
all_task3_val_acc = []
folds = np.arange(start, end)
for i in folds:
seed_torch(args.seed)
train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False,
csv_path='{}/splits_{}.csv'.format(args.split_dir, i))
print('training: {}, validation: {}, testing: {}'.format(len(train_dataset), len(val_dataset), len(test_dataset)))
datasets = (train_dataset, val_dataset, test_dataset)
results, \
task1_test_auc, task1_val_auc, task1_test_acc, task1_val_acc, \
task2_test_auc, task2_val_auc, task2_test_acc, task2_val_acc, \
task3_test_auc, task3_val_auc, task3_test_acc, task3_val_acc = train_mtl(datasets, i, args)
all_task1_test_auc.append(task1_test_auc)
all_task1_val_auc.append( task1_val_auc )
all_task1_test_acc.append(task1_test_acc)
all_task1_val_acc.append( task1_val_acc )
all_task2_test_auc.append(task2_test_auc)
all_task2_val_auc.append( task2_val_auc )
all_task2_test_acc.append(task2_test_acc)
all_task2_val_acc.append( task2_val_acc )
all_task3_test_auc.append(task3_test_auc)
all_task3_val_auc.append( task3_val_auc )
all_task3_test_acc.append(task3_test_acc)
all_task3_val_acc.append( task3_val_acc )
#write results to pkl
filename = os.path.join(args.results_dir, 'split_{}_results.pkl'.format(i))
save_pkl(filename, results)
final_df = pd.DataFrame({'folds': folds,
'task1_test_auc': all_task1_test_auc, 'task1_val_auc': all_task1_val_auc,
'task1_test_acc': all_task1_test_acc, 'task1_val_acc': all_task1_val_acc,
'task2_test_auc': all_task2_test_auc, 'task2_val_auc': all_task2_val_auc,
'task2_test_acc': all_task2_test_acc, 'task2_val_acc': all_task2_val_acc,
'task3_test_auc': all_task3_test_auc, 'task3_val_auc': all_task3_val_auc,
'task3_test_acc': all_task3_test_acc, 'task3_val_acc': all_task3_val_acc})
if len(folds) != args.k:
save_name = 'summary_partial_{}_{}.csv'.format(start, end)
else:
save_name = 'summary.csv'
final_df.to_csv(os.path.join(args.results_dir, save_name))
# Training settings
parser = argparse.ArgumentParser(description='Configurations for WSI Training')
parser.add_argument('--data_root_dir', type=str, default='/media/fedshyvana/ssd1',
help='data directory')
parser.add_argument('--max_epochs', type=int, default=200,
help='maximum number of epochs to train (default: 200)')
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate (default: 0.0001)')
parser.add_argument('--label_frac', type=float, default=1.0,
help='fraction of training labels (default: 1.0)')
parser.add_argument('--bag_weight', type=float, default=0.7,
help='clam: weight coefficient for bag-level loss (default: 0.7)')
parser.add_argument('--reg', type=float, default=1e-5,
help='weight decay (default: 1e-5)')
parser.add_argument('--seed', type=int, default=1,
help='random seed for reproducible experiment (default: 1)')
parser.add_argument('--k', type=int, default=10, help='number of folds (default: 10)')
parser.add_argument('--k_start', type=int, default=-1, help='start fold (default: -1, last fold)')
parser.add_argument('--k_end', type=int, default=-1, help='end fold (default: -1, first fold)')
parser.add_argument('--results_dir', default='./results', help='results directory (default: ./results)')
parser.add_argument('--split_dir', type=str, default=None,
help='manually specify the set of splits to use, '
+'instead of infering from the task and label_frac argument (default: None)')
parser.add_argument('--log_data', action='store_true', default=False, help='log data using tensorboard')
parser.add_argument('--testing', action='store_true', default=False, help='debugging tool')
parser.add_argument('--subtyping', action='store_true', default=False, help='subtyping problem')
parser.add_argument('--early_stopping', action='store_true', default=False, help='enable early stopping')
parser.add_argument('--opt', type=str, choices = ['adam', 'sgd'], default='adam')
parser.add_argument('--drop_out', action='store_true', default=False, help='enabel dropout (p=0.25)')
parser.add_argument('--inst_loss', type=str, choices=['svm', 'ce', None], default=None,
help='instance-level clustering loss function (default: None)')
parser.add_argument('--bag_loss', type=str, choices=['svm', 'ce'], default='ce',
help='slide-level classification loss function (default: ce)')
parser.add_argument('--model_type', type=str, choices=['clam', 'mil', 'clam_simple', 'attention_mil', 'histogram_mil'], default='attention_mil', help='type of model (default: attention_mil)')
parser.add_argument('--exp_code', type=str, help='experiment code for saving results')
parser.add_argument('--weighted_sample', action='store_true', default=False, help='enable weighted sampling')
parser.add_argument('--model_size', type=str, choices=['small', 'big'], default='big', help='size of model')
parser.add_argument('--mtl', action='store_true', default=False, help='flag to enable multi-task problem')
parser.add_argument('--task', type=str, choices=['cardiac-grade','cardiac-mtl'])
args = parser.parse_args()
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
def seed_torch(seed=7):
import random
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device.type == 'cuda':
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
seed_torch(args.seed)
encoding_size = 1024
settings = {'num_splits': args.k,
'k_start': args.k_start,
'k_end': args.k_end,
'task': args.task,
'max_epochs': args.max_epochs,
'results_dir': args.results_dir,
'lr': args.lr,
'experiment': args.exp_code,
'reg': args.reg,
'label_frac': args.label_frac,
'inst_loss': args.inst_loss,
'bag_loss': args.bag_loss,
'bag_weight': args.bag_weight,
'seed': args.seed,
'model_type': args.model_type,
'model_size': args.model_size,
"use_drop_out": args.drop_out,
'weighted_sample': args.weighted_sample,
'opt': args.opt}
print('\nLoad Dataset')
if args.task == 'cardiac-grade':
args.n_classes=2
dataset = Generic_MIL_Dataset(csv_path = 'dataset_csv/CardiacDummy_Grade.csv',
data_dir= os.path.join(args.data_root_dir, 'features'),
shuffle = False,
seed = args.seed,
print_info = True,
label_dict = {'low':0, 'high':1},
label_cols=['label_grade'],
patient_strat=False,
ignore=[])
elif args.task == 'cardiac-mtl':
args.n_classes=[2,2,2]
dataset = Generic_MIL_MTL_Dataset(csv_path = 'dataset_csv/CardiacDummy_MTL.csv',
data_dir= os.path.join(args.data_root_dir, 'features'),
shuffle = False,
seed = args.seed,
print_info = True,
label_dicts = [{'no_cell':0, 'cell':1},
{'no_amr':0, 'amr':1},
{'no_quilty':0, 'quilty':1}],
label_cols=['label_cell','label_amr','label_quilty'],
patient_strat=False,
ignore=[])
else:
raise NotImplementedError
if not os.path.isdir(args.results_dir):
os.mkdir(args.results_dir)
args.results_dir = os.path.join(args.results_dir, str(args.exp_code) + '_s{}'.format(args.seed))
if not os.path.isdir(args.results_dir):
os.mkdir(args.results_dir)
if args.split_dir is None:
args.split_dir = os.path.join('splits', args.task+'_{}'.format(int(args.label_frac*100)))
else:
args.split_dir = os.path.join('splits', args.split_dir)
assert os.path.isdir(args.split_dir)
settings.update({'split_dir': args.split_dir})
with open(args.results_dir + '/experiment_{}.txt'.format(args.exp_code), 'w') as f:
print(settings, file=f)
f.close()
print("################# Settings ###################")
for key, val in settings.items():
print("{}: {}".format(key, val))
if __name__ == "__main__":
if args.mtl:
results = main_mtl(args)
else:
results = main_grade(args)
print("finished!")
print("end script")