Diff of /2DNet/src/train.py [000000] .. [9f60b7]

Switch to side-by-side view

--- a
+++ b/2DNet/src/train.py
@@ -0,0 +1,273 @@
+import os
+import time
+import pandas as pd
+import gc
+import cv2
+import csv
+import random
+from sklearn.metrics.ranking import roc_auc_score
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim
+from torch.optim.lr_scheduler import ReduceLROnPlateau,MultiStepLR
+import torch.utils.data
+
+import torch.utils.data as data
+from net.models import *
+from dataset.dataset import *
+from tuils.tools import *
+from tuils.lrs_scheduler import WarmRestart, warm_restart, AdamW, RAdam
+from tuils.loss_function import *
+import torch.nn.functional as F
+from collections import OrderedDict
+import warnings
+warnings.filterwarnings('ignore')
+torch.manual_seed(1992)
+torch.cuda.manual_seed(1992)
+np.random.seed(1992)
+random.seed(1992)
+from PIL import ImageFile
+import sklearn
+import copy
+torch.backends.cudnn.benchmark = True
+import argparse
+
+def epochVal(model, dataLoader, loss_cls, c_val, val_batch_size):
+    model.eval ()
+    lossValNorm = 0
+    valLoss = 0
+
+    outGT = torch.FloatTensor().cuda()
+    outPRED = torch.FloatTensor().cuda()
+    for i, (input, target) in enumerate (dataLoader):
+        if i == 0:
+            ss_time = time.time()
+        print(str(i) + '/' + str(int(len(c_val)/val_batch_size)) + '     ' + str((time.time()-ss_time)/(i+1)), end='\r')
+        target = target.view(-1, 6).contiguous().cuda()
+        outGT = torch.cat((outGT, target), 0)
+        varInput = torch.autograd.Variable(input)
+        varTarget = torch.autograd.Variable(target.contiguous().cuda())
+        varOutput = model(varInput)
+        lossvalue = loss_cls(varOutput, varTarget)
+        valLoss = valLoss + lossvalue.item()
+        varOutput = varOutput.sigmoid()
+
+        outPRED = torch.cat((outPRED, varOutput.data), 0)
+        lossValNorm += 1
+
+    valLoss = valLoss / lossValNorm
+
+    auc = computeAUROC(outGT, outPRED, 6)
+    auc = [round(x, 4) for x in auc]
+    loss_list, loss_sum = weighted_log_loss(outPRED, outGT)
+
+    return valLoss, auc, loss_list, loss_sum
+
+def train(model_name, image_size):
+
+    if not os.path.exists(snapshot_path):
+        os.makedirs(snapshot_path)
+    header = ['Epoch', 'Learning rate', 'Time', 'Train Loss', 'Val Loss']
+
+    if not os.path.isfile(snapshot_path + '/log.csv'):
+        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
+            writer = csv.writer(f)
+            writer.writerow(header)
+    df_all = pd.read_csv(csv_path)
+
+    kfold_path_train = '../data/fold_5_by_study/'
+    kfold_path_val = '../data/fold_5_by_study_image/'
+
+    for num_fold in range(5):
+        print('fold_num:',num_fold)
+
+        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
+            writer = csv.writer(f)
+            writer.writerow([num_fold]) 
+
+        f_train = open(kfold_path_train + 'fold' + str(num_fold) + '/train.txt', 'r')
+        f_val = open(kfold_path_val + 'fold' + str(num_fold) + '/val.txt', 'r')
+        c_train = f_train.readlines()
+        c_val = f_val.readlines()
+        f_train.close()
+        f_val.close()
+        c_train = [s.replace('\n', '') for s in c_train]
+        c_val = [s.replace('\n', '') for s in c_val]     
+
+        # for debug
+        # c_train = c_train[0:1000]
+        # c_val = c_val[0:4000]
+
+        print('train dataset study num:', len(c_train), '  val dataset image num:', len(c_val))
+        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
+            writer = csv.writer(f)
+            writer.writerow(['train dataset:', len(c_train), '  val dataset:', len(c_val)])  
+            writer.writerow(['train_batch_size:', train_batch_size, 'val_batch_size:', val_batch_size])  
+
+        train_transform, val_transform = generate_transforms(image_size)
+        train_loader, val_loader = generate_dataset_loader(df_all, c_train, train_transform, train_batch_size, c_val, val_transform, val_batch_size, workers)
+
+        model = eval(model_name+'()')
+        model = model.cuda()
+
+        optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.00002)
+        scheduler = WarmRestart(optimizer, T_max=5, T_mult=1, eta_min=1e-5)
+        model = torch.nn.DataParallel(model)
+        loss_cls = torch.nn.BCEWithLogitsLoss(pos_weight = torch.FloatTensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).cuda())
+
+        trMaxEpoch = 80
+        for epochID in range (0, trMaxEpoch):
+            epochID = epochID + 0
+
+            start_time = time.time()
+            model.train()
+            trainLoss = 0
+            lossTrainNorm = 10
+
+            if epochID < 10:
+                pass
+            elif epochID < 80:
+                if epochID != 10:
+                    scheduler.step()
+                    scheduler = warm_restart(scheduler, T_mult=2) 
+            else:
+                optimizer.param_groups[0]['lr'] = 1e-5
+
+            for batchID, (input, target) in enumerate (train_loader):
+                if batchID == 0:
+                    ss_time = time.time()
+
+                print(str(batchID) + '/' + str(int(len(c_train)/train_batch_size)) + '     ' + str((time.time()-ss_time)/(batchID+1)), end='\r')
+                varInput = torch.autograd.Variable(input)
+                target = target.view(-1, 6).contiguous().cuda()
+                varTarget = torch.autograd.Variable(target.contiguous().cuda())
+                varOutput = model(varInput)
+                lossvalue = loss_cls(varOutput, varTarget)
+                trainLoss = trainLoss + lossvalue.item()
+                lossTrainNorm = lossTrainNorm + 1
+
+                lossvalue.backward()
+                optimizer.step()
+                optimizer.zero_grad()
+                del lossvalue
+
+            trainLoss = trainLoss / lossTrainNorm
+
+            if (epochID+1)%5 == 0 or epochID > 79 or epochID == 0:
+                valLoss, auc, loss_list, loss_sum = epochVal(model, val_loader, loss_cls, c_val, val_batch_size)
+        
+            epoch_time = time.time() - start_time
+
+            if (epochID+1)%5 == 0 or epochID > 79:
+                torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'valLoss': valLoss}, snapshot_path + '/model_epoch_' + str(epochID) + '_' + str(num_fold) + '.pth')
+
+            result = [epochID,
+                      round(optimizer.state_dict()['param_groups'][0]['lr'], 6),
+                      round(epoch_time, 0),
+                      round(trainLoss, 5),
+                      round(valLoss, 5),
+                      'auc:', auc,
+                      'loss:',loss_list,
+                      loss_sum]
+
+            print(result)
+
+            with open(snapshot_path + '/log.csv', 'a', newline='') as f:
+                writer = csv.writer(f)
+                writer.writerow(result)  
+
+        del model
+
+def valid_snapshot(model_name, image_size):
+    dir = r'./DenseNet121_change_avg_256'
+    if not os.path.exists(snapshot_path):
+        os.makedirs(snapshot_path)
+    header = ['Epoch', 'Learning rate', 'Time', 'Train Loss', 'Val Loss']
+
+    if not os.path.isfile(snapshot_path + '/log.csv'):
+        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
+            writer = csv.writer(f)
+            writer.writerow(header)
+    df_all = pd.read_csv(csv_path)
+
+    kfold_path_val = '../data/fold_5_by_study_image/'
+    loss_cls = torch.nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).cuda())
+    for num_fold in range(5):
+        print('fold_num:', num_fold)
+
+        ckpt = r'model_epoch_best_'+str(num_fold)+'.pth'
+        ckpt = os.path.join(dir,ckpt)
+
+        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
+            writer = csv.writer(f)
+            writer.writerow([num_fold])
+
+        f_val = open(kfold_path_val + 'fold' + str(num_fold) + '/val.txt', 'r')
+        c_val = f_val.readlines()
+        f_val.close()
+        c_val = [s.replace('\n', '') for s in c_val]
+
+        print('  val dataset image num:', len(c_val))
+
+        val_transform = albumentations.Compose([
+            albumentations.Resize(image_size, image_size),
+            albumentations.Normalize(mean=(0.456, 0.456, 0.456), std=(0.224, 0.224, 0.224), max_pixel_value=255.0,
+                                     p=1.0)
+        ])
+
+        val_dataset = RSNA_Dataset_val_by_study_context(df_all, c_val, val_transform)
+
+        val_loader = torch.utils.data.DataLoader(
+            val_dataset,
+            batch_size=val_batch_size,
+            shuffle=False,
+            num_workers=workers,
+            pin_memory=True,
+            drop_last=False)
+
+        model = eval(model_name + '()')
+        model = model.cuda()
+        model = torch.nn.DataParallel(model)
+
+        if ckpt is not None:
+            print(ckpt)
+            model.load_state_dict(torch.load(ckpt, map_location=lambda storage, loc: storage)["state_dict"])
+
+        valLoss, auc, loss_list, loss_sum = epochVal(model, val_loader, loss_cls, c_val, val_batch_size)
+
+        result = [round(valLoss, 5),
+                  'auc:', auc,
+                  'loss:', loss_list,
+                  loss_sum]
+
+        with open(ckpt + '_log.csv', 'a', newline='') as f:
+            writer = csv.writer(f)
+            writer.writerow(result)
+        print(result)
+
+
+if __name__ == '__main__':
+    csv_path = '../data/stage1_train_cls.csv'
+    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument("-backbone", "--backbone", type=str, default='DenseNet121_change_avg', help='backbone')
+    parser.add_argument("-img_size", "--Image_size", type=int, default=256, help='image_size')
+    parser.add_argument("-tbs", "--train_batch_size", type=int, default=32, help='train_batch_size')
+    parser.add_argument("-vbs", "--val_batch_size", type=int, default=32, help='val_batch_size')
+    parser.add_argument("-save_path", "--model_save_path", type=str,
+                        default='DenseNet169_change_avg', help='epoch')
+    args = parser.parse_args()
+
+    Image_size = args.Image_size
+    train_batch_size = args.train_batch_size
+    val_batch_size = args.val_batch_size
+    workers = 24
+    backbone = args.backbone
+    print(backbone)
+    print('image size:', Image_size)
+    print('train batch size:', train_batch_size)
+    print('val batch size:', val_batch_size)
+    snapshot_path = 'data_test/' + args.model_save_path.replace('\n', '').replace('\r', '')
+    train(backbone, Image_size)
+    # valid_snapshot(backbone, Image_size)