--- a
+++ b/trainner_baseline.py
@@ -0,0 +1,160 @@
+from model.Models import *
+from dataprocess import *
+import loss.losses as losses
+from metrics import *
+import torch.optim as optim
+import time
+import numpy as np
+import os
+import torch
+from config import Config
+import shutil
+from tqdm import tqdm
+import imageio
+import math
+from bisect import bisect_right
+
+config = Config()
+
+torch.cuda.set_device(config.gpu)  
+
+model_name = config.arch
+if not os.path.isdir('result'):
+    os.mkdir('result')
+if config.resume is False:
+    with open('result/' + str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '.txt', 'a+') as f:
+        f.seek(0)
+        f.truncate()
+model = U_Net()
+model.cuda()
+best_dice = 0  # best test accuracy
+start_epoch = 0  # start from epoch 0 or last checkpoint epoch
+optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, betas=(0.5, 0.999))    
+dataloader, dataloader_val = get_dataloader(config, batchsize=config.batch_size, mode='row')   # 64
+criterion = losses.init_loss('BCE_logit').cuda()
+
+if config.resume:
+    # Load checkpoint.
+    print('==> Resuming from checkpoint..')
+    if config.evaluate:
+        checkpoint = torch.load('./checkpoint/' + str(model_name) + '_best.pth.tar')
+    else:
+        checkpoint = torch.load('./checkpoint/' + str(model_name) + '.pth.tar')
+    model.load_state_dict(checkpoint['model'])
+    optimizer.load_state_dict(checkpoint['optimizer'])
+    best_dice = checkpoint['dice']
+    start_epoch = config.epochs
+
+def adjust_lr(optimizer, epoch, eta_max=0.0001, eta_min=0.):
+    cur_lr = 0.
+    if config.lr_type == 'SGDR':
+        i = int(math.log2(epoch / config.sgdr_t + 1))
+        T_cur = epoch - config.sgdr_t * (2 ** (i) - 1)
+        T_i = (config.sgdr_t * 2 ** i)
+
+        cur_lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * T_cur / T_i))
+
+    elif config.lr_type == 'multistep':
+        cur_lr = config.learning_rate * 0.1 ** bisect_right(config.milestones, epoch)
+
+    for param_group in optimizer.param_groups:
+        param_group['lr'] = cur_lr
+    return cur_lr
+
+def train(epoch):
+    model.train()
+    train_loss = 0
+
+    start_time = time.time()
+    for batch_idx, (inputs, lungs, medias, targets_u, targets_i, targets_s) in enumerate(dataloader):
+        iter_start_time = time.time()
+        inputs = inputs.cuda()
+        lungs = lungs.cuda()
+        medias = medias.cuda()
+        targets_i = targets_i.cuda()
+        targets_u = targets_u.cuda()
+        targets_s = targets_s.cuda()
+
+        outputs = model(medias)
+
+        outputs_sig = torch.sigmoid(outputs)
+
+        loss_seg = criterion(outputs_sig, targets_u)
+
+        loss_all = loss_seg
+
+        optimizer.zero_grad()
+        loss_all.backward()
+        optimizer.step()
+
+        train_loss += loss_all.item()
+
+        print('Epoch:{}\t batch_idx:{}/All_batch:{}\t duration:{:.3f}\t loss_all:{:.3f}'
+          .format(epoch, batch_idx, len(dataloader), time.time()-iter_start_time, loss_all.item()))
+        iter_start_time = time.time()
+    print('Epoch:{0}\t duration:{1:.3f}\ttrain_loss:{2:.6f}'.format(epoch, time.time()-start_time, train_loss/len(dataloader)))
+    
+    with open('result/' + str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '.txt', 'a+') as f:
+        f.write('Epoch:{0}\t duration:{1:.3f}\t learning_rate:{2:.6f}\t train_loss:{3:.4f}'
+          .format(epoch, time.time()-start_time, config.learning_rate, train_loss/len(dataloader)))
+
+def test(epoch):
+    global best_dice
+    model.eval()
+    dices_all = []
+    ious_all = []
+    nsds_all = []
+    with torch.no_grad():
+        for batch_idx, (inputs, lungs, medias, targets_u, targets_i, targets_s) in enumerate(dataloader_val):
+            inputs = inputs.cuda()
+            medias = medias.cuda()
+            targets_i = targets_i.cuda()
+            targets_u = targets_u.cuda()
+            targets_s = targets_s.cuda()
+
+            outputs = model(medias)
+
+            outputs_final_sig = torch.sigmoid(outputs)
+
+            dices_all = meandice(outputs_final_sig, targets_u, dices_all)
+            ious_all = meandIoU(outputs_final_sig, targets_u, ious_all)
+            nsds_all = meanNSD(outputs_final_sig, targets_u, nsds_all)
+
+            print('Epoch:{}\tbatch_idx:{}/All_batch:{}\tdice:{:.4f}\tiou:{:.4f}\tnsd:{:.4f}'
+            .format(epoch, batch_idx, len(dataloader_val), np.mean(np.array(dices_all)), np.mean(np.array(ious_all)), np.mean(np.array(nsds_all))))
+        with open('result/' + str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '.txt', 'a+') as f:
+            f.write('\tdice:{:.4f}\tiou:{:.4f}\tnsd:{:.4f}'.format(np.mean(np.array(dices_all)), np.mean(np.array(ious_all)), np.mean(np.array(nsds_all)))+'\n')
+
+    # Save checkpoint.
+    if config.resume is False:
+        dice = np.mean(np.array(dices_all))
+        print('Test accuracy: ', dice)
+        state = {
+            'model': model.state_dict(),
+            'dice': dice,
+            'epoch': epoch,
+            'optimizer': optimizer.state_dict()
+        }
+        if not os.path.isdir('checkpoint'):
+            os.mkdir('checkpoint')
+        torch.save(state, './checkpoint/'+str(model_name)+'.pth.tar')
+
+        is_best = False
+        if best_dice < dice:
+            best_dice = dice
+            is_best = True
+
+        if is_best:
+            shutil.copyfile('./checkpoint/' + str(model_name) + '.pth.tar',
+                            './checkpoint/' + str(model_name) + '_best.pth.tar')
+        print('Save Successfully')
+        print('------------------------------------------------------------------------')
+
+if __name__ == '__main__':
+
+    if config.resume:
+        test(start_epoch)
+    else:
+        for epoch in tqdm(range(start_epoch, config.epochs)):
+            train(epoch)
+            test(epoch)