Diff of /train_UAMT.py [000000] .. [903821]

Switch to unified view

a b/train_UAMT.py
1
import os
2
import sys
3
from tqdm import tqdm
4
from tensorboardX import SummaryWriter
5
import argparse
6
import logging
7
import time
8
import random
9
import numpy as np
10
11
import torch
12
import torch.optim as optim
13
from torchvision import transforms
14
import torch.nn.functional as F
15
import torch.backends.cudnn as cudnn
16
from torch.utils.data import DataLoader
17
from torchvision.utils import make_grid
18
from dataloaders import utils
19
20
from networks.vnet import VNet
21
from utils import ramps, losses
22
from dataloaders.la_heart import *
23
24
25
parser = argparse.ArgumentParser()
26
parser.add_argument('--dataset_name', type=str,  default='LA', help='dataset_name')
27
parser.add_argument('--root_path', type=str, default='/***/data_set/LASet/data', help='Name of Experiment')
28
parser.add_argument('--exp', type=str,  default='vnet', help='model_name')
29
parser.add_argument('--model', type=str,  default='UAMT', help='model_name')
30
parser.add_argument('--max_iterations', type=int,  default=6000, help='maximum epoch number to train')
31
parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu')
32
parser.add_argument('--labeled_bs', type=int, default=2, help='labeled_batch_size per gpu')
33
parser.add_argument('--labelnum', type=int, default=25, help='trained samples')
34
parser.add_argument('--max_samples', type=int, default=123, help='all samples')
35
parser.add_argument('--base_lr', type=float,  default=0.01, help='maximum epoch number to train')
36
parser.add_argument('--deterministic', type=int,  default=1, help='whether use deterministic training')
37
parser.add_argument('--seed', type=int,  default=1337, help='random seed')
38
parser.add_argument('--gpu', type=str,  default='2', help='GPU to use')
39
### costs
40
parser.add_argument('--ema_decay', type=float,  default=0.99, help='ema_decay')
41
parser.add_argument('--consistency_type', type=str,  default="mse", help='consistency_type')
42
parser.add_argument('--consistency', type=float,  default=0.1, help='consistency')
43
parser.add_argument('--consistency_rampup', type=float,  default=40.0, help='consistency_rampup')
44
args = parser.parse_args()
45
46
num_classes = 2
47
patch_size = (112, 112, 80)
48
snapshot_path = "model/{}_{}_{}_labeled/{}".format(args.dataset_name, args.exp, args.labelnum, args.model)
49
50
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
51
batch_size = args.batch_size * len(args.gpu.split(','))
52
max_iterations = args.max_iterations
53
base_lr = args.base_lr
54
labeled_bs = args.labeled_bs
55
56
if args.deterministic:
57
    cudnn.benchmark = False
58
    cudnn.deterministic = True
59
    random.seed(args.seed)
60
    np.random.seed(args.seed)
61
    torch.manual_seed(args.seed)
62
    torch.cuda.manual_seed(args.seed)
63
64
65
def cal_dice(output, target, eps=1e-3):
66
    output = torch.argmax(output,dim=1)
67
    inter = torch.sum(output * target) + eps
68
    union = torch.sum(output) + torch.sum(target) + eps * 2
69
    dice = 2 * inter / union
70
    return dice
71
72
73
def get_current_consistency_weight(epoch):
74
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
75
    return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)
76
77
def update_ema_variables(model, ema_model, alpha, global_step):
78
    # Use the true average until the exponential average is more correct
79
    alpha = min(1 - 1 / (global_step + 1), alpha)
80
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
81
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
82
83
if __name__ == "__main__":
84
    # make logger file
85
    if not os.path.exists(snapshot_path):
86
        os.makedirs(snapshot_path)
87
88
    logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
89
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
90
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
91
    logging.info(str(args))
92
93
    def create_model(ema=False):
94
        # Network definition
95
        net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True)
96
        model = net.cuda()
97
        if ema:
98
            for param in model.parameters():
99
                param.detach_()
100
        return model
101
102
    model = create_model()
103
    ema_model = create_model(ema=True)
104
    db_train = LAHeart(base_dir=args.root_path,
105
                       split='train',
106
                       transform = transforms.Compose([
107
                          RandomRotFlip(),
108
                          RandomCrop(patch_size),
109
                          ToTensor(),
110
                          ]))
111
    db_test = LAHeart(base_dir=args.root_path,
112
                       split='test',
113
                       transform = transforms.Compose([
114
                           CenterCrop(patch_size),
115
                           ToTensor()
116
                       ]))
117
    labeled_idxs = list(range(args.labelnum))
118
    unlabeled_idxs = list(range( args.labelnum, args.max_samples))
119
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size-labeled_bs)
120
    def worker_init_fn(worker_id):
121
        random.seed(args.seed+worker_id)
122
    trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)
123
    test_loader = DataLoader(db_test, batch_size=1,shuffle=False, num_workers=4, pin_memory=True)
124
125
    model.train()
126
    ema_model.train()
127
    optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
128
129
    if args.consistency_type == 'mse':
130
        consistency_criterion = losses.softmax_mse_loss
131
    elif args.consistency_type == 'kl':
132
        consistency_criterion = losses.softmax_kl_loss
133
    else:
134
        assert False, args.consistency_type
135
136
    writer = SummaryWriter(snapshot_path+'/log')
137
    logging.info("{} itertations per epoch".format(len(trainloader)))
138
139
    iter_num = 0
140
    best_dice = 0
141
    max_epoch = max_iterations//len(trainloader)+1
142
    lr_ = base_lr
143
    model.train()
144
    for epoch_num in tqdm(range(max_epoch), ncols=70):
145
        time1 = time.time()
146
        for i_batch, sampled_batch in enumerate(trainloader):
147
            time2 = time.time()
148
            # print('fetch data cost {}'.format(time2-time1))
149
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
150
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
151
            unlabeled_volume_batch = volume_batch[labeled_bs:]
152
153
            noise = torch.clamp(torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2)
154
            ema_inputs = unlabeled_volume_batch + noise
155
            outputs = model(volume_batch)
156
            with torch.no_grad():
157
                ema_output = ema_model(ema_inputs)
158
            T = 8
159
            volume_batch_r = unlabeled_volume_batch.repeat(2, 1, 1, 1, 1)
160
            stride = volume_batch_r.shape[0] // 2
161
            preds = torch.zeros([stride * T, 2, 112, 112, 80]).cuda()
162
            for i in range(T//2):
163
                ema_inputs = volume_batch_r + torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2)
164
                with torch.no_grad():
165
                    preds[2 * stride * i:2 * stride * (i + 1)] = ema_model(ema_inputs)
166
            preds = F.softmax(preds, dim=1)
167
            preds = preds.reshape(T, stride, 2, 112, 112, 80)
168
            preds = torch.mean(preds, dim=0)  #(batch, 2, 112,112,80)
169
            uncertainty = -1.0*torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) #(batch, 1, 112,112,80)
170
171
172
            ## calculate the loss
173
            loss_seg = F.cross_entropy(outputs[:labeled_bs], label_batch[:labeled_bs])
174
            outputs_soft = F.softmax(outputs, dim=1)
175
            loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 1, :, :, :], label_batch[:labeled_bs] == 1)
176
            supervised_loss = 0.5*(loss_seg+loss_seg_dice)
177
178
            consistency_weight = get_current_consistency_weight(iter_num//150)
179
            consistency_dist = consistency_criterion(outputs[labeled_bs:], ema_output) #(batch, 2, 112,112,80)
180
            threshold = (0.75+0.25*ramps.sigmoid_rampup(iter_num, max_iterations))*np.log(2)
181
            mask = (uncertainty<threshold).float()
182
            consistency_dist = torch.sum(mask*consistency_dist)/(2*torch.sum(mask)+1e-16)
183
            consistency_loss = consistency_weight * consistency_dist
184
            loss = supervised_loss + consistency_loss
185
186
            optimizer.zero_grad()
187
            loss.backward()
188
            optimizer.step()
189
            update_ema_variables(model, ema_model, args.ema_decay, iter_num)
190
191
            iter_num = iter_num + 1
192
            writer.add_scalar('uncertainty/mean', uncertainty[0,0].mean(), iter_num)
193
            writer.add_scalar('uncertainty/max', uncertainty[0,0].max(), iter_num)
194
            writer.add_scalar('uncertainty/min', uncertainty[0,0].min(), iter_num)
195
            writer.add_scalar('uncertainty/mask_per', torch.sum(mask)/mask.numel(), iter_num)
196
            writer.add_scalar('uncertainty/threshold', threshold, iter_num)
197
            writer.add_scalar('lr', lr_, iter_num)
198
            writer.add_scalar('loss/loss', loss, iter_num)
199
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
200
            writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num)
201
            writer.add_scalar('train/consistency_loss', consistency_loss, iter_num)
202
            writer.add_scalar('train/consistency_weight', consistency_weight, iter_num)
203
            writer.add_scalar('train/consistency_dist', consistency_dist, iter_num)
204
205
            logging.info('iteration %d : loss : %f cons_dist: %f, loss_weight: %f' %
206
                         (iter_num, loss.item(), consistency_dist.item(), consistency_weight))
207
            
208
            if iter_num % 50 == 0:
209
                image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
210
                grid_image = make_grid(image, 5, normalize=True)
211
                writer.add_image('train/Image', grid_image, iter_num)
212
213
                # image = outputs_soft[0, 3:4, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
214
                image = torch.max(outputs_soft[0, :, :, :, 20:61:10], 0)[1].permute(2, 0, 1).data.cpu().numpy()
215
                image = utils.decode_seg_map_sequence(image)
216
                grid_image = make_grid(image, 5, normalize=False)
217
                writer.add_image('train/Predicted_label', grid_image, iter_num)
218
219
                image = label_batch[0, :, :, 20:61:10].permute(2, 0, 1)
220
                grid_image = make_grid(utils.decode_seg_map_sequence(image.data.cpu().numpy()), 5, normalize=False)
221
                writer.add_image('train/Groundtruth_label', grid_image, iter_num)
222
223
                image = uncertainty[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
224
                grid_image = make_grid(image, 5, normalize=True)
225
                writer.add_image('train/uncertainty', grid_image, iter_num)
226
227
                mask2 = (uncertainty > threshold).float()
228
                image = mask2[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
229
                grid_image = make_grid(image, 5, normalize=True)
230
                writer.add_image('train/mask', grid_image, iter_num)
231
                #####
232
                image = volume_batch[-1, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
233
                grid_image = make_grid(image, 5, normalize=True)
234
                writer.add_image('unlabel/Image', grid_image, iter_num)
235
236
                # image = outputs_soft[-1, 3:4, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
237
                image = torch.max(outputs_soft[-1, :, :, :, 20:61:10], 0)[1].permute(2, 0, 1).data.cpu().numpy()
238
                image = utils.decode_seg_map_sequence(image)
239
                grid_image = make_grid(image, 5, normalize=False)
240
                writer.add_image('unlabel/Predicted_label', grid_image, iter_num)
241
242
                image = label_batch[-1, :, :, 20:61:10].permute(2, 0, 1)
243
                grid_image = make_grid(utils.decode_seg_map_sequence(image.data.cpu().numpy()), 5, normalize=False)
244
                writer.add_image('unlabel/Groundtruth_label', grid_image, iter_num)
245
246
            ## change lr
247
            if iter_num % 2500 == 0:
248
                lr_ = base_lr * 0.1 ** (iter_num // 2500)
249
                for param_group in optimizer.param_groups:
250
                    param_group['lr'] = lr_
251
252
            if iter_num >= 800 and iter_num % 200 == 0:
253
                model.eval()
254
                with torch.no_grad():
255
                    dice_sample = 0
256
                    for sampled_batch in test_loader:
257
                        img, lbl = sampled_batch['image'].cuda(), sampled_batch['label'].cuda()
258
                        outputs = model(img)
259
                        dice_once = cal_dice(outputs,lbl)
260
                        dice_sample += dice_once
261
                    dice_sample = dice_sample / len(test_loader)
262
                    print('Average center dice:{:.3f}'.format(dice_sample))
263
                    
264
                if dice_sample > best_dice:
265
                    best_dice = dice_sample
266
                    save_mode_path = os.path.join(snapshot_path,  'iter_{}_dice_{}.pth'.format(iter_num, best_dice))
267
                    save_best_path = os.path.join(snapshot_path,'{}_best_model.pth'.format(args.model))
268
                    torch.save(model.state_dict(), save_mode_path)
269
                    torch.save(model.state_dict(), save_best_path)
270
                    logging.info("save best model to {}".format(save_mode_path))
271
                writer.add_scalar('Var_dice/Dice', dice_sample, iter_num)
272
                writer.add_scalar('Var_dice/Best_dice', best_dice, iter_num)
273
                model.train()
274
275
            if iter_num >= max_iterations:
276
                break
277
            time1 = time.time()
278
        if iter_num >= max_iterations:
279
            break
280
    save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations)+'.pth')
281
    torch.save(model.state_dict(), save_mode_path)
282
    logging.info("save model to {}".format(save_mode_path))
283
    writer.close()