Diff of /main.py [000000] .. [d5c425]

Switch to side-by-side view

--- a
+++ b/main.py
@@ -0,0 +1,313 @@
+import os
+import torch
+import pickle
+
+import numpy as np
+import torch.backends.cudnn as cudnn
+
+from torch.utils.data.dataloader import DataLoader
+from sklearn.model_selection import StratifiedKFold, train_test_split
+from tqdm.auto import tqdm
+# from torchviz import make_dot
+from losses import MultiTaskLoss, CoxLoss
+from datasets import RadDataset
+from models import FusionModelBi, Model
+from utils import *
+from parameters import parse_args
+import scipy.io
+import time as timetime
+# from monai.networks.nets import DenseNet121,HighResNet,SEResNext50
+
+import matplotlib.pyplot as plt 
+
+device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
+def one_epoch(args, split, model, optim, loader, criterion):
+    if split == "train":
+        model.train()
+    else:
+        model.eval()
+    total = 0
+    sum_loss = 0
+    all_preds_grade = []
+    all_preds_hazard = []
+    all_grade = []
+    all_time = []
+    all_event = []
+    all_ID = []
+
+    device = 'cuda:0'
+    for i, (mod1, mod2, grade, time, event, ID) in enumerate(loader):
+        if i%1==0:
+            print(f"Sample {i}/{len(loader)}")
+            # Display the four samples for each region. Just run it for a single batch and then exit the run to look at the saved images
+            # PT_image, LN_image = mod1[0], mod2[0]
+            # print("saving patient", ID[0], "to folder")
+            # for i in range(4):
+            #     plt.imsave(os.path.join(args.savedir, str(i)+"_PT.png"), PT_image[i,0,:,:])
+            #     plt.imsave(os.path.join(args.savedir, str(i)+"_LN.png"), LN_image[i,0,:,:])
+            
+            # print("-----------------------")
+        
+        model = model.to(device)
+        
+        mod1, mod2, grade, time, event = mod1.to(device), mod2.to(device), grade.to(device), time.to(device), event.to(device)
+        batch = mod1.shape[0]
+
+        pred = model(mod1, mod2)
+        
+        if args.batch_size==1:
+            if args.task == "multitask":
+                pred_grade, pred_hazard = pred
+            elif args.task == "classification":
+                pred_grade, pred_hazard = pred[0], torch.empty(1)
+            elif args.task == "survival":
+                pred_grade, pred_hazard = torch.empty(1), pred[0]
+            else:
+                raise NotImplementedError(
+                    f'task method {args.task} is not implemented')
+        else:
+            if args.task == "multitask":
+                pred_grade, pred_hazard = pred
+            elif args.task == "classification":
+                pred_grade, pred_hazard = pred.squeeze(), torch.empty(1)
+            elif args.task == "survival":
+                pred_grade, pred_hazard = torch.empty(1), pred.squeeze()
+            else:
+                raise NotImplementedError(
+                    f'task method {args.task} is not implemented')
+        loss_task = criterion(args.task, pred_grade, pred_hazard, grade, time, event)
+        loss = loss_task
+
+        
+        if split == 'train':
+            optim.zero_grad()
+            loss.backward()
+            optim.step()
+
+        total += batch
+        sum_loss += batch * (loss.item())
+        all_preds_grade.append(pred_grade)
+        all_preds_hazard.append(pred_hazard)
+        all_grade.append(grade)
+        all_time.append(time)
+        all_event.append(event)
+        all_ID.append(ID)
+
+    all_grade = torch.concat(all_grade)
+    all_time = torch.concat(all_time)
+    all_event = torch.concat(all_event)
+
+    if args.task == "classification" :
+        all_preds_grade = torch.concat(all_preds_grade)
+        return sum_loss / total, (all_preds_grade, None, all_grade, all_time, all_event, all_ID)
+    elif args.task == "multitask":
+        all_preds_grade = torch.concat(all_preds_grade)
+        all_preds_hazard = torch.concat(all_preds_hazard)
+        return sum_loss / total, (all_preds_grade, all_preds_hazard, all_grade, all_time, all_event, all_ID)
+    else: 
+        all_preds_hazard = torch.concat(all_preds_hazard)
+        return sum_loss / total, (None, all_preds_hazard, all_grade, all_time, all_event, all_ID)
+    
+def test(args, device):
+    model_name = args.fusion_type+'_'+args.task+'_'+str(args.n_epochs)+'_'+str(args.lr)
+    criterion = MultiTaskLoss()
+    data_test = extract_csv(os.path.join(
+        args.dataroot, "data_table_test.csv"))
+    
+    checkpoint = torch.load(os.path.join(args.checkpoints_dir, args.exp_name, model_name, f'{model_name}_best_val_cindex.pt'))
+
+    # Create an instance of the model
+    model = Model(args)
+
+    # Extract the 'epoch' from the loaded checkpoint
+    saved_epoch = checkpoint['epoch']
+
+    # Print or use the extracted epoch
+    print(f"The model is saved on epoch: {saved_epoch}")
+
+    # Load the model state from the checkpoint
+    model.load_state_dict(checkpoint['model_state_dict'])
+
+    model.to(device)
+    
+
+    test_set = RadDataset(data_test, args.dataroot, train_flag=False)
+    test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, collate_fn=custom_collate)
+
+    
+    train_loss = checkpoint['train_loss']
+    train_preds = checkpoint['train_pred']
+    val_loss = checkpoint['val_loss']
+    val_preds = checkpoint['val_pred']
+    test_loss, test_preds = one_epoch(args, "test", model, None, test_loader, criterion)
+
+    ci_train, _ = compute_metrics(args, train_preds)
+    ci_val, _ = compute_metrics(args, val_preds)
+    ci_test, _ = compute_metrics(args, test_preds)
+
+    print(
+        f"[Final] Apply model to training set: Loss = {train_loss}, C-Index = {ci_train}")
+    print(
+        f"[Final] Apply model to validation set: Loss = {val_loss}, C-Index = {ci_val}")
+    print(
+        f"[Final] Apply model to test set: Loss = {test_loss}, C-Index = {ci_test}")
+
+    pickle.dump(train_preds, open(os.path.join(args.checkpoints_dir, args.exp_name, model_name, 'pred_train.pkl'), 'wb'))
+    pickle.dump(val_preds, open(os.path.join(args.checkpoints_dir, args.exp_name, model_name, 'pred_val.pkl'), 'wb'))
+    pickle.dump(test_preds, open(os.path.join(args.checkpoints_dir, args.exp_name, model_name, 'pred_test.pkl'), 'wb'))
+    
+
+
+def train_model(args, data_train, data_val, model, criterion, optim, scheduler, device):
+
+    model_name = args.fusion_type+'_'+args.task+'_'+str(args.n_epochs)+'_'+str(args.lr)
+
+    torch.cuda.manual_seed_all(42)
+    torch.manual_seed(42)
+    np.random.seed(42)
+
+    train_set = RadDataset(
+        data_train, args.dataroot)
+    
+    val_set = RadDataset(data_val, args.dataroot, train_flag=False)
+
+
+
+    train_loader = DataLoader(
+    train_set, batch_size=args.batch_size, shuffle=True, collate_fn=custom_collate)
+    val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, collate_fn=custom_collate)
+    
+
+    metric_logger = {'train': {'loss': [], 'cindex': []},
+                     'val': {'loss': [], 'cindex': []}}
+
+    best_val_cindex = float('-inf')  # Initialize to negative infinity
+    # cudnn.deterministic = True
+    for epoch in tqdm(range(args.epoch_count, args.niter+args.n_epochs+1)):
+
+        print(device)
+        loss, preds = one_epoch(args,
+                                "train", model, optim, train_loader, criterion)
+        scheduler.step()
+        vloss, vpreds = one_epoch(args,
+                                    "val", model, None, val_loader, criterion)
+
+        if epoch % args.print_freq == 0:
+            print(f"epoch {epoch}")
+
+            lr_tmp = get_lr(optim)
+            print(f"Learning rate in current epoch: {lr_tmp}")
+
+            ci_train, _ = compute_metrics(args, preds)
+            metric_logger['train']['loss'].append(loss)
+            metric_logger['train']['cindex'].append(ci_train)
+
+            print(f"Training loss = {loss}")
+            print(f"Train C-index (survival) = {ci_train}")
+
+            ci_val, _ = compute_metrics(args, vpreds)
+            metric_logger['val']['loss'].append(vloss)
+            metric_logger['val']['cindex'].append(ci_val)
+
+            print(f"Validation loss = {vloss}")
+            print(f"Val C-index (survival) = {ci_val}")
+
+            if (epoch > 5) and (ci_val > best_val_cindex):
+                best_val_cindex = ci_val
+
+                torch.save({
+                'args': args,
+                'epoch': epoch,
+                'model_state_dict': model.cpu().state_dict(),
+                'optimizer_state_dict': optim.state_dict(),
+                'metrics': metric_logger,
+                'train_loss': loss,
+                'train_pred': preds,
+                'val_loss': vloss,
+                'val_pred': vpreds},
+                os.path.join(args.checkpoints_dir, args.exp_name, model_name, f'{model_name}_best_val_cindex.pt'))
+
+
+    return model, optim, metric_logger
+
+
+def train_val(args, device):
+    criterion = MultiTaskLoss()
+    data_train = extract_csv(os.path.join(
+        args.dataroot, "data_table_train.csv"))
+    
+    data_val = extract_csv(os.path.join(
+        args.dataroot, "data_table_val.csv"))
+    
+    # torch.cuda.manual_seed_all(42)
+    # torch.manual_seed(42)
+    # np.random.seed(42)
+    model = Model(args)
+    model.to(device)
+    
+    optim = define_optimizer(args, model)
+    scheduler = define_scheduler(args, optim)
+    print(model)
+    print("Number of Trainable Parameters: %d" %
+          count_parameters(model))
+    print("Optimizer Type:", args.optimizer_type)
+    print("Activation Type:", args.act_type)
+
+
+
+    model, optim, metric_logger = train_model(
+        args, data_train, data_val, model, criterion, optim, scheduler, device)
+    
+    return metric_logger
+
+
+
+if __name__ == '__main__':
+    args = parse_args()
+    root = args.dataroot
+    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
+    print("Using device:", device)
+    torch.cuda.manual_seed_all(42)
+    torch.manual_seed(42)
+    np.random.seed(42)
+    metric_logger = train_val(args, device)
+    test(args, device)
+
+    model_name = args.fusion_type+'_'+args.task+'_'+str(args.n_epochs)+'_'+str(args.lr)
+
+
+    # Save results for train, validation, and test sets
+    save_results_to_mat("train", args, model_name)
+    save_results_to_mat("val", args, model_name)
+    save_results_to_mat("test", args, model_name)
+
+    
+    
+    # Plotting
+    plt.figure(figsize=(12, 6))
+
+    # Plotting the training loss
+    plt.subplot(1, 2, 1)
+    plt.plot(range(args.epoch_count, args.niter + args.n_epochs + 1),
+            metric_logger['train']['loss'], label='Train')
+    plt.plot(range(args.epoch_count, args.niter + args.n_epochs + 1),
+            metric_logger['val']['loss'], label='Validation')
+    plt.xlabel('Epoch')
+    plt.ylabel('Loss')
+    plt.title('Training and Validation Loss')
+    plt.legend()
+
+    # Plotting the training C-index
+    plt.subplot(1, 2, 2)
+    plt.plot(range(args.epoch_count, args.niter + args.n_epochs + 1),
+            metric_logger['train']['cindex'], label='Train')
+    plt.plot(range(args.epoch_count, args.niter + args.n_epochs + 1),
+            metric_logger['val']['cindex'], label='Validation')
+    plt.xlabel('Epoch')
+    plt.ylabel('C-Index')
+    plt.title('Training and Validation C-Index')
+    plt.legend()
+
+    # Show the plots
+    plt.tight_layout()
+    plt.show()
\ No newline at end of file