a b/train_mean_teacher.py
1
import os
2
import sys
3
from tqdm import tqdm
4
from tensorboardX import SummaryWriter
5
import argparse
6
import logging
7
import time
8
import random
9
10
import torch
11
import torch.optim as optim
12
from torchvision import transforms
13
import torch.nn.functional as F
14
import torch.backends.cudnn as cudnn
15
from torch.utils.data import DataLoader
16
17
from networks.vnet import VNet
18
from utils import ramps, losses
19
from dataloaders.la_heart import *
20
21
parser = argparse.ArgumentParser()
22
parser.add_argument('--dataset_name', type=str, default='LA', help='dataset_name')
23
parser.add_argument('--root_path', type=str, default='/data/omnisky/postgraduate/Yb/data_set/LASet/data',
24
                    help='Name of Experiment')
25
parser.add_argument('--exp', type=str, default='vnet', help='model_name')
26
parser.add_argument('--model', type=str, default='MT', help='model_name')
27
parser.add_argument('--max_iterations', type=int, default=6000, help='maximum epoch number to train')
28
parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu')
29
parser.add_argument('--labeled_bs', type=int, default=2, help='labeled_batch_size per gpu')
30
parser.add_argument('--labelnum', type=int, default=25, help='trained samples')
31
parser.add_argument('--max_samples', type=int, default=123, help='all samples')
32
parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train')
33
parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training')
34
parser.add_argument('--seed', type=int, default=1337, help='random seed')
35
parser.add_argument('--gpu', type=str, default='1', help='GPU to use')
36
### costs
37
parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay')
38
parser.add_argument('--consistency_type', type=str, default="mse", help='consistency_type')
39
parser.add_argument('--consistency', type=float, default=0.1, help='consistency')
40
parser.add_argument('--consistency_rampup', type=float, default=40.0, help='consistency_rampup')
41
args = parser.parse_args()
42
43
44
patch_size = (112, 112, 80)
45
snapshot_path = "model/{}_{}_{}_labeled/{}".format(args.dataset_name, args.exp, args.labelnum, args.model)
46
47
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
48
batch_size = args.batch_size * len(args.gpu.split(','))
49
max_iterations = args.max_iterations
50
base_lr = args.base_lr
51
labeled_bs = args.labeled_bs
52
53
if args.deterministic:
54
    cudnn.benchmark = False
55
    cudnn.deterministic = True
56
    random.seed(args.seed)
57
    np.random.seed(args.seed)
58
    torch.manual_seed(args.seed)
59
    torch.cuda.manual_seed(args.seed)
60
61
num_classes = 2
62
patch_size = (112, 112, 80)
63
64
65
def cal_dice(output, target, eps=1e-3):
66
    output = torch.argmax(output,dim=1)
67
    inter = torch.sum(output * target) + eps
68
    union = torch.sum(output) + torch.sum(target) + eps * 2
69
    dice = 2 * inter / union
70
    return dice
71
72
def get_current_consistency_weight(epoch):
73
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
74
    return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)
75
76
77
def update_ema_variables(model, ema_model, alpha, global_step):
78
    # Use the true average until the exponential average is more correct
79
    alpha = min(1 - 1 / (global_step + 1), alpha)
80
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
81
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
82
83
84
if __name__ == "__main__":
85
    # make logger file
86
    if not os.path.exists(snapshot_path):
87
        os.makedirs(snapshot_path)
88
89
    logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
90
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
91
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
92
    logging.info(str(args))
93
94
95
    def create_model(ema=False):
96
        # Network definition
97
        net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True)
98
        model = net.cuda()
99
        if ema:
100
            for param in model.parameters():
101
                param.detach_()
102
        return model
103
104
105
    model = create_model()
106
    ema_model = create_model(ema=True)
107
    db_train = LAHeart(base_dir=args.root_path,
108
                       split='train',
109
                       transform=transforms.Compose([
110
                           RandomRotFlip(),
111
                           RandomCrop(patch_size),
112
                           ToTensor(),
113
                       ]))
114
    db_test = LAHeart(base_dir=args.root_path,
115
                      split='test',
116
                      transform=transforms.Compose([
117
                          CenterCrop(patch_size),
118
                          ToTensor()
119
                      ]))
120
    labeled_idxs = list(range(args.labelnum))
121
    unlabeled_idxs = list(range(args.labelnum, args.max_samples))
122
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size - labeled_bs)
123
124
125
    def worker_init_fn(worker_id):
126
        random.seed(args.seed + worker_id)
127
128
129
    train_loader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True,
130
                              worker_init_fn=worker_init_fn)
131
    test_loader = DataLoader(db_test, batch_size=1,shuffle=False, num_workers=4, pin_memory=True)
132
133
    model.train()
134
    ema_model.train()
135
    optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
136
137
    if args.consistency_type == 'mse':
138
        consistency_criterion = losses.softmax_mse_loss
139
    elif args.consistency_type == 'kl':
140
        consistency_criterion = losses.softmax_kl_loss
141
    else:
142
        assert False, args.consistency_type
143
144
    writer = SummaryWriter(snapshot_path + '/log')
145
    logging.info("{} itertations per epoch".format(len(train_loader)))
146
147
    iter_num = 0
148
    best_dice = 0
149
    max_epoch = max_iterations // len(train_loader) + 1
150
    lr_ = base_lr
151
    model.train()
152
    for epoch_num in tqdm(range(max_epoch), ncols=70):
153
        time1 = time.time()
154
        for i_batch, sampled_batch in enumerate(train_loader):
155
            time2 = time.time()
156
            # print('fetch data cost {}'.format(time2-time1))
157
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
158
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
159
            unlabeled_volume_batch = volume_batch[labeled_bs:]
160
161
            noise = torch.clamp(torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2)
162
            ema_inputs = unlabeled_volume_batch + noise
163
            outputs = model(volume_batch)
164
            with torch.no_grad():
165
                ema_output = ema_model(ema_inputs)
166
167
            # calculate the loss
168
            loss_seg = F.cross_entropy(outputs[:labeled_bs], label_batch[:labeled_bs])
169
            outputs_soft = F.softmax(outputs, dim=1)
170
            loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 1, :, :, :], label_batch[:labeled_bs] == 1)
171
            supervised_loss = 0.5 * (loss_seg + loss_seg_dice)
172
173
            consistency_weight = get_current_consistency_weight(iter_num // 150)
174
            consistency_dist = consistency_criterion(outputs[labeled_bs:], ema_output) # (batch, 2, 112,112,80)
175
            consistency_loss = consistency_weight * consistency_dist
176
            loss = supervised_loss + consistency_loss
177
178
            optimizer.zero_grad()
179
            loss.backward()
180
            optimizer.step()
181
            update_ema_variables(model, ema_model, args.ema_decay, iter_num)
182
183
            iter_num = iter_num + 1
184
            writer.add_scalar('lr', lr_, iter_num)
185
            writer.add_scalar('loss/loss', loss, iter_num)
186
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
187
            writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num)
188
            writer.add_scalar('train/consistency_weight', consistency_weight, iter_num)
189
            writer.add_scalar('train/consistency_dist', consistency_dist, iter_num)
190
191
            logging.info('iteration %d : loss : %f cons_dist: %f, loss_weight: %f' %
192
                         (iter_num, loss.item(), consistency_dist.item(), consistency_weight))
193
194
            if iter_num >= 800 and iter_num % 200 == 0:
195
                model.eval()
196
                with torch.no_grad():
197
                    dice_sample = 0
198
                    for sampled_batch in test_loader:
199
                        img, lbl = sampled_batch['image'].cuda(), sampled_batch['label'].cuda()
200
                        outputs = model(img)
201
                        dice_once = cal_dice(outputs,lbl)
202
                        dice_sample += dice_once
203
                    dice_sample = dice_sample / len(test_loader)
204
                    print('Average center dice:{:.3f}'.format(dice_sample))
205
                    
206
                if dice_sample > best_dice:
207
                    best_dice = dice_sample
208
                    save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, best_dice))
209
                    save_best_path = os.path.join(snapshot_path, '{}_best_model.pth'.format(args.model))
210
                    torch.save(model.state_dict(), save_mode_path)
211
                    torch.save(model.state_dict(), save_best_path)
212
                    logging.info("save best model to {}".format(save_mode_path))
213
                writer.add_scalar('Var_dice/Dice', dice_sample, iter_num)
214
                writer.add_scalar('Var_dice/Best_dice', best_dice, iter_num)
215
                model.train()
216
217
            if iter_num >= max_iterations:
218
                break
219
            time1 = time.time()
220
        if iter_num >= max_iterations:
221
            break
222
    save_mode_path = os.path.join(snapshot_path, 'iter_' + str(max_iterations) + '.pth')
223
    torch.save(model.state_dict(), save_mode_path)
224
    logging.info("save model to {}".format(save_mode_path))
225
    writer.close()