|
a |
|
b/utils.py |
|
|
1 |
""" Utilities """ |
|
|
2 |
import os |
|
|
3 |
import logging |
|
|
4 |
import shutil |
|
|
5 |
import torch |
|
|
6 |
import torchvision.datasets as dset |
|
|
7 |
import numpy as np |
|
|
8 |
import torch.nn as nn |
|
|
9 |
device = torch.device("cuda") |
|
|
10 |
|
|
|
11 |
def get_logger(file_path): |
|
|
12 |
""" Make python logger """ |
|
|
13 |
# [!] Since tensorboardX use default logger (e.g. logging.info()), we should use custom logger |
|
|
14 |
logger = logging.getLogger('darts') |
|
|
15 |
log_format = '%(asctime)s | %(message)s' |
|
|
16 |
formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p') |
|
|
17 |
file_handler = logging.FileHandler(file_path) |
|
|
18 |
file_handler.setFormatter(formatter) |
|
|
19 |
stream_handler = logging.StreamHandler() |
|
|
20 |
stream_handler.setFormatter(formatter) |
|
|
21 |
|
|
|
22 |
logger.addHandler(file_handler) |
|
|
23 |
logger.addHandler(stream_handler) |
|
|
24 |
logger.setLevel(logging.INFO) |
|
|
25 |
|
|
|
26 |
return logger |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def param_size(model): |
|
|
30 |
""" Compute parameter size in MB """ |
|
|
31 |
n_params = sum( |
|
|
32 |
np.prod(v.size()) for k, v in model.named_parameters() if not k.startswith('aux_head')) |
|
|
33 |
return n_params / 1024. / 1024. |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
class AverageMeter(): |
|
|
37 |
""" Computes and stores the average and current value """ |
|
|
38 |
def __init__(self): |
|
|
39 |
self.reset() |
|
|
40 |
|
|
|
41 |
def reset(self): |
|
|
42 |
""" Reset all statistics """ |
|
|
43 |
self.val = 0 |
|
|
44 |
self.avg = 0 |
|
|
45 |
self.sum = 0 |
|
|
46 |
self.count = 1 # avoid the count of some calsses in the first batch is zero |
|
|
47 |
|
|
|
48 |
def update(self, val, n): |
|
|
49 |
""" Update statistics """ |
|
|
50 |
self.val = val |
|
|
51 |
self.sum += val * n |
|
|
52 |
self.count += n |
|
|
53 |
self.avg = self.sum / self.count |
|
|
54 |
|
|
|
55 |
def evaluate(logits, label): |
|
|
56 |
logits = logits.astype(np.float32) |
|
|
57 |
label = label.astype(np.float32) |
|
|
58 |
inter = np.dot(logits.flatten(), label.flatten()) |
|
|
59 |
union = np.sum(logits) + np.sum(label) |
|
|
60 |
dice = (2 * inter + 1e-5) / (union + 1e-5) |
|
|
61 |
return dice |
|
|
62 |
|
|
|
63 |
def save_results(results, path): |
|
|
64 |
|
|
|
65 |
filename = os.path.join(path, 'final_results.txt') |
|
|
66 |
f = open(filename, 'a') |
|
|
67 |
f.write('Best dice: {:.5f}\n'.format(results)) |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
def save_checkpoint(state, ckpt_dir, is_best=False): |
|
|
71 |
filename = os.path.join(ckpt_dir, 'checkpoint.pth.tar') |
|
|
72 |
torch.save(state, filename) |
|
|
73 |
if is_best: |
|
|
74 |
best_filename = os.path.join(ckpt_dir, 'best.pth.tar') |
|
|
75 |
shutil.copyfile(filename, best_filename) |
|
|
76 |
|
|
|
77 |
class log_loss(nn.Module): |
|
|
78 |
def __init__(self, w_dice = 0.5, w_cross = 0.5): |
|
|
79 |
super(log_loss, self).__init__() |
|
|
80 |
self.w_dice = w_dice |
|
|
81 |
self.w_cross = w_cross |
|
|
82 |
def forward(self, logits, label, smooth = 1.): |
|
|
83 |
|
|
|
84 |
area_union = torch.sum(logits * label, dim = (0,2,3), keepdim = True) |
|
|
85 |
area_logits = torch.sum(logits, dim = (0,2,3), keepdim = True) |
|
|
86 |
area_label = torch.sum(label, dim = (0,2,3), keepdim = True) |
|
|
87 |
in_dice = torch.mean(torch.pow((-1) * torch.log((2 * area_union + 1e-7)/(area_logits + area_label + smooth)), 0.3)) |
|
|
88 |
return in_dice |