--- a +++ b/quicknat.py @@ -0,0 +1,119 @@ +"""Quicknat architecture""" +import numpy as np +import torch +import torch.nn as nn +from nn_common_modules import modules as sm +from squeeze_and_excitation import squeeze_and_excitation as se + +class QuickNat(nn.Module): + """ + A PyTorch implementation of QuickNAT + + """ + def __init__(self, params): + """ + + :param params: {'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':28 + 'se_block': False, + 'drop_out':0.2} + """ + super(QuickNat, self).__init__() + print(se.SELayer(params['se_block'])) + self.encode1 = sm.EncoderBlock(params, se_block_type=params['se_block']) + params['num_channels'] = params['num_filters'] + self.encode2 = sm.EncoderBlock(params, se_block_type=params['se_block']) + self.encode3 = sm.EncoderBlock(params, se_block_type=params['se_block']) + self.encode4 = sm.EncoderBlock(params, se_block_type=params['se_block']) + self.bottleneck = sm.DenseBlock(params, se_block_type=params['se_block']) + params['num_channels'] = params['num_filters'] * 2 + self.decode1 = sm.DecoderBlock(params, se_block_type=params['se_block']) + self.decode2 = sm.DecoderBlock(params, se_block_type=params['se_block']) + self.decode3 = sm.DecoderBlock(params, se_block_type=params['se_block']) + self.decode4 = sm.DecoderBlock(params, se_block_type=params['se_block']) + params['num_channels'] = params['num_filters'] + self.classifier = sm.ClassifierBlock(params) + + def forward(self, input): + """ + + :param input: X + :return: probabiliy map + """ + e1, out1, ind1 = self.encode1.forward(input) + e2, out2, ind2 = self.encode2.forward(e1) + e3, out3, ind3 = self.encode3.forward(e2) + e4, out4, ind4 = self.encode4.forward(e3) + + bn = self.bottleneck.forward(e4) + + d4 = self.decode4.forward(bn, out4, ind4) + d3 = self.decode1.forward(d4, out3, ind3) + d2 = self.decode2.forward(d3, out2, ind2) + d1 = self.decode3.forward(d2, out1, ind1) + prob = self.classifier.forward(d1) + + return prob + + def enable_test_dropout(self): + """ + Enables test time drop out for uncertainity + :return: + """ + attr_dict = self.__dict__['_modules'] + for i in range(1, 5): + encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)] + encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train) + decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train) + + @property + def is_cuda(self): + """ + Check if model parameters are allocated on the GPU. + """ + return next(self.parameters()).is_cuda + + def save(self, path): + """ + Save model with its parameters to the given path. Conventionally the + path should end with '*.model'. + + Inputs: + - path: path string + """ + print('Saving model... %s' % path) + torch.save(self, path) + + def predict(self, X, device=0, enable_dropout=False, out_prob=False): + """ + Predicts the outout after the model is trained. + Inputs: + - X: Volume to be predicted + """ + self.eval() + + if type(X) is np.ndarray: + X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True) + elif type(X) is torch.Tensor and not X.is_cuda: + X = X.type(torch.FloatTensor).cuda(device, non_blocking=True) + + if enable_dropout: + self.enable_test_dropout() + + with torch.no_grad(): + out = self.forward(X) + + if out_prob: + return out + else: + max_val, idx = torch.max(out, 1) + idx = idx.data.cpu().numpy() + prediction = np.squeeze(idx) + del X, out, idx, max_val + return prediction