import argparse
import datetime
import json
import random
from re import A
import time
from pathlib import Path
import os
from torch.utils.data.dataset import Subset
from datasets.MyData import MyDataset
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler,random_split
from sklearn.model_selection import KFold
import util.misc as utils
from models.engine import evaluate, train_one_epoch_SAM,test
from models import build_model
from models.sam import SAM
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
def get_args_parser():
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--batch_size', default=8, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=300, type=int)
parser.add_argument('--lr_drop', default=200, type=int)
parser.add_argument('--clip_max_norm', default=0.1, type=float,
help='gradient clipping max norm')
parser.add_argument('--position_embedding', default='3Dlearned', type=str,
help="Type of positional embedding to use on top of the image features")
# * Transformer
parser.add_argument('--enc_layers', default=6, type=int,
help="Number of encoding layers in the transformer")
parser.add_argument('--dec_layers', default=6, type=int,
help="Number of decoding layers in the transformer")
parser.add_argument('--dim_feedforward', default=2048, type=int,
help="Intermediate size of the feedforward layers in the transformer blocks")
parser.add_argument('--hidden_dim', default=256, type=int,
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_queries', default=100, type=int,
help="Number of query slots")
parser.add_argument('--pre_norm', action='store_true',default=False)
parser.add_argument('--pretrained_path', default='', type=str, help="path of pretrained model")
# dataset parameters
parser.add_argument('--dataset_file', default='coco')
parser.add_argument('--output_dir', default='./',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true')
parser.add_argument('--num_workers', default=2, type=int)
parser.add_argument('--kfoldNum', default=5, type=int)
parser.add_argument('--dataDir',type=str,help='path of the data')
parser.add_argument('--externalDataDir',type=str,help='path of the external test data')
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
# cluster parameters
parser.add_argument('--group_Q', action='store_true', default=False)
parser.add_argument('--group_K', action='store_true', default=False)
parser.add_argument('--cuda-devices', default=None)
parser.add_argument('--max_num_cluster', default=64, type=int)
parser.add_argument('--sequence_len', default=15000, type=int)
# gridsearch parameters
parser.add_argument('--withPosEmbedding', action='store_true', default=False)
parser.add_argument('--seq_pool', action='store_true', default=False,help='use attention pooling layer for aggregating patch risk score')
parser.add_argument('--withLN', action='store_true', default=False)
parser.add_argument('--withEmbeddingPreNorm', action='store_true', default=False,help='Pre-normalize the patch representation before feeding them into ViT')
parser.add_argument('--input_pool', action='store_true', default=False)
parser.add_argument('--mixUp', action='store_true', default=False)
parser.add_argument('--SAM', action='store_true', default=False)
return parser
def main(args):
allDataset = MyDataset(root_dir=args.dataDir,sequence_len=args.sequence_len,max_num_cluster=args.max_num_cluster,status = 'test',input_pool=args.input_pool)
kfoldSplits=KFold(n_splits=args.kfoldNum,shuffle=True,random_state=args.seed)
splitIdx = kfoldSplits.split(np.arange(len(allDataset)))
CIndexTest = []
CIndexExternalTest = []
IBSTest = []
IBSExternalTest = []
correlationCoeffTest = []
IPCWCIndexTest = []
IPCWCIndexExternalTest = []
for fold, (train_idx,test_idx) in enumerate(splitIdx):
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
model, criterion = build_model(args)
model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
if args.pretrained_path!='':
pretrainedModel = torch.load(args.pretrained_path)
model_without_ddp.load_state_dict(pretrainedModel['model'], strict=False)
if args.SAM:
base_optimizer = torch.optim.Adam # define an optimizer for the "sharpness-aware" update
optimizer_SAM = SAM(model_without_ddp.parameters(), base_optimizer,lr=args.lr,weight_decay=args.weight_decay)
optimizer = optimizer_SAM
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
else:
optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
best_loss = 1e12
tb_writer = SummaryWriter()
dataset_train,dataset_val = random_split(Subset(allDataset,train_idx),[int(len(train_idx)*0.8),len(train_idx)-int(len(train_idx)*0.8)],generator=torch.Generator().manual_seed(args.seed))
dataset_test = Subset(allDataset,test_idx)
dataset_train_all = Subset(allDataset,train_idx)
if args.distributed:
sampler_train = DistributedSampler(dataset_train)
sampler_val = DistributedSampler(dataset_val, shuffle=False)
sampler_test = DistributedSampler(dataset_test, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=False)
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
collate_fn=None, num_workers=args.num_workers)
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
drop_last=False,num_workers=args.num_workers)
data_loader_test = DataLoader(dataset_test, args.batch_size, sampler=sampler_test,
drop_last=False, num_workers=args.num_workers)
output_dir = Path(args.output_dir)
if args.kfoldNum>1:
output_dir = output_dir/f'fold{fold}'
output_dir.mkdir(parents=True, exist_ok=True)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.eval:
val_stats = evaluate(model, criterion, data_loader_train, device, args.output_dir,'Validation')
print(f"fold {fold}/{args.kfoldNum} Start training")
start_time = time.time()
# train the model
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch_SAM(
model, criterion, data_loader_train, optimizer, device, epoch,fold,tb_writer,
args.clip_max_norm,mixUp=args.mixUp,SAM=args.SAM)
lr_scheduler.step()
train_eval_stats = evaluate(
model, criterion, data_loader_train, device, args.output_dir,'Train')
val_stats = evaluate(
model, criterion, data_loader_val, device, args.output_dir, 'Validation')
test_stats = evaluate(
model, criterion, data_loader_test, device, args.output_dir, 'Test')
is_best = val_stats['loss'] < best_loss
best_loss = min(val_stats['loss'], best_loss)
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
if is_best:
checkpoint_paths.append(output_dir / f'model_best.pth.tar')
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args,
}, checkpoint_path)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'train_eval{k}': v for k, v in train_eval_stats.items()},
**{f'val_{k}': v for k, v in val_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
temploss = train_eval_stats
tb_writer.add_scalar('train/loss'+'fold{}'.format(fold), temploss['loss'], epoch)
tb_writer.add_scalar('train/CIndex'+'fold{}'.format(fold), temploss['CIndex'], epoch)
temploss = val_stats
tb_writer.add_scalar('val/loss'+'fold{}'.format(fold), temploss['loss'], epoch)
tb_writer.add_scalar('val/CIndex'+'fold{}'.format(fold), temploss['CIndex'], epoch)
temploss = test_stats
tb_writer.add_scalar('test/loss'+'fold{}'.format(fold), temploss['loss'], epoch)
tb_writer.add_scalar('test/CIndex'+'fold{}'.format(fold), temploss['CIndex'], epoch)
if args.output_dir and utils.is_main_process():
with (output_dir / f"trainingLog_fold{fold}.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# evaluate the best model on internal and external datasets
bestModelPath = output_dir / f'model_best.pth.tar'
bestCheckpoint = torch.load(bestModelPath)
bestModel, testcriterion = build_model(args)
bestModel.to(device)
bestModel.load_state_dict(bestCheckpoint['model'])
data_loader_train_all = DataLoader(dataset_train_all, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
dataset_external_test = MyDataset(root_dir=externalDataDir,sequence_len=args.sequence_len,max_num_cluster=args.max_num_cluster,status='externalTest',input_pool=args.input_pool)
dataset_external_test.status = 'externalTest'
data_loader_external_test = DataLoader(dataset_external_test, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
externalOutputDir = output_dir / 'externalTest'
Path(externalOutputDir).mkdir(parents=True, exist_ok=True)
internalTestBestModelStatus = test(bestModel, testcriterion, data_loader_test, data_loader_train_all, device, output_dir,fold,coxBiomarkerRisk)
externalTestBestModelStatus = test(bestModel, testcriterion, data_loader_external_test, data_loader_train_all, device, externalOutputDir,fold)
log_stats = {
**{f'testBestModel_{k}': v for k, v in internalTestBestModelStatus.items()},
**{f'ExternalTestBestModel_{k}': v for k, v in externalTestBestModelStatus.items()},
'fold': fold
}
CIndexTest.append(internalTestBestModelStatus['CIndex'])
CIndexExternalTest.append(externalTestBestModelStatus['CIndex'])
IPCWCIndexTest.append(internalTestBestModelStatus['IPCWCIndex'])
IPCWCIndexExternalTest.append(externalTestBestModelStatus['IPCWCIndex'])
IBSTest.append(internalTestBestModelStatus['IBSTest'])
IBSExternalTest.append(externalTestBestModelStatus['IBSTest'])
if args.output_dir and utils.is_main_process():
with (output_dir / f"testingLog_fold{fold}.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
output_dir = Path(args.output_dir)
AverageBestModelStatus = {'AverageCIndexTest':np.mean(CIndexTest), 'AverageCIndexExternalTest':np.mean(CIndexExternalTest),\
'AverageIPCWCIndexTest':np.mean(IPCWCIndexTest), 'AverageIPCWCIndexExternalTest':np.mean(IPCWCIndexExternalTest),\
'AverageIBSTest':np.mean(IBSTest),'AverageIBSExternalTest':np.mean(IBSExternalTest),\
'stdCIndexTest':np.std(CIndexTest), 'stdCIndexExternalTest':np.std(CIndexExternalTest),\
'stdIPCWCIndexTest':np.std(IPCWCIndexTest), 'stdIPCWCIndexExternalTest':np.std(IPCWCIndexExternalTest),\
'stdIBSTest':np.std(IBSTest),'stdIBSExternalTest':np.std(IBSExternalTest),'AverageCorrelationCoeffTest':np.mean(correlationCoeffTest)}
log_stats = {
**{f'{k}': v for k, v in AverageBestModelStatus.items()}
}
if args.output_dir and utils.is_main_process():
with (output_dir / "logAverage.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
print(AverageBestModelStatus)
if __name__ == '__main__':
now = datetime.datetime.now()
dt_string = now.strftime("%d%m%Y_%H%M%S")
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
if args.cuda_devices is not None:
os.environ["CUDA_VISIBLE_DEVICES"]=args.cuda_devices
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
hyperparameters = vars(args)
hyperparameter_stas = {
**{f'{k}': v for k, v in hyperparameters.items()}
}
if args.output_dir and utils.is_main_process():
with (Path(args.output_dir) / "logHyperparameters.txt").open("a") as f:
f.write(json.dumps(hyperparameter_stas) + "\n")
main(args)