Diff of /utils.py [000000] .. [2162c1]

Switch to unified view

a b/utils.py
1
""" Utilities """
2
import os
3
import logging
4
import shutil
5
import torch
6
import torchvision.datasets as dset
7
import numpy as np
8
import torch.nn as nn
9
device = torch.device("cuda")
10
11
def get_logger(file_path):
12
    """ Make python logger """
13
    # [!] Since tensorboardX use default logger (e.g. logging.info()), we should use custom logger
14
    logger = logging.getLogger('darts')
15
    log_format = '%(asctime)s | %(message)s'
16
    formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
17
    file_handler = logging.FileHandler(file_path)
18
    file_handler.setFormatter(formatter)
19
    stream_handler = logging.StreamHandler()
20
    stream_handler.setFormatter(formatter)
21
22
    logger.addHandler(file_handler)
23
    logger.addHandler(stream_handler)
24
    logger.setLevel(logging.INFO)
25
26
    return logger
27
28
29
def param_size(model):
30
    """ Compute parameter size in MB """
31
    n_params = sum(
32
        np.prod(v.size()) for k, v in model.named_parameters() if not k.startswith('aux_head'))
33
    return n_params / 1024. / 1024.
34
35
36
class AverageMeter():
37
    """ Computes and stores the average and current value """
38
    def __init__(self):
39
        self.reset()
40
41
    def reset(self):
42
        """ Reset all statistics """
43
        self.val = 0
44
        self.avg = 0
45
        self.sum = 0
46
        self.count = 1 # avoid the count of some calsses in the first batch is zero
47
48
    def update(self, val, n):
49
        """ Update statistics """
50
        self.val = val
51
        self.sum += val * n
52
        self.count += n
53
        self.avg = self.sum / self.count
54
55
def evaluate(logits, label):
56
    logits = logits.astype(np.float32)
57
    label = label.astype(np.float32)
58
    inter = np.dot(logits.flatten(), label.flatten())
59
    union = np.sum(logits) + np.sum(label)
60
    dice = (2 * inter + 1e-5) / (union + 1e-5)
61
    return dice
62
63
def save_results(results, path):
64
    
65
    filename = os.path.join(path, 'final_results.txt')
66
    f = open(filename, 'a')
67
    f.write('Best dice: {:.5f}\n'.format(results))
68
    
69
70
def save_checkpoint(state, ckpt_dir, is_best=False):
71
    filename = os.path.join(ckpt_dir, 'checkpoint.pth.tar')
72
    torch.save(state, filename)
73
    if is_best:
74
        best_filename = os.path.join(ckpt_dir, 'best.pth.tar')
75
        shutil.copyfile(filename, best_filename)
76
        
77
class log_loss(nn.Module):
78
    def __init__(self, w_dice = 0.5, w_cross = 0.5):
79
        super(log_loss, self).__init__()
80
        self.w_dice = w_dice
81
        self.w_cross = w_cross
82
    def forward(self, logits, label, smooth = 1.):
83
        
84
        area_union = torch.sum(logits * label, dim = (0,2,3), keepdim = True)
85
        area_logits = torch.sum(logits, dim = (0,2,3), keepdim = True)
86
        area_label = torch.sum(label, dim = (0,2,3), keepdim = True)
87
        in_dice = torch.mean(torch.pow((-1) * torch.log((2 * area_union + 1e-7)/(area_logits + area_label + smooth)), 0.3))
88
        return in_dice