# -*- coding: utf-8 -*-
"""
Created on Thu Jul 29 17:41:30 2021
@author: angelou
"""
import torch
from torch.autograd import Variable
import os
import argparse
from datetime import datetime
from utils.dataloader import get_loader,test_dataset
from utils.utils import clip_gradient, adjust_lr, AvgMeter
import torch.nn.functional as F
import numpy as np
from torchstat import stat
from CaraNet import caranet
def structure_loss(pred, mask):
weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
pred = torch.sigmoid(pred)
inter = ((pred * mask)*weit).sum(dim=(2, 3))
union = ((pred + mask)*weit).sum(dim=(2, 3))
wiou = 1 - (inter + 1)/(union - inter+1)
return (wbce + wiou).mean()
def test(model, path):
##### put your data_path of TestDataSet/Kvasir here #####
data_path = path
#########################################################
model.eval()
image_root = '{}/images/'.format(data_path)
gt_root = '{}/masks/'.format(data_path)
test_loader = test_dataset(image_root, gt_root, 512)
b=0.0
print('[test_size]',test_loader.size)
for i in range(test_loader.size):
image, gt, name = test_loader.load_data()
gt = np.asarray(gt, np.float32)
gt /= (gt.max() + 1e-8)
image = image.cuda()
res5,res3,res2,res1 = model(image)
res = res5
res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
res = res.sigmoid().data.cpu().numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
input = res
target = np.array(gt)
N = gt.shape
smooth = 1
input_flat = np.reshape(input,(-1))
target_flat = np.reshape(target,(-1))
intersection = (input_flat*target_flat)
loss = (2 * intersection.sum() + smooth) / (input.sum() + target.sum() + smooth)
a = '{:.4f}'.format(loss)
a = float(a)
b = b + a
return b/60
def train(train_loader, model, optimizer, epoch, test_path):
model.train()
# ---- multi-scale training ----
size_rates = [0.75, 1, 1.25]
loss_record1, loss_record2, loss_record3, loss_record5 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
for i, pack in enumerate(train_loader, start=1):
for rate in size_rates:
optimizer.zero_grad()
# ---- data prepare ----
images, gts = pack
images = Variable(images).cuda()
gts = Variable(gts).cuda()
# ---- rescale ----
trainsize = int(round(opt.trainsize*rate/32)*32)
if rate != 1:
images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
# ---- forward ----
lateral_map_5,lateral_map_3,lateral_map_2,lateral_map_1 = model(images)
# ---- loss function ----
loss5 = structure_loss(lateral_map_5, gts)
loss3 = structure_loss(lateral_map_3, gts)
loss2 = structure_loss(lateral_map_2, gts)
loss1 = structure_loss(lateral_map_1, gts)
loss = loss5 +loss3 + loss2 + loss1
# ---- backward ----
loss.backward()
clip_gradient(optimizer, opt.clip)
optimizer.step()
# ---- recording loss ----
if rate == 1:
loss_record5.update(loss5.data, opt.batchsize)
loss_record3.update(loss3.data, opt.batchsize)
loss_record2.update(loss2.data, opt.batchsize)
loss_record1.update(loss1.data, opt.batchsize)
# ---- train visualization ----
if i % 20 == 0 or i == total_step:
print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], '
' lateral-5: {:0.4f}], lateral-3: {:0.4f}], lateral-2: {:0.4f}], lateral-1: {:0.4f}]'.
format(datetime.now(), epoch, opt.epoch, i, total_step,
loss_record5.show(),loss_record3.show(),loss_record2.show(),loss_record1.show()))
save_path = 'snapshots/{}/'.format(opt.train_save)
os.makedirs(save_path, exist_ok=True)
if (epoch+1) % 1 == 0:
meandice = test(model,test_path)
fp = open('log/log.txt','a')
fp.write(str(meandice)+'\n')
fp.close()
fp = open('log/best.txt','r')
best = fp.read()
fp.close()
if meandice > float(best):
fp = open('log/best.txt','w')
fp.write(str(meandice))
fp.close()
# best = meandice
fp = open('log/best.txt','r')
best = fp.read()
fp.close()
torch.save(model.state_dict(), save_path + 'CaraNet-best.pth' )
print('[Saving Snapshot:]', save_path + 'CaraNet-best.pth',meandice,'[best:]',best)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int,
default=10, help='epoch number')
parser.add_argument('--lr', type=float,
default=1e-4, help='learning rate')
parser.add_argument('--optimizer', type=str,
default='Adam', help='choosing optimizer Adam or SGD')
parser.add_argument('--augmentation',
default=False, help='choose to do random flip rotation')
parser.add_argument('--batchsize', type=int,
default=6, help='training batch size')
parser.add_argument('--trainsize', type=int,
default=352, help='training dataset size')
parser.add_argument('--clip', type=float,
default=0.5, help='gradient clipping margin')
parser.add_argument('--decay_rate', type=float,
default=0.1, help='decay rate of learning rate')
parser.add_argument('--decay_epoch', type=int,
default=50, help='every n epochs decay learning rate')
parser.add_argument('--train_path', type=str,
default='/home/data/spleen_blood/CaraNet/TrainDataset/train/', help='path to train dataset')
parser.add_argument('--test_path', type=str,
default='/home/data/spleen_blood/CaraNet/TestDataset/test/' , help='path to testing Kvasir dataset')
parser.add_argument('--train_save', type=str,
default='')
opt = parser.parse_args()
# ---- build models ----
torch.cuda.set_device(4) # set your gpu device
model = caranet().cuda()
# ---- flops and params ----
# from utils.utils import CalParams
# x = torch.randn(1, 3, 352, 352).cuda()
# CalParams(model, x)
params = model.parameters()
if opt.optimizer == 'Adam':
optimizer = torch.optim.Adam(params, opt.lr)
else:
optimizer = torch.optim.SGD(params, opt.lr, weight_decay = 1e-4, momentum = 0.9)
print(optimizer)
image_root = '{}/image/'.format(opt.train_path)
gt_root = '{}/mask/'.format(opt.train_path)
train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize, augmentation = opt.augmentation)
total_step = len(train_loader)
print("#"*20, "Start Training", "#"*20)
for epoch in range(1, opt.epoch):
adjust_lr(optimizer, opt.lr, epoch, 0.1, 200)
train(train_loader, model, optimizer, epoch, opt.test_path)