a b/code/detectorDNN.lua
1
require 'torch';
2
require 'nn';
3
require 'image';
4
cuda = false
5
require 'cutorch';
6
require 'cunn';
7
require 'sys';
8
require 'csvigo';
9
local dir = require 'pl.dir';
10
11
local c = os.clock()
12
local t = os.time()
13
14
local folder = '/home/andrew/mitosis/data/MITOS/testing/'
15
--local netPath1 = '/home/andrew/mitosis/data/nets/net.t7'
16
local netPath1 = '/home/andrew/mitosis/data/nets/dnn1_fullset_aug_20i_lr05_lrd0005_m09_mini200_aeptgl.t7'
17
local netPath2 = '/home/andrew/mitosis/data/nets/dnn2_fullset_aug_20i_lr05_lrd0005_m09_mini200_aeptgl.t7'
18
19
local threshold = 0.1
20
21
dofile("getImagePaths.lua")
22
local imagePaths = getImagePaths(folder)
23
24
if paths.dirp(folder .. 'results') == false then
25
    paths.mkdir(folder .. 'results');
26
end
27
28
dofile("scan.lua")
29
dofile("expand.lua")
30
31
-- define circular kernel
32
local d = 10
33
local kernel = torch.FloatTensor(2*d+1,2*d+1):fill(0.001)
34
for i=1,2*d+1 do
35
    for j=1,2*d+1 do
36
        if (i-d-1)^2 + (j-d-1)^2 >= d^2 then
37
            kernel[i][j] = 0
38
        end
39
    end
40
end
41
42
for k,imagePath in ipairs(imagePaths) do
43
    print(k)
44
45
    local c1 = os.clock()
46
    local t1 = os.time()
47
48
    --local windowWidth = 101
49
    --local windowHeight = 101
50
51
    local net1 = torch.load(netPath1)
52
    local net2 = torch.load(netPath2)
53
    net1 = expand(net1)
54
    net2 = expand(net2)
55
    if cuda then
56
        net1 = net1:cuda()
57
        net2 = net2:cuda()
58
    else
59
        net1 = net1:float()
60
        net2 = net2:float()
61
    end
62
63
    local img = image.load(imagePath, 3, 'float')
64
    if cuda then
65
        img = img:cuda()
66
    else
67
        img = img:float()
68
    end
69
70
    -- scan sixteen versions of the image
71
    -- four rotations, paired with their reflections, and two neural nets
72
    --]]
73
    --[
74
    local maps = {}
75
    for i=1,4 do
76
        for j=1,2 do
77
            local tmp = image.rotate(img, (i-1)*math.pi/2)
78
            if j == 2 then
79
                tmp = image.hflip(tmp)
80
                tmp = scan(tmp, net1)
81
                tmp = image.hflip(tmp)
82
            else
83
                tmp = scan(tmp, net1)
84
            end
85
            maps[(i-1)*2+j] = image.rotate(tmp, -(i-1)*math.pi/2)
86
        end
87
    end
88
    net1 = nil
89
--[
90
    for i=5,8 do    
91
        for j=1,2 do
92
            local tmp = image.rotate(img, (i-1)*math.pi/2)
93
            if j == 2 then
94
                tmp = image.hflip(tmp)
95
                tmp = scan(tmp, net2)
96
                tmp = image.hflip(tmp)
97
            else
98
                tmp = scan(tmp, net2)
99
            end
100
            maps[(i-1)*2+j] = image.rotate(tmp, -(i-1)*math.pi/2)
101
        end
102
    end
103
    net2 = nil
104
--]]
105
106
    -- take the mean of the sixteen maps
107
    --[
108
    local sum = maps[1]
109
    for i=2,#maps do
110
        sum = sum + maps[i]
111
    end
112
    local map = sum/#maps
113
    
114
    map = image.convolve(map, kernel, 'same')
115
    map = map/torch.max(map)                    -- normalize
116
117
    image.save('test.png',map)
118
    --map = image.load('test.png',1,'float')
119
120
    local results = {}
121
    local m = 1
122
    while true do
123
        local ind = map:eq(map:max()):nonzero()
124
        local row = ind[1][1]
125
        local col = ind[1][2]
126
        local val = map[row][col]
127
        results[m] = {row, col, val}
128
        m = m + 1
129
        for i=-2*d,2*d do
130
            for j=-2*d,2*d do
131
                if i^2 + j^2 < 4*d^2 and row+i>0 and col+j>0 and row+i<=map:size(1) and col+j<=map:size(2) then
132
                    map[row+i][col+j] = 0
133
                end
134
            end
135
        end
136
        if val < threshold then
137
            break
138
        end
139
    end
140
141
    local outfile = paths.concat(folder, 'results', paths.basename(imagePath, paths.extname(imagePath)) .. '.csv')
142
    csvigo.save(outfile, results)
143
144
    print(os.clock()-c1)
145
    print(os.time()-t1)
146
end
147
148
print(os.clock()-c)
149
print(os.time()-t)