--- a
+++ b/train.py
@@ -0,0 +1,137 @@
+import numpy as np
+import os
+from tqdm import tqdm
+import time
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.autograd import Variable
+from data.dataset import get_train_val_loader, inverse_normalize, get_test_loader
+from model import UNet
+from utils.Config import opt
+from utils.vis_tool import Visualizer
+from utils.eval_tool import compute_iou, save_pred_result
+import utils.array_tool as at
+
+def train(model, train_loader, criterion, epoch, vis):
+    model.train()
+    batch_loss = 0
+    for batch_idx, sample_batched in enumerate(train_loader):
+        data = sample_batched['image']
+        target = sample_batched['mask']
+        data, target = Variable(data.type(opt.dtype)), Variable(target.type(opt.dtype))
+        optimizer.zero_grad()
+        output = model(data)
+        loss = criterion(output, target)
+        loss.backward()
+        optimizer.step()
+        batch_loss += loss.data[0]
+        if (batch_idx+1) % opt.plot_every == 0:
+            ori_img_ = inverse_normalize(at.tonumpy(data[0]))
+            target_ = at.tonumpy(target[0])
+            pred_ = at.tonumpy(output[0])
+            vis.img('gt_img', ori_img_)
+            vis.img('gt_mask', target_)
+            vis.img('pred_mask', (pred_ >= 0.5).astype(np.float32))
+
+    batch_loss /= (batch_idx+1)
+    print('epoch: ' + str(epoch) + ', train loss: ' + str(batch_loss))
+    with open('logs.txt', 'a') as file:
+        file.write('epoch: ' + str(epoch) + ', train loss: ' + str(batch_loss) + '\n')
+    vis.plot('train loss', batch_loss)
+
+def val(model, val_loader, criterion, epoch, vis):
+    model.eval()
+    batch_loss = 0
+    avg_iou = 0
+    for batch_idx, sample_batched in enumerate(val_loader):
+        data = sample_batched['image']
+        target = sample_batched['mask']
+        data, target = Variable(data.type(opt.dtype), volatile=True), Variable(target.type(opt.dtype), volatile=True)
+        output = model.forward(data)
+        loss = criterion(output, target)
+        batch_loss += loss.data[0]
+        avg_iou += compute_iou(pred_masks=at.tonumpy(output >= 0.5).astype(np.float32), gt_masks=target)
+
+    batch_loss /= (batch_idx+1)
+    avg_iou /= len(val_loader.dataset)
+
+    print('epoch: ' + str(epoch) + ', validation loss: ' + str(batch_loss), ', avg_iou: ', avg_iou)
+    with open('logs.txt', 'a') as file:
+        file.write('epoch: ' + str(epoch) + ', validation loss: ' + str(batch_loss) + ', avg_iou: ' + str(avg_iou) + '\n')
+
+    vis.plot('val loss', batch_loss)
+    vis.plot('validation average IOU', avg_iou)
+    return avg_iou
+
+# train and validation
+def run(model, train_loader, val_loader, criterion, vis):
+    best_iou = 0
+    for epoch in range(1, opt.epochs+1):
+        train(model, train_loader, criterion, epoch, vis)
+        avg_iou = val(model, val_loader, criterion, epoch, vis)
+        if avg_iou > best_iou:
+            best_iou = avg_iou
+            if opt.save_model:
+                save_path = './checkpoints/RSNA_UNet_' + str(round(best_iou, 3)) + '_' + time.strftime('%m%d%H%M')
+                torch.save(model.state_dict(), save_path)
+
+    if opt.save_model:
+        save_path = './checkpoints/RSNA_UNet_' + str(round(best_iou, 3)) + '_' + time.strftime('%m%d%H%M')
+        torch.save(model.state_dict(), save_path)
+
+# make prediction
+def run_test(model, test_loader):
+    pred_masks = []
+    img_ids = []
+    images = []
+    for batch_idx, sample_batched in tqdm(enumerate(test_loader)):
+        data, img_id = sample_batched['image'], sample_batched['img_id']
+        data = Variable(data.type(opt.dtype), volatile=True)
+        output = model.forward(data)
+        # output = (output > 0.5)
+        output = at.tonumpy(output)
+        for i in range(0, output.shape[0]):
+            pred_mask = np.squeeze(output[i])
+            id = img_id[i]
+            pred_mask = (pred_mask >= 0.5).astype(np.float32)
+            pred_masks.append(pred_mask)
+            img_ids.append(id)
+            ori_img_ = inverse_normalize(at.tonumpy(data[i]))
+            images.append(ori_img_)
+
+    return img_ids, images, pred_masks
+
+if __name__ == '__main__':
+    """Train Unet model"""
+    model = UNet(input_channels=1, nclasses=1)
+    if opt.is_train:
+        # split all data to train and validation, set split = True
+        train_loader, val_loader = get_train_val_loader(opt.root_dir, batch_size=opt.batch_size, val_ratio=0.15,
+                                                        shuffle=True, num_workers=4, pin_memory=False)
+
+        optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)
+        criterion = nn.BCELoss()
+        vis = Visualizer(env=opt.env)
+
+        if opt.is_cuda:
+            model.cuda()
+            criterion.cuda()
+            if opt.n_gpu > 1:
+                model = nn.DataParallel(model)
+
+        run(model, train_loader, val_loader, criterion, vis)
+    else:
+        if opt.is_cuda:
+            model.cuda()
+            if opt.n_gpu > 1:
+                model = nn.DataParallel(model)
+        test_loader = get_test_loader(batch_size=20, shuffle=True,
+                                      num_workers=opt.num_workers,
+                                      pin_memory=opt.pin_memory)
+        # load the model and run test
+        model.load_state_dict(torch.load(os.path.join(opt.checkpoint_dir, 'RSNA_UNet_0.895_09210122')))
+
+        img_ids, images, pred_masks = run_test(model, test_loader)
+
+        save_pred_result(img_ids, images, pred_masks)
\ No newline at end of file