|
a |
|
b/train_infer.py |
|
|
1 |
""" Training augmented model """ |
|
|
2 |
import os |
|
|
3 |
import torch |
|
|
4 |
import torch.nn as nn |
|
|
5 |
import numpy as np |
|
|
6 |
from tensorboardX import SummaryWriter |
|
|
7 |
from ptflops import get_model_complexity_info |
|
|
8 |
import utils |
|
|
9 |
import data_generator_3D as data_generator_3D |
|
|
10 |
import time |
|
|
11 |
import SimpleITK as sitk |
|
|
12 |
import sys |
|
|
13 |
from config import TrainConfig |
|
|
14 |
from model import LCOVNet |
|
|
15 |
from apex import amp |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
config = TrainConfig() |
|
|
19 |
|
|
|
20 |
device = torch.device("cuda") |
|
|
21 |
|
|
|
22 |
# tensorboard |
|
|
23 |
writer = SummaryWriter(log_dir=os.path.join(config.path, "tb")) |
|
|
24 |
writer.add_text('config', config.as_markdown(), 0) |
|
|
25 |
|
|
|
26 |
logger = utils.get_logger(os.path.join(config.path, "{}.log".format(config.name))) |
|
|
27 |
config.print_params(logger.info) |
|
|
28 |
|
|
|
29 |
def main(): |
|
|
30 |
logger.info("Logger is set - training start") |
|
|
31 |
|
|
|
32 |
# set default gpu device id |
|
|
33 |
torch.cuda.set_device(config.gpus[0]) |
|
|
34 |
|
|
|
35 |
# set seed |
|
|
36 |
np.random.seed(config.seed) |
|
|
37 |
torch.manual_seed(config.seed) |
|
|
38 |
torch.cuda.manual_seed_all(config.seed) |
|
|
39 |
|
|
|
40 |
torch.backends.cudnn.benchmark = True |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
criterion = utils.log_loss().to(device) |
|
|
44 |
d = torch.device(type='cuda', index=config.gpus[0]) |
|
|
45 |
model = LCOVNet(config.input_channels, config.n_classes).to(device=d) |
|
|
46 |
with torch.cuda.device(config.gpus[0]): |
|
|
47 |
net = model |
|
|
48 |
macs, params = get_model_complexity_info(net, (1, 240, 160, 48), as_strings=True, |
|
|
49 |
print_per_layer_stat=True, verbose=True) |
|
|
50 |
logger.info("{:<30} {:<8}".format('Computational complexity: ', macs)) |
|
|
51 |
logger.info("{:<30} {:<8}".format('Number of parameters: ', params)) |
|
|
52 |
|
|
|
53 |
# model size |
|
|
54 |
mb_params = utils.param_size(model) |
|
|
55 |
logger.info("Model size = {:.3f} MB".format(mb_params)) |
|
|
56 |
# weights optimizer |
|
|
57 |
optimizer = torch.optim.SGD(model.parameters(), config.lr, momentum=config.momentum, |
|
|
58 |
weight_decay=config.weight_decay) |
|
|
59 |
model, optimizer = amp.initialize(model, optimizer, opt_level="O1") |
|
|
60 |
|
|
|
61 |
train_loader = data_generator_3D.Covid19TrainSet() |
|
|
62 |
valid_loader = data_generator_3D.Covid19EvalSet() |
|
|
63 |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs) |
|
|
64 |
|
|
|
65 |
best_dice = 0. |
|
|
66 |
# training loop |
|
|
67 |
summ_writer = SummaryWriter(config.training_summary_dir) |
|
|
68 |
for epoch in range(config.epochs): |
|
|
69 |
|
|
|
70 |
# training |
|
|
71 |
train(train_loader, model, optimizer, criterion, epoch, summ_writer) |
|
|
72 |
lr_scheduler.step() |
|
|
73 |
# validation |
|
|
74 |
cur_step = (epoch+1) * len(train_loader) |
|
|
75 |
mean_dice = validate(valid_loader, model, criterion, epoch, summ_writer, best_dice) |
|
|
76 |
|
|
|
77 |
# save |
|
|
78 |
if best_dice < mean_dice: |
|
|
79 |
best_dice = mean_dice |
|
|
80 |
is_best = True |
|
|
81 |
else: |
|
|
82 |
is_best = False |
|
|
83 |
utils.save_checkpoint(model, config.path, is_best) |
|
|
84 |
print("") |
|
|
85 |
|
|
|
86 |
logger.info("Final best Dice = {:.4%}".format(best_dice)) |
|
|
87 |
utils.save_results(best_dice, config.path) |
|
|
88 |
summ_writer.close() |
|
|
89 |
|
|
|
90 |
def train(train_loader, model, optimizer, criterion, epoch, summ_writer): |
|
|
91 |
losses = utils.AverageMeter() |
|
|
92 |
cur_step = epoch*len(train_loader) |
|
|
93 |
cur_lr = optimizer.param_groups[0]['lr'] |
|
|
94 |
logger.info("Epoch {} LR {}".format(epoch, cur_lr)) |
|
|
95 |
writer.add_scalar('train/lr', cur_lr, cur_step) |
|
|
96 |
model.train() |
|
|
97 |
#all_dice = np.empty().astype(np.float32) |
|
|
98 |
all_dice = [] |
|
|
99 |
for step, (name, X, y) in enumerate(train_loader): |
|
|
100 |
X, y = torch.from_numpy(X).to(device, non_blocking=True), torch.from_numpy(y).to(device, non_blocking=True) |
|
|
101 |
N = X.size(0) |
|
|
102 |
|
|
|
103 |
optimizer.zero_grad() |
|
|
104 |
logits = model(X) |
|
|
105 |
|
|
|
106 |
loss = criterion(logits, y) |
|
|
107 |
#loss.backward() |
|
|
108 |
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
|
109 |
scaled_loss.backward() |
|
|
110 |
# gradient clipping |
|
|
111 |
nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) |
|
|
112 |
optimizer.step() |
|
|
113 |
|
|
|
114 |
losses.update(loss.item(), N) |
|
|
115 |
|
|
|
116 |
if step % config.print_freq == 0 or step == len(train_loader)-1: |
|
|
117 |
logger.info( |
|
|
118 |
"Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {:.3f} ".format( |
|
|
119 |
epoch+1, config.epochs, step, len(train_loader), losses.avg, |
|
|
120 |
)) |
|
|
121 |
|
|
|
122 |
writer.add_scalar('train/loss', loss.item(), cur_step) |
|
|
123 |
|
|
|
124 |
logits[logits >= 0.5] = 1 |
|
|
125 |
logits[logits < 0.5] = 0 |
|
|
126 |
predict = logits.cpu().detach().numpy() |
|
|
127 |
y = y.cpu().detach().numpy() |
|
|
128 |
dice_i = utils.evaluate(predict, y) |
|
|
129 |
all_dice.append(dice_i) |
|
|
130 |
cur_step += 1 |
|
|
131 |
dice_mean = 0 |
|
|
132 |
|
|
|
133 |
for i in all_dice: |
|
|
134 |
dice_mean += i/len(all_dice) |
|
|
135 |
|
|
|
136 |
train_avg_loss = losses.avg |
|
|
137 |
train_avg_dice = dice_mean |
|
|
138 |
loss_scalers = {'train': train_avg_loss} |
|
|
139 |
summ_writer.add_scalars('loss', loss_scalers, epoch + 1) |
|
|
140 |
|
|
|
141 |
dice_scalers = {'train': train_avg_dice} |
|
|
142 |
summ_writer.add_scalars('avg_dice', dice_scalers, epoch + 1) |
|
|
143 |
|
|
|
144 |
if (epoch+1) % 50 == 0: |
|
|
145 |
chpt_prefx = config.training_checkpoint_prefix |
|
|
146 |
save_dict = {'epoch': epoch + 1, |
|
|
147 |
'model_state_dict': model.state_dict(), |
|
|
148 |
'optimizer_state_dict': optimizer.state_dict(), |
|
|
149 |
'amp': amp.state_dict()} |
|
|
150 |
save_name = "{0:}_{1:}.pt".format(chpt_prefx, epoch + 1) |
|
|
151 |
torch.save(save_dict, save_name) |
|
|
152 |
print("train_avg_loss", train_avg_loss) |
|
|
153 |
print("train_avg_dice", train_avg_dice) |
|
|
154 |
|
|
|
155 |
def validate(valid_loader, model, criterion, epoch, summ_writer, best_dice): |
|
|
156 |
losses = utils.AverageMeter() |
|
|
157 |
|
|
|
158 |
model.eval() |
|
|
159 |
all_dice = np.zeros([len(valid_loader)]).astype(np.float32) |
|
|
160 |
all_dice = [] |
|
|
161 |
totel_time = 0 |
|
|
162 |
start_time = time.time() |
|
|
163 |
size_z = 48 |
|
|
164 |
with torch.no_grad(): |
|
|
165 |
for i, (name, image, label) in enumerate(valid_loader): |
|
|
166 |
image = torch.from_numpy(image) |
|
|
167 |
predict = np.zeros(shape=label.shape, dtype=label.dtype) |
|
|
168 |
z = image.shape[4] |
|
|
169 |
m = z // size_z if z % size_z == 0 else z // size_z + 1 |
|
|
170 |
start_time = time.time() |
|
|
171 |
for k in range(m): |
|
|
172 |
if (k+1)*size_z <= z: |
|
|
173 |
max_z = (k+1)*size_z |
|
|
174 |
else: |
|
|
175 |
max_z = z |
|
|
176 |
min_z = max_z - size_z |
|
|
177 |
image_k = image[:, :, :, :, min_z:max_z].float().to(device, non_blocking=True) |
|
|
178 |
predict_k = model(image_k) |
|
|
179 |
predict_k[predict_k >= 0.5] = 1 |
|
|
180 |
predict_k[predict_k < 0.5] = 0 |
|
|
181 |
predict[:, :, :, :, min_z:max_z] = predict_k.cpu().detach().numpy() |
|
|
182 |
totel_time = totel_time + time.time() - start_time |
|
|
183 |
all_dice.append(utils.evaluate(predict, label)) |
|
|
184 |
|
|
|
185 |
dice_len = len(all_dice) |
|
|
186 |
dice_np = np.empty(shape=[dice_len]) |
|
|
187 |
#list_image = [] |
|
|
188 |
for i in range(dice_len): |
|
|
189 |
dice_np[i] = all_dice[i] |
|
|
190 |
logger.info("{} dice: {:.4%} ".format(i, all_dice[i])) |
|
|
191 |
logger.info("mean: {}".format(dice_np.mean())) |
|
|
192 |
logger.info("std : {}".format(dice_np.std())) |
|
|
193 |
|
|
|
194 |
if best_dice < dice_np.mean(): |
|
|
195 |
chpt_prefx = config.validing_checkpoint_prefix |
|
|
196 |
save_dict = {'epoch': epoch + 1, |
|
|
197 |
'model_state_dict': model.state_dict(), |
|
|
198 |
'amp': amp.state_dict()} |
|
|
199 |
fname = "{}/best.pt".format(chpt_prefx) |
|
|
200 |
if os.path.isfile(fname): |
|
|
201 |
os.remove(fname) |
|
|
202 |
save_name = "{}/best.pt".format(chpt_prefx) |
|
|
203 |
torch.save(save_dict, save_name) |
|
|
204 |
|
|
|
205 |
dice_scalers = {'vadil': dice_np.mean()} |
|
|
206 |
summ_writer.add_scalars('vadil_avg_dice', dice_scalers, epoch + 1) |
|
|
207 |
|
|
|
208 |
avg_time = totel_time / dice_len |
|
|
209 |
logger.info("average testing time : {}".format(avg_time)) |
|
|
210 |
|
|
|
211 |
mean_dice = np.mean(all_dice, axis = 0) |
|
|
212 |
writer.add_scalar('val/dice', mean_dice, epoch) |
|
|
213 |
writer.add_scalar('val/loss', losses.avg, epoch) |
|
|
214 |
logger.info("Valid: [{:2d}/{}] average dice: {:.4%} ".format(epoch+1, config.epochs, mean_dice)) |
|
|
215 |
|
|
|
216 |
return mean_dice |
|
|
217 |
|
|
|
218 |
|
|
|
219 |
|
|
|
220 |
def save_nd_array_as_image(data, image_name, reference_name = None): |
|
|
221 |
""" |
|
|
222 |
save a 3D or 2D numpy array as medical image or RGB image |
|
|
223 |
inputs: |
|
|
224 |
data: a numpy array with shape [D, H, W] or [C, H, W] |
|
|
225 |
image_name: the output file name |
|
|
226 |
outputs: None |
|
|
227 |
""" |
|
|
228 |
data_dim = len(data.shape) |
|
|
229 |
assert(data_dim == 2 or data_dim == 3) |
|
|
230 |
if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or |
|
|
231 |
image_name.endswith(".mha")): |
|
|
232 |
assert(data_dim == 3) |
|
|
233 |
save_array_as_nifty_volume(data, image_name, reference_name) |
|
|
234 |
|
|
|
235 |
def save_array_as_nifty_volume(data, image_name, reference_name = None): |
|
|
236 |
""" |
|
|
237 |
save a numpy array as nifty image |
|
|
238 |
inputs: |
|
|
239 |
data: a numpy array with shape [Depth, Height, Width] |
|
|
240 |
image_name: the ouput file name |
|
|
241 |
reference_name: file name of the reference image of which affine and header are used |
|
|
242 |
outputs: None |
|
|
243 |
""" |
|
|
244 |
img = sitk.GetImageFromArray(data) |
|
|
245 |
if(reference_name is not None): |
|
|
246 |
img_ref = sitk.ReadImage(reference_name) |
|
|
247 |
img.CopyInformation(img_ref) |
|
|
248 |
sitk.WriteImage(img, image_name) |
|
|
249 |
|
|
|
250 |
|
|
|
251 |
if __name__ == "__main__": |
|
|
252 |
main() |