Diff of /Cluster-ViT/main.py [000000] .. [15fc01]

Switch to unified view

a b/Cluster-ViT/main.py
1
import argparse
2
import datetime
3
import json
4
import random
5
from re import A
6
import time
7
from pathlib import Path
8
import os
9
10
from torch.utils.data.dataset import Subset
11
from datasets.MyData import MyDataset
12
import numpy as np
13
import torch
14
from torch.utils.data import DataLoader, DistributedSampler,random_split
15
from sklearn.model_selection import KFold
16
import util.misc as utils
17
18
from models.engine import evaluate, train_one_epoch_SAM,test
19
from models import build_model
20
from models.sam import SAM
21
from torch.utils.tensorboard import SummaryWriter
22
import pandas as pd
23
def get_args_parser():
24
    parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
25
    parser.add_argument('--lr', default=1e-4, type=float)
26
    parser.add_argument('--batch_size', default=8, type=int)
27
    parser.add_argument('--weight_decay', default=1e-4, type=float)
28
    parser.add_argument('--epochs', default=300, type=int)
29
    parser.add_argument('--lr_drop', default=200, type=int)
30
    parser.add_argument('--clip_max_norm', default=0.1, type=float,
31
                        help='gradient clipping max norm')
32
33
    parser.add_argument('--position_embedding', default='3Dlearned', type=str,
34
                        help="Type of positional embedding to use on top of the image features")
35
36
    # * Transformer
37
    parser.add_argument('--enc_layers', default=6, type=int,
38
                        help="Number of encoding layers in the transformer")
39
    parser.add_argument('--dec_layers', default=6, type=int,
40
                        help="Number of decoding layers in the transformer")
41
    parser.add_argument('--dim_feedforward', default=2048, type=int,
42
                        help="Intermediate size of the feedforward layers in the transformer blocks")
43
    parser.add_argument('--hidden_dim', default=256, type=int,
44
                        help="Size of the embeddings (dimension of the transformer)")
45
    parser.add_argument('--dropout', default=0.1, type=float,
46
                        help="Dropout applied in the transformer")
47
    parser.add_argument('--nheads', default=8, type=int,
48
                        help="Number of attention heads inside the transformer's attentions")
49
    parser.add_argument('--num_queries', default=100, type=int,
50
                        help="Number of query slots")
51
    parser.add_argument('--pre_norm', action='store_true',default=False)
52
    parser.add_argument('--pretrained_path', default='', type=str, help="path of pretrained model")
53
54
    # dataset parameters
55
    parser.add_argument('--dataset_file', default='coco')
56
57
58
    parser.add_argument('--output_dir', default='./',
59
                        help='path where to save, empty for no saving')
60
    parser.add_argument('--device', default='cuda',
61
                        help='device to use for training / testing')
62
    parser.add_argument('--seed', default=42, type=int)
63
    parser.add_argument('--resume', default='', help='resume from checkpoint')
64
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
65
                        help='start epoch')
66
    parser.add_argument('--eval', action='store_true')
67
    parser.add_argument('--num_workers', default=2, type=int)
68
    parser.add_argument('--kfoldNum', default=5, type=int)
69
    parser.add_argument('--dataDir',type=str,help='path of the data')
70
    parser.add_argument('--externalDataDir',type=str,help='path of the external test data')
71
    
72
73
74
    # distributed training parameters
75
    parser.add_argument('--world_size', default=1, type=int,
76
                        help='number of distributed processes')
77
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
78
    
79
    # cluster parameters
80
    parser.add_argument('--group_Q', action='store_true', default=False)
81
    parser.add_argument('--group_K', action='store_true', default=False)
82
    parser.add_argument('--cuda-devices', default=None)
83
    parser.add_argument('--max_num_cluster', default=64, type=int)
84
    parser.add_argument('--sequence_len', default=15000, type=int)
85
86
    # gridsearch parameters
87
    parser.add_argument('--withPosEmbedding', action='store_true', default=False)
88
    parser.add_argument('--seq_pool', action='store_true', default=False,help='use attention pooling layer for aggregating patch risk score') 
89
    parser.add_argument('--withLN', action='store_true', default=False)
90
    parser.add_argument('--withEmbeddingPreNorm', action='store_true', default=False,help='Pre-normalize the patch representation before feeding them into ViT')
91
    parser.add_argument('--input_pool', action='store_true', default=False)
92
    parser.add_argument('--mixUp', action='store_true', default=False)
93
    parser.add_argument('--SAM', action='store_true', default=False)
94
95
96
    return parser
97
98
def main(args):
99
    allDataset = MyDataset(root_dir=args.dataDir,sequence_len=args.sequence_len,max_num_cluster=args.max_num_cluster,status = 'test',input_pool=args.input_pool)
100
    kfoldSplits=KFold(n_splits=args.kfoldNum,shuffle=True,random_state=args.seed)  
101
    splitIdx = kfoldSplits.split(np.arange(len(allDataset)))
102
    CIndexTest = []
103
    CIndexExternalTest = []
104
    IBSTest = []
105
    IBSExternalTest = []
106
    correlationCoeffTest = []
107
    IPCWCIndexTest = []
108
    IPCWCIndexExternalTest = []
109
    for fold, (train_idx,test_idx) in enumerate(splitIdx):
110
        utils.init_distributed_mode(args)
111
        print(args)
112
        device = torch.device(args.device)
113
114
        # fix the seed for reproducibility
115
        seed = args.seed + utils.get_rank()
116
        torch.manual_seed(seed)
117
        np.random.seed(seed)
118
        random.seed(seed)
119
120
        model, criterion = build_model(args)
121
        model.to(device)
122
123
        model_without_ddp = model
124
        if args.distributed:
125
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
126
            model_without_ddp = model.module
127
        n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
128
        print('number of params:', n_parameters)
129
        if args.pretrained_path!='':
130
            pretrainedModel = torch.load(args.pretrained_path)
131
            model_without_ddp.load_state_dict(pretrainedModel['model'], strict=False)
132
133
        if args.SAM:
134
            base_optimizer = torch.optim.Adam # define an optimizer for the "sharpness-aware" update
135
            optimizer_SAM = SAM(model_without_ddp.parameters(), base_optimizer,lr=args.lr,weight_decay=args.weight_decay)
136
            optimizer = optimizer_SAM
137
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) 
138
        else:
139
            optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr,
140
                                          weight_decay=args.weight_decay)
141
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
142
143
        best_loss = 1e12
144
        tb_writer = SummaryWriter()
145
        dataset_train,dataset_val = random_split(Subset(allDataset,train_idx),[int(len(train_idx)*0.8),len(train_idx)-int(len(train_idx)*0.8)],generator=torch.Generator().manual_seed(args.seed))
146
        dataset_test = Subset(allDataset,test_idx)
147
        dataset_train_all = Subset(allDataset,train_idx)
148
        
149
        if args.distributed:
150
            sampler_train = DistributedSampler(dataset_train)
151
            sampler_val = DistributedSampler(dataset_val, shuffle=False)
152
            sampler_test = DistributedSampler(dataset_test, shuffle=False)
153
        else:
154
            sampler_train = torch.utils.data.RandomSampler(dataset_train)
155
            sampler_val = torch.utils.data.SequentialSampler(dataset_val)
156
            sampler_test = torch.utils.data.SequentialSampler(dataset_test)
157
158
        batch_sampler_train = torch.utils.data.BatchSampler(
159
            sampler_train, args.batch_size, drop_last=False)
160
161
        
162
        data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
163
                                    collate_fn=None, num_workers=args.num_workers)
164
        data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
165
                                    drop_last=False,num_workers=args.num_workers)
166
        data_loader_test = DataLoader(dataset_test, args.batch_size, sampler=sampler_test,
167
                                    drop_last=False, num_workers=args.num_workers)
168
        
169
        output_dir = Path(args.output_dir)
170
        if args.kfoldNum>1:
171
            output_dir = output_dir/f'fold{fold}'
172
            output_dir.mkdir(parents=True, exist_ok=True)
173
            
174
        if args.resume:
175
            checkpoint = torch.load(args.resume, map_location='cpu')
176
            model_without_ddp.load_state_dict(checkpoint['model'])
177
            if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
178
                optimizer.load_state_dict(checkpoint['optimizer'])
179
                lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
180
                args.start_epoch = checkpoint['epoch'] + 1      
181
182
        if args.eval:
183
            val_stats = evaluate(model, criterion, data_loader_train, device, args.output_dir,'Validation')
184
185
        print(f"fold {fold}/{args.kfoldNum} Start training")
186
        start_time = time.time()
187
        # train the model 
188
        for epoch in range(args.start_epoch, args.epochs):
189
            if args.distributed:
190
                sampler_train.set_epoch(epoch)
191
192
            train_stats = train_one_epoch_SAM(
193
                model, criterion, data_loader_train, optimizer, device, epoch,fold,tb_writer,
194
                args.clip_max_norm,mixUp=args.mixUp,SAM=args.SAM)    
195
196
            lr_scheduler.step()
197
198
            train_eval_stats = evaluate(
199
                model, criterion, data_loader_train, device, args.output_dir,'Train')
200
            val_stats = evaluate(
201
                model, criterion, data_loader_val, device, args.output_dir, 'Validation')
202
            test_stats = evaluate(
203
                model, criterion, data_loader_test, device, args.output_dir, 'Test')
204
205
            is_best = val_stats['loss'] < best_loss
206
            best_loss = min(val_stats['loss'], best_loss)            
207
208
            if args.output_dir:
209
                checkpoint_paths = [output_dir / 'checkpoint.pth']
210
                # extra checkpoint before LR drop and every 100 epochs
211
                if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
212
                    checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
213
                if is_best:
214
                    checkpoint_paths.append(output_dir / f'model_best.pth.tar')
215
                for checkpoint_path in checkpoint_paths:
216
                    utils.save_on_master({
217
                        'model': model_without_ddp.state_dict(),
218
                        'optimizer': optimizer.state_dict(),
219
                        'lr_scheduler': lr_scheduler.state_dict(),
220
                        'epoch': epoch,
221
                        'args': args,
222
                    }, checkpoint_path)
223
224
            log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
225
                        **{f'train_eval{k}': v for k, v in train_eval_stats.items()},
226
                        **{f'val_{k}': v for k, v in val_stats.items()},
227
                        **{f'test_{k}': v for k, v in test_stats.items()},
228
                        'epoch': epoch,
229
                        'n_parameters': n_parameters}
230
231
            temploss = train_eval_stats
232
            tb_writer.add_scalar('train/loss'+'fold{}'.format(fold), temploss['loss'], epoch)
233
            tb_writer.add_scalar('train/CIndex'+'fold{}'.format(fold), temploss['CIndex'], epoch)                        
234
            temploss = val_stats
235
            tb_writer.add_scalar('val/loss'+'fold{}'.format(fold), temploss['loss'], epoch)
236
            tb_writer.add_scalar('val/CIndex'+'fold{}'.format(fold), temploss['CIndex'], epoch) 
237
            temploss = test_stats
238
            tb_writer.add_scalar('test/loss'+'fold{}'.format(fold), temploss['loss'], epoch)
239
            tb_writer.add_scalar('test/CIndex'+'fold{}'.format(fold), temploss['CIndex'], epoch)
240
241
            if args.output_dir and utils.is_main_process():
242
                with (output_dir / f"trainingLog_fold{fold}.txt").open("a") as f:
243
                    f.write(json.dumps(log_stats) + "\n")  
244
        # evaluate the best model on internal and external datasets
245
        bestModelPath =  output_dir / f'model_best.pth.tar' 
246
        bestCheckpoint = torch.load(bestModelPath)
247
        bestModel, testcriterion = build_model(args)
248
        bestModel.to(device)
249
        bestModel.load_state_dict(bestCheckpoint['model'])
250
        data_loader_train_all = DataLoader(dataset_train_all, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
251
        dataset_external_test = MyDataset(root_dir=externalDataDir,sequence_len=args.sequence_len,max_num_cluster=args.max_num_cluster,status='externalTest',input_pool=args.input_pool)
252
        dataset_external_test.status = 'externalTest'
253
        data_loader_external_test = DataLoader(dataset_external_test, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
254
        externalOutputDir = output_dir / 'externalTest'
255
        Path(externalOutputDir).mkdir(parents=True, exist_ok=True)
256
        internalTestBestModelStatus = test(bestModel, testcriterion, data_loader_test, data_loader_train_all, device, output_dir,fold,coxBiomarkerRisk)
257
        externalTestBestModelStatus = test(bestModel, testcriterion, data_loader_external_test, data_loader_train_all, device, externalOutputDir,fold)
258
        log_stats = {
259
            **{f'testBestModel_{k}': v for k, v in internalTestBestModelStatus.items()},
260
            **{f'ExternalTestBestModel_{k}': v for k, v in externalTestBestModelStatus.items()},
261
            'fold': fold
262
            }
263
264
        CIndexTest.append(internalTestBestModelStatus['CIndex'])
265
        CIndexExternalTest.append(externalTestBestModelStatus['CIndex'])
266
        IPCWCIndexTest.append(internalTestBestModelStatus['IPCWCIndex'])
267
        IPCWCIndexExternalTest.append(externalTestBestModelStatus['IPCWCIndex'])
268
        IBSTest.append(internalTestBestModelStatus['IBSTest'])
269
        IBSExternalTest.append(externalTestBestModelStatus['IBSTest'])
270
271
        if args.output_dir and utils.is_main_process():
272
            with (output_dir / f"testingLog_fold{fold}.txt").open("a") as f:
273
                f.write(json.dumps(log_stats) + "\n")  
274
                
275
        total_time = time.time() - start_time
276
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
277
        print('Training time {}'.format(total_time_str))
278
279
    output_dir = Path(args.output_dir)
280
    AverageBestModelStatus = {'AverageCIndexTest':np.mean(CIndexTest), 'AverageCIndexExternalTest':np.mean(CIndexExternalTest),\
281
        'AverageIPCWCIndexTest':np.mean(IPCWCIndexTest), 'AverageIPCWCIndexExternalTest':np.mean(IPCWCIndexExternalTest),\
282
        'AverageIBSTest':np.mean(IBSTest),'AverageIBSExternalTest':np.mean(IBSExternalTest),\
283
        'stdCIndexTest':np.std(CIndexTest), 'stdCIndexExternalTest':np.std(CIndexExternalTest),\
284
        'stdIPCWCIndexTest':np.std(IPCWCIndexTest), 'stdIPCWCIndexExternalTest':np.std(IPCWCIndexExternalTest),\
285
        'stdIBSTest':np.std(IBSTest),'stdIBSExternalTest':np.std(IBSExternalTest),'AverageCorrelationCoeffTest':np.mean(correlationCoeffTest)}
286
    log_stats = {
287
        **{f'{k}': v for k, v in AverageBestModelStatus.items()}
288
        }    
289
    if args.output_dir and utils.is_main_process():
290
        with (output_dir / "logAverage.txt").open("a") as f:
291
            f.write(json.dumps(log_stats) + "\n")      
292
    print(AverageBestModelStatus)
293
294
if __name__ == '__main__':
295
    now = datetime.datetime.now()
296
    dt_string = now.strftime("%d%m%Y_%H%M%S")    
297
    parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
298
    args = parser.parse_args()
299
300
    if args.cuda_devices is not None:
301
        os.environ["CUDA_VISIBLE_DEVICES"]=args.cuda_devices
302
    if args.output_dir:
303
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
304
    hyperparameters = vars(args)
305
    hyperparameter_stas = {
306
        **{f'{k}': v for k, v in hyperparameters.items()}
307
        }
308
    if args.output_dir and utils.is_main_process():
309
        with (Path(args.output_dir) / "logHyperparameters.txt").open("a") as f:
310
            f.write(json.dumps(hyperparameter_stas) + "\n")      
311
    main(args)