Diff of /train_dtc.py [000000] .. [903821]

Switch to unified view

a b/train_dtc.py
1
import os
2
import sys
3
from tqdm import tqdm
4
from tensorboardX import SummaryWriter
5
import argparse
6
import logging
7
import time
8
import random
9
import numpy as np
10
11
import torch
12
import torch.optim as optim
13
from torchvision import transforms
14
import torch.nn.functional as F
15
import torch.backends.cudnn as cudnn
16
import torch.nn as nn
17
from torch.nn import BCEWithLogitsLoss, MSELoss
18
from torch.utils.data import DataLoader
19
20
from networks.vnet_sdf import VNet
21
from utils import ramps, losses
22
from dataloaders.la_heart import *
23
from dataloaders.utils import compute_sdf
24
25
parser = argparse.ArgumentParser()
26
parser.add_argument('--dataset_name', type=str,  default='LA', help='dataset_name')
27
parser.add_argument('--root_path', type=str, default='/data/omnisky/postgraduate/Yb/data_set/LASet/data', help='Name of Experiment')
28
parser.add_argument('--exp', type=str, default='vnet', help='model_name')
29
parser.add_argument('--model', type=str,  default='DTC', help='model_name')
30
parser.add_argument('--max_iterations', type=int, default=6000, help='maximum epoch number to train')
31
parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu')
32
parser.add_argument('--labeled_bs', type=int, default=2, help='labeled_batch_size per gpu')
33
parser.add_argument('--base_lr', type=float,  default=0.01, help='maximum epoch number to train')
34
parser.add_argument('--D_lr', type=float,  default=1e-4, help='maximum discriminator learning rate to train')
35
parser.add_argument('--deterministic', type=int,  default=1, help='whether use deterministic training')
36
parser.add_argument('--labelnum', type=int,  default=25, help='num of labeled samples')
37
parser.add_argument('--max_samples', type=int, default=123, help='all samples')
38
parser.add_argument('--seed', type=int,  default=1337, help='random seed')
39
parser.add_argument('--consistency_weight', type=float,  default=0.1, help='balance factor to control supervised loss and consistency loss')
40
parser.add_argument('--gpu', type=str,  default='1', help='GPU to use')
41
parser.add_argument('--beta', type=float,  default=0.3, help='balance factor to control regional and sdm loss')
42
parser.add_argument('--gamma', type=float,  default=0.5, help='balance factor to control supervised and consistency loss')
43
# costs
44
parser.add_argument('--consistency', type=float, default=1.0, help='consistency')
45
parser.add_argument('--consistency_rampup', type=float, default=40.0, help='consistency_rampup')
46
args = parser.parse_args()
47
48
train_data_path = args.root_path
49
snapshot_path = "model/{}_{}_{}_labeled/{}".format(args.dataset_name, args.exp, args.labelnum, args.model)
50
51
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
52
batch_size = args.batch_size * len(args.gpu.split(','))
53
max_iterations = args.max_iterations
54
base_lr = args.base_lr
55
labeled_bs = args.labeled_bs
56
57
if not args.deterministic:
58
    cudnn.benchmark = True
59
    cudnn.deterministic = False
60
else:
61
    cudnn.benchmark = False  # True #
62
    cudnn.deterministic = True  # False #
63
random.seed(args.seed)
64
np.random.seed(args.seed)
65
torch.manual_seed(args.seed)
66
torch.cuda.manual_seed(args.seed)
67
68
num_classes = 2
69
patch_size = (112, 112, 80)
70
71
def cal_dice(output, target, eps=1e-3):
72
    output = torch.sigmoid(output)
73
    output = (output>0.5).float()
74
    output = torch.squeeze(output)
75
    inter = torch.sum(output * target) + eps
76
    union = torch.sum(output) + torch.sum(target) + eps * 2
77
    dice = 2 * inter / union
78
    return dice
79
80
def get_current_consistency_weight(epoch):
81
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
82
    return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)
83
84
85
if __name__ == "__main__":
86
    # make logger file
87
    if not os.path.exists(snapshot_path):
88
        os.makedirs(snapshot_path)
89
90
    logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO,
91
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
92
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
93
    logging.info(str(args))
94
95
    def create_model(ema=False):
96
        # Network definition
97
        net = VNet(n_channels=1, n_classes=num_classes-1,
98
                   normalization='batchnorm', has_dropout=True)
99
        model = net.cuda()
100
        if ema:
101
            for param in model.parameters():
102
                param.detach_()
103
        return model
104
105
    model = create_model()
106
107
    db_train = LAHeart(base_dir=train_data_path,
108
                       split='train',  # train/val split
109
                       transform=transforms.Compose([
110
                           RandomRotFlip(),
111
                           RandomCrop(patch_size),
112
                           ToTensor(),
113
                       ]))
114
    db_test = LAHeart(base_dir=args.root_path,
115
                      split='test',
116
                      transform=transforms.Compose([
117
                          CenterCrop(patch_size),
118
                          ToTensor()
119
                      ]))
120
    labeled_idxs = list(range(args.labelnum))
121
    unlabeled_idxs = list(range(args.labelnum, args.max_samples))
122
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size-labeled_bs)
123
124
    def worker_init_fn(worker_id):
125
        random.seed(args.seed+worker_id)
126
    trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)
127
    test_loader = DataLoader(db_test, batch_size=1,shuffle=False, num_workers=4, pin_memory=True)
128
129
    model.train()
130
131
    optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
132
    ce_loss = BCEWithLogitsLoss()
133
    mse_loss = MSELoss()
134
135
    writer = SummaryWriter(snapshot_path+'/log')
136
    logging.info("{} itertations per epoch".format(len(trainloader)))
137
138
    iter_num = 0
139
    max_epoch = max_iterations//len(trainloader)+1
140
    lr_ = base_lr
141
    best_dice = 0.0
142
    iterator = tqdm(range(max_epoch), ncols=70)
143
    for epoch_num in iterator:
144
        time1 = time.time()
145
        for i_batch, sampled_batch in enumerate(trainloader):
146
            time2 = time.time()
147
            # print('fetch data cost {}'.format(time2-time1))
148
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
149
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
150
151
            outputs_tanh, outputs = model(volume_batch)
152
            outputs_soft = torch.sigmoid(outputs)
153
154
            # calculate the loss
155
            with torch.no_grad():
156
                gt_dis = compute_sdf(label_batch[:].cpu().numpy(), outputs[:labeled_bs, 0, ...].shape)
157
                gt_dis = torch.from_numpy(gt_dis).float().cuda()
158
            loss_sdf = mse_loss(outputs_tanh[:labeled_bs, 0, ...], gt_dis)
159
            loss_seg = ce_loss(outputs[:labeled_bs, 0, ...], label_batch[:labeled_bs].float())
160
            loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 0, :, :, :], label_batch[:labeled_bs] == 1)
161
            supervised_loss = loss_seg_dice + args.beta * loss_sdf
162
            
163
            # unsupervised loss
164
            dis_to_mask = torch.sigmoid(-1500*outputs_tanh)
165
            consistency_loss = torch.mean((dis_to_mask - outputs_soft) ** 2)
166
            consistency_weight = get_current_consistency_weight(iter_num//150)
167
168
            loss = supervised_loss + consistency_weight * consistency_loss
169
170
            optimizer.zero_grad()
171
            loss.backward()
172
            optimizer.step()
173
174
            iter_num = iter_num + 1
175
            writer.add_scalar('lr', lr_, iter_num)
176
            writer.add_scalar('loss/loss', loss, iter_num)
177
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
178
            writer.add_scalar('loss/loss_dice', loss_seg_dice, iter_num)
179
            writer.add_scalar('loss/loss_hausdorff', loss_sdf, iter_num)
180
            writer.add_scalar('loss/consistency_weight', consistency_weight, iter_num)
181
            writer.add_scalar('loss/consistency_loss', consistency_loss, iter_num)
182
183
            logging.info('iteration %d : loss : %f, loss_consis: %f, loss_haus: %f, loss_seg: %f, loss_dice: %f' %
184
                (iter_num, loss.item(), consistency_loss.item(), loss_sdf.item(),  loss_seg.item(), loss_seg_dice.item()))
185
            writer.add_scalar('loss/loss', loss, iter_num)
186
            # logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))
187
188
            if iter_num >= 800 and iter_num % 200 == 0:
189
                model.eval()
190
                with torch.no_grad():
191
                    dice_sample = 0
192
                    for sampled_batch in test_loader:
193
                        img, lbl = sampled_batch['image'].cuda(), sampled_batch['label'].cuda()
194
                        _, outputs = model(img)
195
                        dice_once = cal_dice(outputs,lbl)
196
                        print(dice_once)
197
                        dice_sample += dice_once
198
                    dice_sample = dice_sample / len(test_loader)
199
                    print('Average center dice:{:.3f}'.format(dice_sample))
200
                        
201
                    if dice_sample > best_dice:
202
                        best_dice = dice_sample
203
                        save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, best_dice))
204
                        save_best_path = os.path.join(snapshot_path, '{}_best_model.pth'.format(args.model))
205
                        torch.save(model.state_dict(), save_mode_path)
206
                        torch.save(model.state_dict(), save_best_path)
207
                        logging.info("save best model to {}".format(save_mode_path))
208
                    writer.add_scalar('Var_dice/Dice', dice_sample, iter_num)
209
                    writer.add_scalar('Var_dice/Best_dice', best_dice, iter_num)
210
                    model.train()
211
212
            if iter_num >= max_iterations:
213
                break
214
            time1 = time.time()
215
        if iter_num >= max_iterations:
216
            iterator.close()
217
            break
218
    writer.close()