Diff of /train_URPC.py [000000] .. [903821]

Switch to unified view

a b/train_URPC.py
1
import os
2
import sys
3
from tqdm import tqdm
4
from tensorboardX import SummaryWriter
5
import shutil
6
import argparse
7
import logging
8
import time
9
import random
10
import numpy as np
11
12
import torch
13
import torch.nn as nn
14
import torch.optim as optim
15
from torchvision import transforms
16
import torch.nn.functional as F
17
import torch.backends.cudnn as cudnn
18
from torch.utils.data import DataLoader
19
20
from networks.unet_urpc import unet_3D_dv_semi
21
from utils import ramps, losses
22
from dataloaders.la_heart import *
23
24
25
parser = argparse.ArgumentParser()
26
parser.add_argument('--dataset_name', type=str,  default='LA', help='dataset_name')
27
parser.add_argument('--root_path', type=str, default='/***/data_set/LASet/data', help='Name of Experiment')
28
parser.add_argument('--exp', type=str,  default='vnet', help='exp_name')
29
parser.add_argument('--model', type=str,  default='URPC', help='model_name')
30
parser.add_argument('--max_iterations', type=int,  default=6000, help='maximum epoch number to train')
31
parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu')
32
parser.add_argument('--labeled_bs', type=int, default=2, help='labeled_batch_size per gpu')
33
parser.add_argument('--labelnum', type=int,  default=25, help='trained samples')
34
parser.add_argument('--max_samples', type=int, default=123, help='all samples')
35
parser.add_argument('--base_lr', type=float,  default=0.01, help='maximum epoch number to train')
36
parser.add_argument('--deterministic', type=int,  default=1, help='whether use deterministic training')
37
parser.add_argument('--seed', type=int,  default=1337, help='random seed')
38
parser.add_argument('--gpu', type=str,  default='0', help='GPU to use')
39
### costs
40
parser.add_argument('--ema_decay', type=float,  default=0.99, help='ema_decay')
41
parser.add_argument('--consistency_type', type=str,  default="mse", help='consistency_type')
42
parser.add_argument('--consistency', type=float,  default=0.1, help='consistency')
43
parser.add_argument('--consistency_rampup', type=float,  default=40.0, help='consistency_rampup')
44
args = parser.parse_args()
45
46
num_classes = 2
47
patch_size = (112, 112, 80)
48
snapshot_path = "model/{}_{}_{}_labeled/{}".format(args.dataset_name, args.exp, args.labelnum, args.model)
49
50
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
51
batch_size = args.batch_size * len(args.gpu.split(','))
52
max_iterations = args.max_iterations
53
base_lr = args.base_lr
54
labeled_bs = args.labeled_bs
55
56
if args.deterministic:
57
    cudnn.benchmark = False
58
    cudnn.deterministic = True
59
    random.seed(args.seed)
60
    np.random.seed(args.seed)
61
    torch.manual_seed(args.seed)
62
    torch.cuda.manual_seed(args.seed)
63
64
65
def cal_dice(output, target, eps=1e-3):
66
    output = torch.argmax(output,dim=1)
67
    inter = torch.sum(output * target) + eps
68
    union = torch.sum(output) + torch.sum(target) + eps * 2
69
    dice = 2 * inter / union
70
    return dice
71
72
73
def get_current_consistency_weight(epoch):
74
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
75
    return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)
76
77
78
if __name__ == "__main__":
79
    if not os.path.exists(snapshot_path):
80
        os.makedirs(snapshot_path)
81
82
    logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO,
83
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
84
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
85
    logging.info(str(args))
86
        
87
    model = unet_3D_dv_semi(n_classes=num_classes, in_channels=1).cuda()
88
    db_train = LAHeart(base_dir=args.root_path,
89
                       split='train',
90
                       transform=transforms.Compose([
91
                           RandomRotFlip(),
92
                           RandomCrop(patch_size),
93
                           ToTensor(),
94
                       ]))
95
    db_test = LAHeart(base_dir=args.root_path,
96
                      split='test',
97
                      transform=transforms.Compose([
98
                          CenterCrop(patch_size),
99
                          ToTensor()
100
                      ]))
101
    
102
    labeled_idxs = list(range(args.labelnum))
103
    unlabeled_idxs = list(range(args.labelnum, args.max_samples))
104
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size-labeled_bs)
105
    def worker_init_fn(worker_id):
106
        random.seed(args.seed+worker_id)
107
    trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)
108
    test_loader = DataLoader(db_test, batch_size=1,shuffle=False, num_workers=4, pin_memory=True)
109
    
110
    model.train()
111
    optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
112
113
    ce_loss = nn.CrossEntropyLoss()
114
    dice_loss = losses.DiceLoss(num_classes)
115
    kl_distance = nn.KLDivLoss(reduction='none')
116
117
    writer = SummaryWriter(snapshot_path+'/log')
118
    logging.info("{} itertations per epoch".format(len(trainloader)))
119
120
    iter_num = 0
121
    best_dice = 0
122
    max_epoch = max_iterations//len(trainloader)+1
123
    model.train()
124
    for epoch_num in tqdm(range(max_epoch), ncols=70):
125
        time1 = time.time()
126
        for i_batch, sampled_batch in enumerate(trainloader):
127
            time2 = time.time()
128
            # print('fetch data cost {}'.format(time2-time1))
129
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
130
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
131
            unlabeled_volume_batch = volume_batch[labeled_bs:]
132
133
            outputs_aux1, outputs_aux2, outputs_aux3, outputs_aux4,  = model(volume_batch)
134
            outputs_aux1_soft = torch.softmax(outputs_aux1, dim=1)
135
            outputs_aux2_soft = torch.softmax(outputs_aux2, dim=1)
136
            outputs_aux3_soft = torch.softmax(outputs_aux3, dim=1)
137
            outputs_aux4_soft = torch.softmax(outputs_aux4, dim=1)
138
139
            loss_ce_aux1 = ce_loss(outputs_aux1[:args.labeled_bs],
140
                                   label_batch[:args.labeled_bs])
141
            loss_ce_aux2 = ce_loss(outputs_aux2[:args.labeled_bs],
142
                                   label_batch[:args.labeled_bs])
143
            loss_ce_aux3 = ce_loss(outputs_aux3[:args.labeled_bs],
144
                                   label_batch[:args.labeled_bs])
145
            loss_ce_aux4 = ce_loss(outputs_aux4[:args.labeled_bs],
146
                                   label_batch[:args.labeled_bs])
147
148
            loss_dice_aux1 = dice_loss(
149
                outputs_aux1_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))
150
            loss_dice_aux2 = dice_loss(
151
                outputs_aux2_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))
152
            loss_dice_aux3 = dice_loss(
153
                outputs_aux3_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))
154
            loss_dice_aux4 = dice_loss(
155
                outputs_aux4_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))
156
157
            supervised_loss = (loss_ce_aux1+loss_ce_aux2+loss_ce_aux3+loss_ce_aux4 +
158
                               loss_dice_aux1+loss_dice_aux2+loss_dice_aux3+loss_dice_aux4)/8
159
160
            preds = (outputs_aux1_soft +
161
                     outputs_aux2_soft+outputs_aux3_soft+outputs_aux4_soft)/4
162
163
            variance_aux1 = torch.sum(kl_distance(
164
                torch.log(outputs_aux1_soft[args.labeled_bs:]), preds[args.labeled_bs:]), dim=1, keepdim=True)
165
            exp_variance_aux1 = torch.exp(-variance_aux1)
166
167
            variance_aux2 = torch.sum(kl_distance(
168
                torch.log(outputs_aux2_soft[args.labeled_bs:]), preds[args.labeled_bs:]), dim=1, keepdim=True)
169
            exp_variance_aux2 = torch.exp(-variance_aux2)
170
171
            variance_aux3 = torch.sum(kl_distance(
172
                torch.log(outputs_aux3_soft[args.labeled_bs:]), preds[args.labeled_bs:]), dim=1, keepdim=True)
173
            exp_variance_aux3 = torch.exp(-variance_aux3)
174
175
            variance_aux4 = torch.sum(kl_distance(
176
                torch.log(outputs_aux4_soft[args.labeled_bs:]), preds[args.labeled_bs:]), dim=1, keepdim=True)
177
            exp_variance_aux4 = torch.exp(-variance_aux4)
178
179
            consistency_dist_aux1 = (
180
                preds[args.labeled_bs:] - outputs_aux1_soft[args.labeled_bs:]) ** 2
181
            consistency_loss_aux1 = torch.mean(
182
                consistency_dist_aux1 * exp_variance_aux1) / (torch.mean(exp_variance_aux1) + 1e-8) + torch.mean(variance_aux1)
183
184
            consistency_dist_aux2 = (
185
                preds[args.labeled_bs:] - outputs_aux2_soft[args.labeled_bs:]) ** 2
186
            consistency_loss_aux2 = torch.mean(
187
                consistency_dist_aux2 * exp_variance_aux2) / (torch.mean(exp_variance_aux2) + 1e-8) + torch.mean(variance_aux2)
188
189
            consistency_dist_aux3 = (
190
                preds[args.labeled_bs:] - outputs_aux3_soft[args.labeled_bs:]) ** 2
191
            consistency_loss_aux3 = torch.mean(
192
                consistency_dist_aux3 * exp_variance_aux3) / (torch.mean(exp_variance_aux3) + 1e-8) + torch.mean(variance_aux3)
193
194
            consistency_dist_aux4 = (
195
                preds[args.labeled_bs:] - outputs_aux4_soft[args.labeled_bs:]) ** 2
196
            consistency_loss_aux4 = torch.mean(
197
                consistency_dist_aux4 * exp_variance_aux4) / (torch.mean(exp_variance_aux4) + 1e-8) + torch.mean(variance_aux4)
198
199
            consistency_loss = (consistency_loss_aux1 +
200
                                consistency_loss_aux2 + consistency_loss_aux3 + consistency_loss_aux4) / 4
201
            consistency_weight = get_current_consistency_weight(iter_num//150)
202
            loss = supervised_loss + consistency_weight * consistency_loss
203
            optimizer.zero_grad()
204
            loss.backward()
205
            optimizer.step()
206
207
            iter_num = iter_num + 1
208
            writer.add_scalar('info/total_loss', loss, iter_num)
209
            writer.add_scalar('info/supervised_loss',
210
                              supervised_loss, iter_num)
211
            writer.add_scalar('info/consistency_loss',
212
                              consistency_loss, iter_num)
213
            writer.add_scalar('info/consistency_weight',
214
                              consistency_weight, iter_num)
215
216
            logging.info(
217
                'iteration %d : loss : %f, supervised_loss: %f' %
218
                (iter_num, loss.item(), supervised_loss.item()))
219
220
            if iter_num >= 800 and iter_num % 200 == 0:
221
                model.eval()
222
                with torch.no_grad():
223
                    dice_sample = 0
224
                    for sampled_batch in test_loader:
225
                        img, lbl = sampled_batch['image'].cuda(), sampled_batch['label'].cuda()
226
                        outputs = model(img)
227
                        dice_once = cal_dice(outputs[0],lbl)
228
                        dice_sample += dice_once
229
                    dice_sample = dice_sample / len(test_loader)
230
                    print('Average center dice:{:.3f}'.format(dice_sample))
231
                    
232
                if dice_sample > best_dice:
233
                    best_dice = dice_sample
234
                    save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, best_dice))
235
                    save_best_path = os.path.join(snapshot_path, '{}_best_model.pth'.format(args.model))
236
                    torch.save(model.state_dict(), save_mode_path)
237
                    torch.save(model.state_dict(), save_best_path)
238
                    logging.info("save best model to {}".format(save_mode_path))
239
                writer.add_scalar('Var_dice/Dice', dice_sample, iter_num)
240
                writer.add_scalar('Var_dice/Best_dice', best_dice, iter_num)
241
                model.train()
242
243
            if iter_num >= max_iterations:
244
                break
245
            time1 = time.time()
246
        if iter_num >= max_iterations:
247
            break
248
    save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations)+'.pth')
249
    torch.save(model.state_dict(), save_mode_path)
250
    logging.info("save model to {}".format(save_mode_path))
251
    writer.close()