--- a +++ b/2DNet/src/train.py @@ -0,0 +1,273 @@ +import os +import time +import pandas as pd +import gc +import cv2 +import csv +import random +from sklearn.metrics.ranking import roc_auc_score +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +from torch.optim.lr_scheduler import ReduceLROnPlateau,MultiStepLR +import torch.utils.data + +import torch.utils.data as data +from net.models import * +from dataset.dataset import * +from tuils.tools import * +from tuils.lrs_scheduler import WarmRestart, warm_restart, AdamW, RAdam +from tuils.loss_function import * +import torch.nn.functional as F +from collections import OrderedDict +import warnings +warnings.filterwarnings('ignore') +torch.manual_seed(1992) +torch.cuda.manual_seed(1992) +np.random.seed(1992) +random.seed(1992) +from PIL import ImageFile +import sklearn +import copy +torch.backends.cudnn.benchmark = True +import argparse + +def epochVal(model, dataLoader, loss_cls, c_val, val_batch_size): + model.eval () + lossValNorm = 0 + valLoss = 0 + + outGT = torch.FloatTensor().cuda() + outPRED = torch.FloatTensor().cuda() + for i, (input, target) in enumerate (dataLoader): + if i == 0: + ss_time = time.time() + print(str(i) + '/' + str(int(len(c_val)/val_batch_size)) + ' ' + str((time.time()-ss_time)/(i+1)), end='\r') + target = target.view(-1, 6).contiguous().cuda() + outGT = torch.cat((outGT, target), 0) + varInput = torch.autograd.Variable(input) + varTarget = torch.autograd.Variable(target.contiguous().cuda()) + varOutput = model(varInput) + lossvalue = loss_cls(varOutput, varTarget) + valLoss = valLoss + lossvalue.item() + varOutput = varOutput.sigmoid() + + outPRED = torch.cat((outPRED, varOutput.data), 0) + lossValNorm += 1 + + valLoss = valLoss / lossValNorm + + auc = computeAUROC(outGT, outPRED, 6) + auc = [round(x, 4) for x in auc] + loss_list, loss_sum = weighted_log_loss(outPRED, outGT) + + return valLoss, auc, loss_list, loss_sum + +def train(model_name, image_size): + + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + header = ['Epoch', 'Learning rate', 'Time', 'Train Loss', 'Val Loss'] + + if not os.path.isfile(snapshot_path + '/log.csv'): + with open(snapshot_path + '/log.csv', 'a', newline='') as f: + writer = csv.writer(f) + writer.writerow(header) + df_all = pd.read_csv(csv_path) + + kfold_path_train = '../data/fold_5_by_study/' + kfold_path_val = '../data/fold_5_by_study_image/' + + for num_fold in range(5): + print('fold_num:',num_fold) + + with open(snapshot_path + '/log.csv', 'a', newline='') as f: + writer = csv.writer(f) + writer.writerow([num_fold]) + + f_train = open(kfold_path_train + 'fold' + str(num_fold) + '/train.txt', 'r') + f_val = open(kfold_path_val + 'fold' + str(num_fold) + '/val.txt', 'r') + c_train = f_train.readlines() + c_val = f_val.readlines() + f_train.close() + f_val.close() + c_train = [s.replace('\n', '') for s in c_train] + c_val = [s.replace('\n', '') for s in c_val] + + # for debug + # c_train = c_train[0:1000] + # c_val = c_val[0:4000] + + print('train dataset study num:', len(c_train), ' val dataset image num:', len(c_val)) + with open(snapshot_path + '/log.csv', 'a', newline='') as f: + writer = csv.writer(f) + writer.writerow(['train dataset:', len(c_train), ' val dataset:', len(c_val)]) + writer.writerow(['train_batch_size:', train_batch_size, 'val_batch_size:', val_batch_size]) + + train_transform, val_transform = generate_transforms(image_size) + train_loader, val_loader = generate_dataset_loader(df_all, c_train, train_transform, train_batch_size, c_val, val_transform, val_batch_size, workers) + + model = eval(model_name+'()') + model = model.cuda() + + optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.00002) + scheduler = WarmRestart(optimizer, T_max=5, T_mult=1, eta_min=1e-5) + model = torch.nn.DataParallel(model) + loss_cls = torch.nn.BCEWithLogitsLoss(pos_weight = torch.FloatTensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).cuda()) + + trMaxEpoch = 80 + for epochID in range (0, trMaxEpoch): + epochID = epochID + 0 + + start_time = time.time() + model.train() + trainLoss = 0 + lossTrainNorm = 10 + + if epochID < 10: + pass + elif epochID < 80: + if epochID != 10: + scheduler.step() + scheduler = warm_restart(scheduler, T_mult=2) + else: + optimizer.param_groups[0]['lr'] = 1e-5 + + for batchID, (input, target) in enumerate (train_loader): + if batchID == 0: + ss_time = time.time() + + print(str(batchID) + '/' + str(int(len(c_train)/train_batch_size)) + ' ' + str((time.time()-ss_time)/(batchID+1)), end='\r') + varInput = torch.autograd.Variable(input) + target = target.view(-1, 6).contiguous().cuda() + varTarget = torch.autograd.Variable(target.contiguous().cuda()) + varOutput = model(varInput) + lossvalue = loss_cls(varOutput, varTarget) + trainLoss = trainLoss + lossvalue.item() + lossTrainNorm = lossTrainNorm + 1 + + lossvalue.backward() + optimizer.step() + optimizer.zero_grad() + del lossvalue + + trainLoss = trainLoss / lossTrainNorm + + if (epochID+1)%5 == 0 or epochID > 79 or epochID == 0: + valLoss, auc, loss_list, loss_sum = epochVal(model, val_loader, loss_cls, c_val, val_batch_size) + + epoch_time = time.time() - start_time + + if (epochID+1)%5 == 0 or epochID > 79: + torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'valLoss': valLoss}, snapshot_path + '/model_epoch_' + str(epochID) + '_' + str(num_fold) + '.pth') + + result = [epochID, + round(optimizer.state_dict()['param_groups'][0]['lr'], 6), + round(epoch_time, 0), + round(trainLoss, 5), + round(valLoss, 5), + 'auc:', auc, + 'loss:',loss_list, + loss_sum] + + print(result) + + with open(snapshot_path + '/log.csv', 'a', newline='') as f: + writer = csv.writer(f) + writer.writerow(result) + + del model + +def valid_snapshot(model_name, image_size): + dir = r'./DenseNet121_change_avg_256' + if not os.path.exists(snapshot_path): + os.makedirs(snapshot_path) + header = ['Epoch', 'Learning rate', 'Time', 'Train Loss', 'Val Loss'] + + if not os.path.isfile(snapshot_path + '/log.csv'): + with open(snapshot_path + '/log.csv', 'a', newline='') as f: + writer = csv.writer(f) + writer.writerow(header) + df_all = pd.read_csv(csv_path) + + kfold_path_val = '../data/fold_5_by_study_image/' + loss_cls = torch.nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).cuda()) + for num_fold in range(5): + print('fold_num:', num_fold) + + ckpt = r'model_epoch_best_'+str(num_fold)+'.pth' + ckpt = os.path.join(dir,ckpt) + + with open(snapshot_path + '/log.csv', 'a', newline='') as f: + writer = csv.writer(f) + writer.writerow([num_fold]) + + f_val = open(kfold_path_val + 'fold' + str(num_fold) + '/val.txt', 'r') + c_val = f_val.readlines() + f_val.close() + c_val = [s.replace('\n', '') for s in c_val] + + print(' val dataset image num:', len(c_val)) + + val_transform = albumentations.Compose([ + albumentations.Resize(image_size, image_size), + albumentations.Normalize(mean=(0.456, 0.456, 0.456), std=(0.224, 0.224, 0.224), max_pixel_value=255.0, + p=1.0) + ]) + + val_dataset = RSNA_Dataset_val_by_study_context(df_all, c_val, val_transform) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=val_batch_size, + shuffle=False, + num_workers=workers, + pin_memory=True, + drop_last=False) + + model = eval(model_name + '()') + model = model.cuda() + model = torch.nn.DataParallel(model) + + if ckpt is not None: + print(ckpt) + model.load_state_dict(torch.load(ckpt, map_location=lambda storage, loc: storage)["state_dict"]) + + valLoss, auc, loss_list, loss_sum = epochVal(model, val_loader, loss_cls, c_val, val_batch_size) + + result = [round(valLoss, 5), + 'auc:', auc, + 'loss:', loss_list, + loss_sum] + + with open(ckpt + '_log.csv', 'a', newline='') as f: + writer = csv.writer(f) + writer.writerow(result) + print(result) + + +if __name__ == '__main__': + csv_path = '../data/stage1_train_cls.csv' + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-backbone", "--backbone", type=str, default='DenseNet121_change_avg', help='backbone') + parser.add_argument("-img_size", "--Image_size", type=int, default=256, help='image_size') + parser.add_argument("-tbs", "--train_batch_size", type=int, default=32, help='train_batch_size') + parser.add_argument("-vbs", "--val_batch_size", type=int, default=32, help='val_batch_size') + parser.add_argument("-save_path", "--model_save_path", type=str, + default='DenseNet169_change_avg', help='epoch') + args = parser.parse_args() + + Image_size = args.Image_size + train_batch_size = args.train_batch_size + val_batch_size = args.val_batch_size + workers = 24 + backbone = args.backbone + print(backbone) + print('image size:', Image_size) + print('train batch size:', train_batch_size) + print('val batch size:', val_batch_size) + snapshot_path = 'data_test/' + args.model_save_path.replace('\n', '').replace('\r', '') + train(backbone, Image_size) + # valid_snapshot(backbone, Image_size)