Diff of /quicknat.py [000000] .. [6f9c00]

Switch to unified view

a b/quicknat.py
1
"""Quicknat architecture"""
2
import numpy as np
3
import torch
4
import torch.nn as nn
5
from nn_common_modules import modules as sm
6
from squeeze_and_excitation import squeeze_and_excitation as se
7
8
class QuickNat(nn.Module):
9
    """
10
    A PyTorch implementation of QuickNAT
11
12
    """
13
    def __init__(self, params):
14
        """
15
16
        :param params: {'num_channels':1,
17
                        'num_filters':64,
18
                        'kernel_h':5,
19
                        'kernel_w':5,
20
                        'stride_conv':1,
21
                        'pool':2,
22
                        'stride_pool':2,
23
                        'num_classes':28
24
                        'se_block': False,
25
                        'drop_out':0.2}
26
        """
27
        super(QuickNat, self).__init__()
28
        print(se.SELayer(params['se_block']))
29
        self.encode1 = sm.EncoderBlock(params, se_block_type=params['se_block'])
30
        params['num_channels'] = params['num_filters']
31
        self.encode2 = sm.EncoderBlock(params, se_block_type=params['se_block'])
32
        self.encode3 = sm.EncoderBlock(params, se_block_type=params['se_block'])
33
        self.encode4 = sm.EncoderBlock(params, se_block_type=params['se_block'])
34
        self.bottleneck = sm.DenseBlock(params, se_block_type=params['se_block'])
35
        params['num_channels'] = params['num_filters'] * 2
36
        self.decode1 = sm.DecoderBlock(params, se_block_type=params['se_block'])
37
        self.decode2 = sm.DecoderBlock(params, se_block_type=params['se_block'])
38
        self.decode3 = sm.DecoderBlock(params, se_block_type=params['se_block'])
39
        self.decode4 = sm.DecoderBlock(params, se_block_type=params['se_block'])
40
        params['num_channels'] = params['num_filters']
41
        self.classifier = sm.ClassifierBlock(params)
42
43
    def forward(self, input):
44
        """
45
46
        :param input: X
47
        :return: probabiliy map
48
        """
49
        e1, out1, ind1 = self.encode1.forward(input)
50
        e2, out2, ind2 = self.encode2.forward(e1)
51
        e3, out3, ind3 = self.encode3.forward(e2)
52
        e4, out4, ind4 = self.encode4.forward(e3)
53
54
        bn = self.bottleneck.forward(e4)
55
56
        d4 = self.decode4.forward(bn, out4, ind4)
57
        d3 = self.decode1.forward(d4, out3, ind3)
58
        d2 = self.decode2.forward(d3, out2, ind2)
59
        d1 = self.decode3.forward(d2, out1, ind1)
60
        prob = self.classifier.forward(d1)
61
62
        return prob
63
64
    def enable_test_dropout(self):
65
        """
66
        Enables test time drop out for uncertainity
67
        :return:
68
        """
69
        attr_dict = self.__dict__['_modules']
70
        for i in range(1, 5):
71
            encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)]
72
            encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train)
73
            decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train)
74
75
    @property
76
    def is_cuda(self):
77
        """
78
        Check if model parameters are allocated on the GPU.
79
        """
80
        return next(self.parameters()).is_cuda
81
82
    def save(self, path):
83
        """
84
        Save model with its parameters to the given path. Conventionally the
85
        path should end with '*.model'.
86
87
        Inputs:
88
        - path: path string
89
        """
90
        print('Saving model... %s' % path)
91
        torch.save(self, path)
92
93
    def predict(self, X, device=0, enable_dropout=False, out_prob=False):
94
        """
95
        Predicts the outout after the model is trained.
96
        Inputs:
97
        - X: Volume to be predicted
98
        """
99
        self.eval()
100
101
        if type(X) is np.ndarray:
102
            X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True)
103
        elif type(X) is torch.Tensor and not X.is_cuda:
104
            X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
105
106
        if enable_dropout:
107
            self.enable_test_dropout()
108
109
        with torch.no_grad():
110
            out = self.forward(X)
111
112
        if out_prob:
113
            return out
114
        else:
115
            max_val, idx = torch.max(out, 1)
116
            idx = idx.data.cpu().numpy()
117
            prediction = np.squeeze(idx)
118
            del X, out, idx, max_val
119
            return prediction