|
a |
|
b/train_UAMT.py |
|
|
1 |
import os |
|
|
2 |
import sys |
|
|
3 |
from tqdm import tqdm |
|
|
4 |
from tensorboardX import SummaryWriter |
|
|
5 |
import argparse |
|
|
6 |
import logging |
|
|
7 |
import time |
|
|
8 |
import random |
|
|
9 |
import numpy as np |
|
|
10 |
|
|
|
11 |
import torch |
|
|
12 |
import torch.optim as optim |
|
|
13 |
from torchvision import transforms |
|
|
14 |
import torch.nn.functional as F |
|
|
15 |
import torch.backends.cudnn as cudnn |
|
|
16 |
from torch.utils.data import DataLoader |
|
|
17 |
from torchvision.utils import make_grid |
|
|
18 |
from dataloaders import utils |
|
|
19 |
|
|
|
20 |
from networks.vnet import VNet |
|
|
21 |
from utils import ramps, losses |
|
|
22 |
from dataloaders.la_heart import * |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
parser = argparse.ArgumentParser() |
|
|
26 |
parser.add_argument('--dataset_name', type=str, default='LA', help='dataset_name') |
|
|
27 |
parser.add_argument('--root_path', type=str, default='/***/data_set/LASet/data', help='Name of Experiment') |
|
|
28 |
parser.add_argument('--exp', type=str, default='vnet', help='model_name') |
|
|
29 |
parser.add_argument('--model', type=str, default='UAMT', help='model_name') |
|
|
30 |
parser.add_argument('--max_iterations', type=int, default=6000, help='maximum epoch number to train') |
|
|
31 |
parser.add_argument('--batch_size', type=int, default=4, help='batch_size per gpu') |
|
|
32 |
parser.add_argument('--labeled_bs', type=int, default=2, help='labeled_batch_size per gpu') |
|
|
33 |
parser.add_argument('--labelnum', type=int, default=25, help='trained samples') |
|
|
34 |
parser.add_argument('--max_samples', type=int, default=123, help='all samples') |
|
|
35 |
parser.add_argument('--base_lr', type=float, default=0.01, help='maximum epoch number to train') |
|
|
36 |
parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training') |
|
|
37 |
parser.add_argument('--seed', type=int, default=1337, help='random seed') |
|
|
38 |
parser.add_argument('--gpu', type=str, default='2', help='GPU to use') |
|
|
39 |
### costs |
|
|
40 |
parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay') |
|
|
41 |
parser.add_argument('--consistency_type', type=str, default="mse", help='consistency_type') |
|
|
42 |
parser.add_argument('--consistency', type=float, default=0.1, help='consistency') |
|
|
43 |
parser.add_argument('--consistency_rampup', type=float, default=40.0, help='consistency_rampup') |
|
|
44 |
args = parser.parse_args() |
|
|
45 |
|
|
|
46 |
num_classes = 2 |
|
|
47 |
patch_size = (112, 112, 80) |
|
|
48 |
snapshot_path = "model/{}_{}_{}_labeled/{}".format(args.dataset_name, args.exp, args.labelnum, args.model) |
|
|
49 |
|
|
|
50 |
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu |
|
|
51 |
batch_size = args.batch_size * len(args.gpu.split(',')) |
|
|
52 |
max_iterations = args.max_iterations |
|
|
53 |
base_lr = args.base_lr |
|
|
54 |
labeled_bs = args.labeled_bs |
|
|
55 |
|
|
|
56 |
if args.deterministic: |
|
|
57 |
cudnn.benchmark = False |
|
|
58 |
cudnn.deterministic = True |
|
|
59 |
random.seed(args.seed) |
|
|
60 |
np.random.seed(args.seed) |
|
|
61 |
torch.manual_seed(args.seed) |
|
|
62 |
torch.cuda.manual_seed(args.seed) |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
def cal_dice(output, target, eps=1e-3): |
|
|
66 |
output = torch.argmax(output,dim=1) |
|
|
67 |
inter = torch.sum(output * target) + eps |
|
|
68 |
union = torch.sum(output) + torch.sum(target) + eps * 2 |
|
|
69 |
dice = 2 * inter / union |
|
|
70 |
return dice |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
def get_current_consistency_weight(epoch): |
|
|
74 |
# Consistency ramp-up from https://arxiv.org/abs/1610.02242 |
|
|
75 |
return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup) |
|
|
76 |
|
|
|
77 |
def update_ema_variables(model, ema_model, alpha, global_step): |
|
|
78 |
# Use the true average until the exponential average is more correct |
|
|
79 |
alpha = min(1 - 1 / (global_step + 1), alpha) |
|
|
80 |
for ema_param, param in zip(ema_model.parameters(), model.parameters()): |
|
|
81 |
ema_param.data.mul_(alpha).add_(1 - alpha, param.data) |
|
|
82 |
|
|
|
83 |
if __name__ == "__main__": |
|
|
84 |
# make logger file |
|
|
85 |
if not os.path.exists(snapshot_path): |
|
|
86 |
os.makedirs(snapshot_path) |
|
|
87 |
|
|
|
88 |
logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO, |
|
|
89 |
format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') |
|
|
90 |
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) |
|
|
91 |
logging.info(str(args)) |
|
|
92 |
|
|
|
93 |
def create_model(ema=False): |
|
|
94 |
# Network definition |
|
|
95 |
net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True) |
|
|
96 |
model = net.cuda() |
|
|
97 |
if ema: |
|
|
98 |
for param in model.parameters(): |
|
|
99 |
param.detach_() |
|
|
100 |
return model |
|
|
101 |
|
|
|
102 |
model = create_model() |
|
|
103 |
ema_model = create_model(ema=True) |
|
|
104 |
db_train = LAHeart(base_dir=args.root_path, |
|
|
105 |
split='train', |
|
|
106 |
transform = transforms.Compose([ |
|
|
107 |
RandomRotFlip(), |
|
|
108 |
RandomCrop(patch_size), |
|
|
109 |
ToTensor(), |
|
|
110 |
])) |
|
|
111 |
db_test = LAHeart(base_dir=args.root_path, |
|
|
112 |
split='test', |
|
|
113 |
transform = transforms.Compose([ |
|
|
114 |
CenterCrop(patch_size), |
|
|
115 |
ToTensor() |
|
|
116 |
])) |
|
|
117 |
labeled_idxs = list(range(args.labelnum)) |
|
|
118 |
unlabeled_idxs = list(range( args.labelnum, args.max_samples)) |
|
|
119 |
batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size-labeled_bs) |
|
|
120 |
def worker_init_fn(worker_id): |
|
|
121 |
random.seed(args.seed+worker_id) |
|
|
122 |
trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn) |
|
|
123 |
test_loader = DataLoader(db_test, batch_size=1,shuffle=False, num_workers=4, pin_memory=True) |
|
|
124 |
|
|
|
125 |
model.train() |
|
|
126 |
ema_model.train() |
|
|
127 |
optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001) |
|
|
128 |
|
|
|
129 |
if args.consistency_type == 'mse': |
|
|
130 |
consistency_criterion = losses.softmax_mse_loss |
|
|
131 |
elif args.consistency_type == 'kl': |
|
|
132 |
consistency_criterion = losses.softmax_kl_loss |
|
|
133 |
else: |
|
|
134 |
assert False, args.consistency_type |
|
|
135 |
|
|
|
136 |
writer = SummaryWriter(snapshot_path+'/log') |
|
|
137 |
logging.info("{} itertations per epoch".format(len(trainloader))) |
|
|
138 |
|
|
|
139 |
iter_num = 0 |
|
|
140 |
best_dice = 0 |
|
|
141 |
max_epoch = max_iterations//len(trainloader)+1 |
|
|
142 |
lr_ = base_lr |
|
|
143 |
model.train() |
|
|
144 |
for epoch_num in tqdm(range(max_epoch), ncols=70): |
|
|
145 |
time1 = time.time() |
|
|
146 |
for i_batch, sampled_batch in enumerate(trainloader): |
|
|
147 |
time2 = time.time() |
|
|
148 |
# print('fetch data cost {}'.format(time2-time1)) |
|
|
149 |
volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] |
|
|
150 |
volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() |
|
|
151 |
unlabeled_volume_batch = volume_batch[labeled_bs:] |
|
|
152 |
|
|
|
153 |
noise = torch.clamp(torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2) |
|
|
154 |
ema_inputs = unlabeled_volume_batch + noise |
|
|
155 |
outputs = model(volume_batch) |
|
|
156 |
with torch.no_grad(): |
|
|
157 |
ema_output = ema_model(ema_inputs) |
|
|
158 |
T = 8 |
|
|
159 |
volume_batch_r = unlabeled_volume_batch.repeat(2, 1, 1, 1, 1) |
|
|
160 |
stride = volume_batch_r.shape[0] // 2 |
|
|
161 |
preds = torch.zeros([stride * T, 2, 112, 112, 80]).cuda() |
|
|
162 |
for i in range(T//2): |
|
|
163 |
ema_inputs = volume_batch_r + torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2) |
|
|
164 |
with torch.no_grad(): |
|
|
165 |
preds[2 * stride * i:2 * stride * (i + 1)] = ema_model(ema_inputs) |
|
|
166 |
preds = F.softmax(preds, dim=1) |
|
|
167 |
preds = preds.reshape(T, stride, 2, 112, 112, 80) |
|
|
168 |
preds = torch.mean(preds, dim=0) #(batch, 2, 112,112,80) |
|
|
169 |
uncertainty = -1.0*torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) #(batch, 1, 112,112,80) |
|
|
170 |
|
|
|
171 |
|
|
|
172 |
## calculate the loss |
|
|
173 |
loss_seg = F.cross_entropy(outputs[:labeled_bs], label_batch[:labeled_bs]) |
|
|
174 |
outputs_soft = F.softmax(outputs, dim=1) |
|
|
175 |
loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 1, :, :, :], label_batch[:labeled_bs] == 1) |
|
|
176 |
supervised_loss = 0.5*(loss_seg+loss_seg_dice) |
|
|
177 |
|
|
|
178 |
consistency_weight = get_current_consistency_weight(iter_num//150) |
|
|
179 |
consistency_dist = consistency_criterion(outputs[labeled_bs:], ema_output) #(batch, 2, 112,112,80) |
|
|
180 |
threshold = (0.75+0.25*ramps.sigmoid_rampup(iter_num, max_iterations))*np.log(2) |
|
|
181 |
mask = (uncertainty<threshold).float() |
|
|
182 |
consistency_dist = torch.sum(mask*consistency_dist)/(2*torch.sum(mask)+1e-16) |
|
|
183 |
consistency_loss = consistency_weight * consistency_dist |
|
|
184 |
loss = supervised_loss + consistency_loss |
|
|
185 |
|
|
|
186 |
optimizer.zero_grad() |
|
|
187 |
loss.backward() |
|
|
188 |
optimizer.step() |
|
|
189 |
update_ema_variables(model, ema_model, args.ema_decay, iter_num) |
|
|
190 |
|
|
|
191 |
iter_num = iter_num + 1 |
|
|
192 |
writer.add_scalar('uncertainty/mean', uncertainty[0,0].mean(), iter_num) |
|
|
193 |
writer.add_scalar('uncertainty/max', uncertainty[0,0].max(), iter_num) |
|
|
194 |
writer.add_scalar('uncertainty/min', uncertainty[0,0].min(), iter_num) |
|
|
195 |
writer.add_scalar('uncertainty/mask_per', torch.sum(mask)/mask.numel(), iter_num) |
|
|
196 |
writer.add_scalar('uncertainty/threshold', threshold, iter_num) |
|
|
197 |
writer.add_scalar('lr', lr_, iter_num) |
|
|
198 |
writer.add_scalar('loss/loss', loss, iter_num) |
|
|
199 |
writer.add_scalar('loss/loss_seg', loss_seg, iter_num) |
|
|
200 |
writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num) |
|
|
201 |
writer.add_scalar('train/consistency_loss', consistency_loss, iter_num) |
|
|
202 |
writer.add_scalar('train/consistency_weight', consistency_weight, iter_num) |
|
|
203 |
writer.add_scalar('train/consistency_dist', consistency_dist, iter_num) |
|
|
204 |
|
|
|
205 |
logging.info('iteration %d : loss : %f cons_dist: %f, loss_weight: %f' % |
|
|
206 |
(iter_num, loss.item(), consistency_dist.item(), consistency_weight)) |
|
|
207 |
|
|
|
208 |
if iter_num % 50 == 0: |
|
|
209 |
image = volume_batch[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) |
|
|
210 |
grid_image = make_grid(image, 5, normalize=True) |
|
|
211 |
writer.add_image('train/Image', grid_image, iter_num) |
|
|
212 |
|
|
|
213 |
# image = outputs_soft[0, 3:4, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) |
|
|
214 |
image = torch.max(outputs_soft[0, :, :, :, 20:61:10], 0)[1].permute(2, 0, 1).data.cpu().numpy() |
|
|
215 |
image = utils.decode_seg_map_sequence(image) |
|
|
216 |
grid_image = make_grid(image, 5, normalize=False) |
|
|
217 |
writer.add_image('train/Predicted_label', grid_image, iter_num) |
|
|
218 |
|
|
|
219 |
image = label_batch[0, :, :, 20:61:10].permute(2, 0, 1) |
|
|
220 |
grid_image = make_grid(utils.decode_seg_map_sequence(image.data.cpu().numpy()), 5, normalize=False) |
|
|
221 |
writer.add_image('train/Groundtruth_label', grid_image, iter_num) |
|
|
222 |
|
|
|
223 |
image = uncertainty[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) |
|
|
224 |
grid_image = make_grid(image, 5, normalize=True) |
|
|
225 |
writer.add_image('train/uncertainty', grid_image, iter_num) |
|
|
226 |
|
|
|
227 |
mask2 = (uncertainty > threshold).float() |
|
|
228 |
image = mask2[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) |
|
|
229 |
grid_image = make_grid(image, 5, normalize=True) |
|
|
230 |
writer.add_image('train/mask', grid_image, iter_num) |
|
|
231 |
##### |
|
|
232 |
image = volume_batch[-1, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) |
|
|
233 |
grid_image = make_grid(image, 5, normalize=True) |
|
|
234 |
writer.add_image('unlabel/Image', grid_image, iter_num) |
|
|
235 |
|
|
|
236 |
# image = outputs_soft[-1, 3:4, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1) |
|
|
237 |
image = torch.max(outputs_soft[-1, :, :, :, 20:61:10], 0)[1].permute(2, 0, 1).data.cpu().numpy() |
|
|
238 |
image = utils.decode_seg_map_sequence(image) |
|
|
239 |
grid_image = make_grid(image, 5, normalize=False) |
|
|
240 |
writer.add_image('unlabel/Predicted_label', grid_image, iter_num) |
|
|
241 |
|
|
|
242 |
image = label_batch[-1, :, :, 20:61:10].permute(2, 0, 1) |
|
|
243 |
grid_image = make_grid(utils.decode_seg_map_sequence(image.data.cpu().numpy()), 5, normalize=False) |
|
|
244 |
writer.add_image('unlabel/Groundtruth_label', grid_image, iter_num) |
|
|
245 |
|
|
|
246 |
## change lr |
|
|
247 |
if iter_num % 2500 == 0: |
|
|
248 |
lr_ = base_lr * 0.1 ** (iter_num // 2500) |
|
|
249 |
for param_group in optimizer.param_groups: |
|
|
250 |
param_group['lr'] = lr_ |
|
|
251 |
|
|
|
252 |
if iter_num >= 800 and iter_num % 200 == 0: |
|
|
253 |
model.eval() |
|
|
254 |
with torch.no_grad(): |
|
|
255 |
dice_sample = 0 |
|
|
256 |
for sampled_batch in test_loader: |
|
|
257 |
img, lbl = sampled_batch['image'].cuda(), sampled_batch['label'].cuda() |
|
|
258 |
outputs = model(img) |
|
|
259 |
dice_once = cal_dice(outputs,lbl) |
|
|
260 |
dice_sample += dice_once |
|
|
261 |
dice_sample = dice_sample / len(test_loader) |
|
|
262 |
print('Average center dice:{:.3f}'.format(dice_sample)) |
|
|
263 |
|
|
|
264 |
if dice_sample > best_dice: |
|
|
265 |
best_dice = dice_sample |
|
|
266 |
save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, best_dice)) |
|
|
267 |
save_best_path = os.path.join(snapshot_path,'{}_best_model.pth'.format(args.model)) |
|
|
268 |
torch.save(model.state_dict(), save_mode_path) |
|
|
269 |
torch.save(model.state_dict(), save_best_path) |
|
|
270 |
logging.info("save best model to {}".format(save_mode_path)) |
|
|
271 |
writer.add_scalar('Var_dice/Dice', dice_sample, iter_num) |
|
|
272 |
writer.add_scalar('Var_dice/Best_dice', best_dice, iter_num) |
|
|
273 |
model.train() |
|
|
274 |
|
|
|
275 |
if iter_num >= max_iterations: |
|
|
276 |
break |
|
|
277 |
time1 = time.time() |
|
|
278 |
if iter_num >= max_iterations: |
|
|
279 |
break |
|
|
280 |
save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations)+'.pth') |
|
|
281 |
torch.save(model.state_dict(), save_mode_path) |
|
|
282 |
logging.info("save model to {}".format(save_mode_path)) |
|
|
283 |
writer.close() |