--- a +++ b/CaraNet/train_blood.py @@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- +""" +Created on Thu Jul 29 17:41:30 2021 + +@author: angelou +""" + +import torch +from torch.autograd import Variable +import os +import argparse +from datetime import datetime +from utils.dataloader import get_loader,test_dataset +from utils.utils import clip_gradient, adjust_lr, AvgMeter +import torch.nn.functional as F +import numpy as np +from torchstat import stat +from CaraNet import caranet + + + +def structure_loss(pred, mask): + + weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) + wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') + wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) + + pred = torch.sigmoid(pred) + inter = ((pred * mask)*weit).sum(dim=(2, 3)) + union = ((pred + mask)*weit).sum(dim=(2, 3)) + wiou = 1 - (inter + 1)/(union - inter+1) + + return (wbce + wiou).mean() + + + + + +def test(model, path): + + ##### put your data_path of TestDataSet/Kvasir here ##### + data_path = path + ######################################################### + + model.eval() + image_root = '{}/images/'.format(data_path) + gt_root = '{}/masks/'.format(data_path) + test_loader = test_dataset(image_root, gt_root, 512) + b=0.0 + print('[test_size]',test_loader.size) + for i in range(test_loader.size): + image, gt, name = test_loader.load_data() + gt = np.asarray(gt, np.float32) + gt /= (gt.max() + 1e-8) + image = image.cuda() + + res5,res3,res2,res1 = model(image) + res = res5 + res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) + res = res.sigmoid().data.cpu().numpy().squeeze() + res = (res - res.min()) / (res.max() - res.min() + 1e-8) + + input = res + target = np.array(gt) + N = gt.shape + smooth = 1 + input_flat = np.reshape(input,(-1)) + target_flat = np.reshape(target,(-1)) + + intersection = (input_flat*target_flat) + + loss = (2 * intersection.sum() + smooth) / (input.sum() + target.sum() + smooth) + + a = '{:.4f}'.format(loss) + a = float(a) + b = b + a + + return b/60 + + + +def train(train_loader, model, optimizer, epoch, test_path): + model.train() + # ---- multi-scale training ---- + size_rates = [0.75, 1, 1.25] + loss_record1, loss_record2, loss_record3, loss_record5 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() + for i, pack in enumerate(train_loader, start=1): + for rate in size_rates: + optimizer.zero_grad() + # ---- data prepare ---- + images, gts = pack + images = Variable(images).cuda() + gts = Variable(gts).cuda() + # ---- rescale ---- + trainsize = int(round(opt.trainsize*rate/32)*32) + if rate != 1: + images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True) + gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) + # ---- forward ---- + lateral_map_5,lateral_map_3,lateral_map_2,lateral_map_1 = model(images) + # ---- loss function ---- + loss5 = structure_loss(lateral_map_5, gts) + loss3 = structure_loss(lateral_map_3, gts) + loss2 = structure_loss(lateral_map_2, gts) + loss1 = structure_loss(lateral_map_1, gts) + + + loss = loss5 +loss3 + loss2 + loss1 + # ---- backward ---- + loss.backward() + clip_gradient(optimizer, opt.clip) + optimizer.step() + # ---- recording loss ---- + if rate == 1: + + loss_record5.update(loss5.data, opt.batchsize) + loss_record3.update(loss3.data, opt.batchsize) + loss_record2.update(loss2.data, opt.batchsize) + loss_record1.update(loss1.data, opt.batchsize) + # ---- train visualization ---- + if i % 20 == 0 or i == total_step: + print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], ' + ' lateral-5: {:0.4f}], lateral-3: {:0.4f}], lateral-2: {:0.4f}], lateral-1: {:0.4f}]'. + format(datetime.now(), epoch, opt.epoch, i, total_step, + loss_record5.show(),loss_record3.show(),loss_record2.show(),loss_record1.show())) + save_path = 'snapshots/{}/'.format(opt.train_save) + os.makedirs(save_path, exist_ok=True) + + + + + + if (epoch+1) % 1 == 0: + meandice = test(model,test_path) + + fp = open('log/log.txt','a') + fp.write(str(meandice)+'\n') + fp.close() + + fp = open('log/best.txt','r') + best = fp.read() + fp.close() + + if meandice > float(best): + fp = open('log/best.txt','w') + fp.write(str(meandice)) + fp.close() + # best = meandice + fp = open('log/best.txt','r') + best = fp.read() + fp.close() + torch.save(model.state_dict(), save_path + 'CaraNet-best.pth' ) + print('[Saving Snapshot:]', save_path + 'CaraNet-best.pth',meandice,'[best:]',best) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('--epoch', type=int, + default=10, help='epoch number') + + parser.add_argument('--lr', type=float, + default=1e-4, help='learning rate') + + parser.add_argument('--optimizer', type=str, + default='Adam', help='choosing optimizer Adam or SGD') + + parser.add_argument('--augmentation', + default=False, help='choose to do random flip rotation') + + parser.add_argument('--batchsize', type=int, + default=6, help='training batch size') + + parser.add_argument('--trainsize', type=int, + default=352, help='training dataset size') + + parser.add_argument('--clip', type=float, + default=0.5, help='gradient clipping margin') + + parser.add_argument('--decay_rate', type=float, + default=0.1, help='decay rate of learning rate') + + parser.add_argument('--decay_epoch', type=int, + default=50, help='every n epochs decay learning rate') + + parser.add_argument('--train_path', type=str, + default='/home/data/spleen_blood/CaraNet/TrainDataset/train/', help='path to train dataset') + + parser.add_argument('--test_path', type=str, + default='/home/data/spleen_blood/CaraNet/TestDataset/test/' , help='path to testing Kvasir dataset') + + parser.add_argument('--train_save', type=str, + default='') + + opt = parser.parse_args() + + # ---- build models ---- + torch.cuda.set_device(4) # set your gpu device + model = caranet().cuda() + # ---- flops and params ---- + + # from utils.utils import CalParams + # x = torch.randn(1, 3, 352, 352).cuda() + # CalParams(model, x) + + params = model.parameters() + + if opt.optimizer == 'Adam': + optimizer = torch.optim.Adam(params, opt.lr) + else: + optimizer = torch.optim.SGD(params, opt.lr, weight_decay = 1e-4, momentum = 0.9) + + print(optimizer) + + image_root = '{}/image/'.format(opt.train_path) + gt_root = '{}/mask/'.format(opt.train_path) + + train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize, augmentation = opt.augmentation) + total_step = len(train_loader) + + print("#"*20, "Start Training", "#"*20) + + for epoch in range(1, opt.epoch): + adjust_lr(optimizer, opt.lr, epoch, 0.1, 200) + train(train_loader, model, optimizer, epoch, opt.test_path) + + + + + + + + + + + + + + + + + + + + + + + + + +