|
a |
|
b/hippodeep.py |
|
|
1 |
import torch |
|
|
2 |
import nibabel |
|
|
3 |
import numpy as np |
|
|
4 |
import os, sys, time |
|
|
5 |
import scipy.ndimage |
|
|
6 |
import torch.nn as nn |
|
|
7 |
import torch.nn.functional as F |
|
|
8 |
from numpy.linalg import inv |
|
|
9 |
try: |
|
|
10 |
import resource |
|
|
11 |
except: |
|
|
12 |
pass |
|
|
13 |
|
|
|
14 |
# monkey-patch for back-compatibility with older (~1.0.0) torch |
|
|
15 |
try: |
|
|
16 |
import inspect |
|
|
17 |
if not "align_corners" in inspect.signature(F.grid_sample).parameters: |
|
|
18 |
old_grid_sample = torch.nn.functional.grid_sample |
|
|
19 |
F.grid_sample = lambda *x, **k : old_grid_sample(*x) |
|
|
20 |
except: |
|
|
21 |
pass |
|
|
22 |
|
|
|
23 |
if len(sys.argv[1:]) == 0: |
|
|
24 |
print("Need to pass one or more T1 image filename as argument") |
|
|
25 |
sys.exit(1) |
|
|
26 |
|
|
|
27 |
print("Using all available CPU threads") |
|
|
28 |
if 0: # otherwise, set a limit (useful for running multiple instances) |
|
|
29 |
torch.set_num_threads(4) |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
class HeadModel(nn.Module): |
|
|
33 |
def __init__(self): |
|
|
34 |
super(HeadModel, self).__init__() |
|
|
35 |
self.conv0a = nn.Conv3d(1, 8, 3, padding=1) |
|
|
36 |
self.conv0b = nn.Conv3d(8, 8, 3, padding=1) |
|
|
37 |
self.bn0a = nn.BatchNorm3d(8) |
|
|
38 |
|
|
|
39 |
self.ma1 = nn.MaxPool3d(2) |
|
|
40 |
self.conv1a = nn.Conv3d(8, 16, 3, padding=1) |
|
|
41 |
self.conv1b = nn.Conv3d(16, 24, 3, padding=1) |
|
|
42 |
self.bn1a = nn.BatchNorm3d(24) |
|
|
43 |
|
|
|
44 |
self.ma2 = nn.MaxPool3d(2) |
|
|
45 |
self.conv2a = nn.Conv3d(24, 24, 3, padding=1) |
|
|
46 |
self.conv2b = nn.Conv3d(24, 32, 3, padding=1) |
|
|
47 |
self.bn2a = nn.BatchNorm3d(32) |
|
|
48 |
|
|
|
49 |
self.ma3 = nn.MaxPool3d(2) |
|
|
50 |
self.conv3a = nn.Conv3d(32, 48, 3, padding=1) |
|
|
51 |
self.conv3b = nn.Conv3d(48, 48, 3, padding=1) |
|
|
52 |
self.bn3a = nn.BatchNorm3d(48) |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
self.conv2u = nn.Conv3d(48, 24, 3, padding=1) |
|
|
56 |
self.conv2v = nn.Conv3d(24+32, 24, 3, padding=1) |
|
|
57 |
self.bn2u = nn.BatchNorm3d(24) |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
self.conv1u = nn.Conv3d(24, 24, 3, padding=1) |
|
|
61 |
self.conv1v = nn.Conv3d(24+24, 24, 3, padding=1) |
|
|
62 |
self.bn1u = nn.BatchNorm3d(24) |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
self.conv0u = nn.Conv3d(24, 16, 3, padding=1) |
|
|
66 |
self.conv0v = nn.Conv3d(16+8, 8, 3, padding=1) |
|
|
67 |
self.bn0u = nn.BatchNorm3d(8) |
|
|
68 |
|
|
|
69 |
self.conv1x = nn.Conv3d(8, 4, 1, padding=0) |
|
|
70 |
|
|
|
71 |
def forward(self, x): |
|
|
72 |
x = F.elu(self.conv0a(x)) |
|
|
73 |
self.li0 = x = F.elu(self.bn0a(self.conv0b(x))) |
|
|
74 |
|
|
|
75 |
x = self.ma1(x) |
|
|
76 |
x = F.elu(self.conv1a(x)) |
|
|
77 |
self.li1 = x = F.elu(self.bn1a(self.conv1b(x))) |
|
|
78 |
|
|
|
79 |
x = self.ma2(x) |
|
|
80 |
x = F.elu(self.conv2a(x)) |
|
|
81 |
self.li2 = x = F.elu(self.bn2a(self.conv2b(x))) |
|
|
82 |
|
|
|
83 |
x = self.ma3(x) |
|
|
84 |
x = F.elu(self.conv3a(x)) |
|
|
85 |
self.li3 = x = F.elu(self.bn3a(self.conv3b(x))) |
|
|
86 |
|
|
|
87 |
x = F.interpolate(x, scale_factor=2, mode="nearest") |
|
|
88 |
|
|
|
89 |
x = F.elu(self.conv2u(x)) |
|
|
90 |
x = torch.cat([x, self.li2], 1) |
|
|
91 |
x = F.elu(self.bn2u(self.conv2v(x))) |
|
|
92 |
|
|
|
93 |
self.lo1 = x |
|
|
94 |
x = F.interpolate(x, scale_factor=2, mode="nearest") |
|
|
95 |
|
|
|
96 |
x = F.elu(self.conv1u(x)) |
|
|
97 |
x = torch.cat([x, self.li1], 1) |
|
|
98 |
x = F.elu(self.bn1u(self.conv1v(x))) |
|
|
99 |
|
|
|
100 |
x = F.interpolate(x, scale_factor=2, mode="nearest") |
|
|
101 |
self.la1 = x |
|
|
102 |
|
|
|
103 |
x = F.elu(self.conv0u(x)) |
|
|
104 |
x = torch.cat([x, self.li0], 1) |
|
|
105 |
x = F.elu(self.bn0u(self.conv0v(x))) |
|
|
106 |
|
|
|
107 |
self.out = x = self.conv1x(x) |
|
|
108 |
x = torch.sigmoid(x) |
|
|
109 |
return x |
|
|
110 |
|
|
|
111 |
|
|
|
112 |
|
|
|
113 |
|
|
|
114 |
class ModelAff(nn.Module): |
|
|
115 |
def __init__(self): |
|
|
116 |
super(ModelAff, self).__init__() |
|
|
117 |
self.convaff1 = nn.Conv3d(2, 16, 3, padding=1) |
|
|
118 |
self.maaff1 = nn.MaxPool3d(2) |
|
|
119 |
self.convaff2 = nn.Conv3d(16, 16, 3, padding=1) |
|
|
120 |
self.bnaff2 = nn.LayerNorm([32, 32, 32]) |
|
|
121 |
|
|
|
122 |
self.maaff2 = nn.MaxPool3d(2) |
|
|
123 |
self.convaff3 = nn.Conv3d(16, 32, 3, padding=1) |
|
|
124 |
self.bnaff3 = nn.LayerNorm([16, 16, 16]) |
|
|
125 |
|
|
|
126 |
self.maaff3 = nn.MaxPool3d(2) |
|
|
127 |
self.convaff4 = nn.Conv3d(32, 64, 3, padding=1) |
|
|
128 |
self.maaff4 = nn.MaxPool3d(2) |
|
|
129 |
self.bnaff4 = nn.LayerNorm([8, 8, 8]) |
|
|
130 |
self.convaff5 = nn.Conv3d(64, 128, 1, padding=0) |
|
|
131 |
self.convaff6 = nn.Conv3d(128, 12, 4, padding=0) |
|
|
132 |
|
|
|
133 |
gsx, gsy, gsz = 64, 64, 64 |
|
|
134 |
gx, gy, gz = np.linspace(-1, 1, gsx), np.linspace(-1, 1, gsy), np.linspace(-1,1, gsz) |
|
|
135 |
grid = np.meshgrid(gx, gy, gz) # Y, X, Z |
|
|
136 |
grid = np.stack([grid[2], grid[1], grid[0], np.ones_like(grid[0])], axis=3) |
|
|
137 |
netgrid = np.swapaxes(grid, 0, 1)[...,[2,1,0,3]] |
|
|
138 |
|
|
|
139 |
self.register_buffer('grid', torch.tensor(netgrid.astype("float32"), requires_grad = False)) |
|
|
140 |
self.register_buffer('diagA', torch.eye(4, dtype=torch.float32)) |
|
|
141 |
|
|
|
142 |
def forward(self, outc1): |
|
|
143 |
x = outc1 |
|
|
144 |
x = F.relu(self.convaff1(x)) |
|
|
145 |
x = self.maaff1(x) |
|
|
146 |
x = F.relu(self.bnaff2(self.convaff2(x))) |
|
|
147 |
x = self.maaff2(x) |
|
|
148 |
x = F.relu(self.bnaff3(self.convaff3(x))) |
|
|
149 |
x = self.maaff3(x) |
|
|
150 |
x = F.relu(self.bnaff4(self.convaff4(x))) |
|
|
151 |
x = self.maaff4(x) |
|
|
152 |
x = F.relu(self.convaff5(x)) |
|
|
153 |
x = self.convaff6(x) |
|
|
154 |
|
|
|
155 |
x = x.view(-1, 3, 4) |
|
|
156 |
x = torch.cat([x, x[:,0:1] * 0], dim=1) |
|
|
157 |
self.tA = torch.transpose(x + self.diagA, 1, 2) |
|
|
158 |
|
|
|
159 |
wgrid = self.grid @ self.tA[:,None,None] |
|
|
160 |
gout = F.grid_sample(outc1, wgrid[...,[2,1,0]], align_corners=True) |
|
|
161 |
return gout, self.tA |
|
|
162 |
|
|
|
163 |
def resample_other(self, other): |
|
|
164 |
with torch.no_grad(): |
|
|
165 |
wgrid = self.grid @ self.tA[:,None,None] |
|
|
166 |
gout = F.grid_sample(other, wgrid[...,[2,1,0]], align_corners=True) |
|
|
167 |
return gout |
|
|
168 |
|
|
|
169 |
|
|
|
170 |
|
|
|
171 |
def bbox_world(affine, shape): |
|
|
172 |
s = shape[0]-1, shape[1]-1, shape[2]-1 |
|
|
173 |
bbox = [[0,0,0], [s[0],0,0], [0,s[1],0], [0,0,s[2]], [s[0],s[1],0], [s[0],0,s[2]], [0,s[1],s[2]], [s[0],s[1],s[2]]] |
|
|
174 |
w = affine @ np.column_stack([bbox, [1]*8]).T |
|
|
175 |
return w.T |
|
|
176 |
|
|
|
177 |
bbox_one = np.array([[-1,-1,-1,1], [1, -1, -1, 1], [-1, 1, -1, 1], [-1, -1, 1, 1], [1, 1, -1, 1], [1, -1, 1, 1], [-1, 1, 1, 1], [1,1,1,1]]) |
|
|
178 |
|
|
|
179 |
affine64_mni = \ |
|
|
180 |
np.array([[ -2.85714293, -0. , 0. , 90. ], |
|
|
181 |
[ -0. , 3.42857146, -0. , -126. ], |
|
|
182 |
[ 0. , 0. , 2.85714293, -72. ], |
|
|
183 |
[ 0. , 0. , 0. , 1. ]]) |
|
|
184 |
|
|
|
185 |
|
|
|
186 |
scriptpath = os.path.dirname(os.path.realpath(__file__)) |
|
|
187 |
|
|
|
188 |
device = torch.device("cpu") |
|
|
189 |
net = HeadModel() |
|
|
190 |
net.to(device) |
|
|
191 |
net.load_state_dict(torch.load(scriptpath + "/torchparams/params_head_00075_00000.pt", map_location=device)) |
|
|
192 |
net.eval() |
|
|
193 |
|
|
|
194 |
netAff = ModelAff() |
|
|
195 |
netAff.load_state_dict(torch.load(scriptpath + "/torchparams/paramsaffineta_00079_00000.pt", map_location=device), strict=False) |
|
|
196 |
netAff.to(device) |
|
|
197 |
netAff.eval() |
|
|
198 |
|
|
|
199 |
|
|
|
200 |
|
|
|
201 |
class HippoModel(nn.Module): |
|
|
202 |
def __init__(self): |
|
|
203 |
super(HippoModel, self).__init__() |
|
|
204 |
self.conv0a_0 = l = nn.Conv3d(1, 16, (1,1,3), padding=0) |
|
|
205 |
self.conv0a_1 = l = nn.Conv3d(16, 16, (1,3,1), padding=0) |
|
|
206 |
self.conv0a = nn.Conv3d(16, 16, (3,1,1), padding=0) |
|
|
207 |
|
|
|
208 |
self.convf1 = nn.Conv3d(16, 48, (3,3,3), padding=0) |
|
|
209 |
|
|
|
210 |
self.maxpool1 = nn.MaxPool3d(2) |
|
|
211 |
|
|
|
212 |
self.bn1 = nn.BatchNorm3d(48, momentum=1) |
|
|
213 |
self.bn1.training = False |
|
|
214 |
self.convout0 = nn.Conv3d(48, 48, (3,3,3), padding=1) |
|
|
215 |
self.convout1 = nn.Conv3d(48, 48, (3,3,3), padding=1) |
|
|
216 |
|
|
|
217 |
self.maxpool2 = nn.MaxPool3d(2) |
|
|
218 |
|
|
|
219 |
self.bn2 = nn.BatchNorm3d(48, momentum=1) |
|
|
220 |
self.bn2.training = False |
|
|
221 |
|
|
|
222 |
self.convout2p = nn.Conv3d(48, 48, (3,3,3), padding=1) |
|
|
223 |
self.convout2 = nn.Conv3d(48, 48, (3,3,3), padding=1) |
|
|
224 |
|
|
|
225 |
self.convlx3 = nn.Conv3d(48, 48, (3,3,3), padding=1) |
|
|
226 |
|
|
|
227 |
self.convlx5 = nn.Conv3d(48, 48, (3,3,3), padding=1) |
|
|
228 |
|
|
|
229 |
self.convlx7 = nn.Conv3d(48, 16, (3,3,3), padding=1) |
|
|
230 |
|
|
|
231 |
self.convlx8 = nn.Conv3d(16, 1, 1, padding=0) |
|
|
232 |
|
|
|
233 |
self.blur = nn.Conv3d(1, 1, 7, padding=3) |
|
|
234 |
|
|
|
235 |
self.conv_extract = nn.Conv3d(48, 47, 3, padding=1) |
|
|
236 |
self.convmix = nn.Conv3d(48, 16, 3, padding=1) |
|
|
237 |
self.convout1x = nn.Conv3d(16, 1, 1, padding=0) |
|
|
238 |
|
|
|
239 |
def forward(self, x): |
|
|
240 |
x = F.relu(self.conv0a_0(x)) |
|
|
241 |
x = F.relu(self.conv0a_1(x)) |
|
|
242 |
x = F.relu(self.conv0a(x)) |
|
|
243 |
self.out_conv_f1 = x = F.relu(self.convf1(x)) |
|
|
244 |
|
|
|
245 |
self.out_maxpool1 = x = self.maxpool1(x) |
|
|
246 |
x = self.bn1(x) |
|
|
247 |
x = F.relu(self.convout0(x)) |
|
|
248 |
x = self.convout1(x) |
|
|
249 |
x = x + self.out_maxpool1 |
|
|
250 |
x = F.relu(x) |
|
|
251 |
|
|
|
252 |
self.out_maxpool2 = x = self.maxpool2(x) |
|
|
253 |
x = self.bn2(x) |
|
|
254 |
x = F.relu(self.convout2p(x)) |
|
|
255 |
x = self.convout2(x) |
|
|
256 |
x = x + self.out_maxpool2 |
|
|
257 |
x = F.relu(x) |
|
|
258 |
|
|
|
259 |
self.lx2 = F.interpolate(x, scale_factor=2, mode="nearest") |
|
|
260 |
|
|
|
261 |
x = F.relu(self.convlx3(x)) |
|
|
262 |
x = F.interpolate(x, scale_factor=2, mode="nearest") |
|
|
263 |
x = F.relu(self.convlx5(x)) |
|
|
264 |
x = F.interpolate(x, scale_factor=2, mode="nearest") |
|
|
265 |
x = F.relu(self.convlx7(x)) |
|
|
266 |
self.out_output1 = x = torch.sigmoid(self.convlx8(x)) |
|
|
267 |
|
|
|
268 |
x = torch.sigmoid(self.blur(x)) |
|
|
269 |
x = x * self.out_conv_f1 |
|
|
270 |
x = F.leaky_relu(self.conv_extract(x)) |
|
|
271 |
x = torch.cat([self.out_output1, x], dim=1) |
|
|
272 |
|
|
|
273 |
x = F.relu(self.convmix(x)) |
|
|
274 |
self.out_output2 = x = torch.sigmoid(self.convout1x(x)) |
|
|
275 |
#x = torch.cat([self.out_output2, self.out_output1], dim=1) |
|
|
276 |
|
|
|
277 |
return x |
|
|
278 |
|
|
|
279 |
hipponet = HippoModel() |
|
|
280 |
hipponet.load_state_dict(torch.load(scriptpath + "/torchparams/hippodeep.pt")) |
|
|
281 |
|
|
|
282 |
|
|
|
283 |
OUTPUT_RES64 = False |
|
|
284 |
OUTPUT_NATIVE = True |
|
|
285 |
OUTPUT_DEBUG = False |
|
|
286 |
|
|
|
287 |
allsubjects_scalar_report = [] |
|
|
288 |
|
|
|
289 |
mul_homo = lambda g, Mt : g @ Mt[:3,:3].astype(np.float32) + Mt[3,:3].astype(np.float32) |
|
|
290 |
|
|
|
291 |
def indices_unitary(dimensions, dtype): |
|
|
292 |
dimensions = tuple(dimensions) |
|
|
293 |
N = len(dimensions) |
|
|
294 |
shape = (1,)*N |
|
|
295 |
res = np.empty((N,)+dimensions, dtype=dtype) |
|
|
296 |
for i, dim in enumerate(dimensions): |
|
|
297 |
res[i] = np.linspace(-1, 1, dim, dtype=dtype).reshape( shape[:i] + (dim,) + shape[i+1:] ) |
|
|
298 |
return res |
|
|
299 |
|
|
|
300 |
def main(): |
|
|
301 |
for fname in sys.argv[1:]: |
|
|
302 |
if "_mask" in fname: |
|
|
303 |
print("Skipping %s because the filename contains _mask in it" % fname) |
|
|
304 |
continue |
|
|
305 |
Ti = time.time() |
|
|
306 |
try: |
|
|
307 |
print("Loading image " + fname) |
|
|
308 |
outfilename = fname.replace(".mnc", ".nii").replace(".mgz", ".nii").replace(".nii.gz", ".nii").replace(".nii", "_tiv.nii.gz") |
|
|
309 |
img = nibabel.load(fname) |
|
|
310 |
|
|
|
311 |
if type(img) is nibabel.nifti1.Nifti1Image: |
|
|
312 |
img._affine = img.get_qform() # for ANTs compatibility |
|
|
313 |
|
|
|
314 |
if type(img) is nibabel.Nifti1Image: |
|
|
315 |
if img.header["qform_code"] == 0: |
|
|
316 |
if img.header["sform_code"] == 0: |
|
|
317 |
print(" *** Error: the header of this nifti file has no qform_code defined.") |
|
|
318 |
print(" Fix the header manually or reconvert from the original DICOM.") |
|
|
319 |
if not OUTPUT_DEBUG: |
|
|
320 |
continue |
|
|
321 |
|
|
|
322 |
if not np.allclose(img.get_sform(), img.get_qform()): |
|
|
323 |
img._affine = img.get_qform() # simplify later ANTs compatibility |
|
|
324 |
print("This image has an sform defined, ignoring it - work in scanner space using the qform") |
|
|
325 |
|
|
|
326 |
except: |
|
|
327 |
open(fname + ".warning.txt", "a").write("can't open the file\n") |
|
|
328 |
print(" *** Error: can't open file. Skip") |
|
|
329 |
continue |
|
|
330 |
|
|
|
331 |
d = img.get_fdata(caching="unchanged", dtype=np.float32) |
|
|
332 |
while len(d.shape) > 3: |
|
|
333 |
print("Warning: this looks like a timeserie. Averaging it") |
|
|
334 |
open(fname + ".warning.txt", "a").write("dim not 3. Averaging last dimension\n") |
|
|
335 |
d = d.mean(-1) |
|
|
336 |
|
|
|
337 |
d = (d - d.mean()) / d.std() |
|
|
338 |
|
|
|
339 |
o1 = nibabel.orientations.io_orientation(img.affine) |
|
|
340 |
o2 = np.array([[ 0., -1.], [ 1., 1.], [ 2., 1.]]) # We work in LAS space (same as the mni_icbm152 template) |
|
|
341 |
trn = nibabel.orientations.ornt_transform(o1, o2) # o1 to o2 (apply to o2 to obtain o1) |
|
|
342 |
trn_back = nibabel.orientations.ornt_transform(o2, o1) |
|
|
343 |
|
|
|
344 |
revaff1 = nibabel.orientations.inv_ornt_aff(trn, (1,1,1)) # mult on o1 to obtain o2 |
|
|
345 |
revaff1i = nibabel.orientations.inv_ornt_aff(trn_back, (1,1,1)) # mult on o2 to obtain o1 |
|
|
346 |
|
|
|
347 |
aff_orig64 = np.linalg.lstsq(bbox_world(np.identity(4), (64,64,64)), bbox_world(img.affine, img.shape[:3]), rcond=None)[0].T |
|
|
348 |
voxscale_native64 = np.abs(np.linalg.det(aff_orig64)) |
|
|
349 |
revaff64i = nibabel.orientations.inv_ornt_aff(trn_back, (64,64,64)) |
|
|
350 |
aff_reor64 = np.linalg.lstsq(bbox_world(revaff64i, (64,64,64)), bbox_world(img.affine, img.shape[:3]), rcond=None)[0].T |
|
|
351 |
|
|
|
352 |
wgridt = (netAff.grid @ torch.tensor(revaff1i, device=device, dtype=torch.float32))[None,...,[2,1,0]] |
|
|
353 |
d_orr = F.grid_sample(torch.as_tensor(d, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True) |
|
|
354 |
|
|
|
355 |
if OUTPUT_DEBUG: |
|
|
356 |
nibabel.Nifti1Image(np.asarray(d_orr[0,0].cpu()), aff_reor64).to_filename(outfilename.replace("_tiv", "_orig_b64")) |
|
|
357 |
|
|
|
358 |
## Head priors |
|
|
359 |
T = time.time() |
|
|
360 |
with torch.no_grad(): |
|
|
361 |
out1t = net(d_orr) |
|
|
362 |
out1 = np.asarray(out1t.cpu()) |
|
|
363 |
#print("Head Inference in ", time.time() - T) |
|
|
364 |
|
|
|
365 |
## Output head priors |
|
|
366 |
scalar_output = [] |
|
|
367 |
scalar_output_report = [] |
|
|
368 |
|
|
|
369 |
|
|
|
370 |
# brain mask |
|
|
371 |
output = out1[0,0].astype("float32") |
|
|
372 |
|
|
|
373 |
out_cc, lab = scipy.ndimage.label(output > .01) |
|
|
374 |
#output *= (out_cc == np.bincount(out_cc.flat)[1:].argmax()+1) |
|
|
375 |
brainmask_cc = torch.tensor(output) |
|
|
376 |
|
|
|
377 |
vol = (output[output > .5]).sum() * voxscale_native64 |
|
|
378 |
if OUTPUT_DEBUG: |
|
|
379 |
print(" Estimated intra-cranial volume (mm^3): %d" % vol) |
|
|
380 |
if 0: |
|
|
381 |
open(outfilename.replace("_tiv.nii.gz", "_eTIV.txt"), "w").write("%d\n" % vol) |
|
|
382 |
scalar_output.append(vol) |
|
|
383 |
scalar_output_report.append(vol) |
|
|
384 |
|
|
|
385 |
if OUTPUT_RES64: |
|
|
386 |
out = (output.clip(0, 1) * 255).astype("uint8") |
|
|
387 |
nibabel.Nifti1Image(out, aff_reor64, img.header).to_filename(outfilename.replace("_tiv", "_tissues%d_b64" % 0)) |
|
|
388 |
|
|
|
389 |
if OUTPUT_NATIVE: |
|
|
390 |
# wgridt for native space |
|
|
391 |
gsx, gsy, gsz = img.shape[:3] |
|
|
392 |
# this is a big array, so use float16 |
|
|
393 |
sgrid = np.rollaxis(indices_unitary((gsx,gsy,gsz), dtype=np.float16),0,4) |
|
|
394 |
wgridt = torch.as_tensor(mul_homo(sgrid, inv(revaff1i))[None,...,[2,1,0]], device=device, dtype=torch.float32) |
|
|
395 |
del sgrid |
|
|
396 |
|
|
|
397 |
dnat = np.asarray(F.grid_sample(torch.as_tensor(output, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True).cpu())[0,0] |
|
|
398 |
#nibabel.Nifti1Image(dnat, img.affine).to_filename(outfilename.replace("_tiv", "_tissues%d" % 0)) |
|
|
399 |
nibabel.Nifti1Image((dnat > .5).astype("uint8"), img.affine).to_filename(outfilename.replace("_tiv", "_brain_mask")) |
|
|
400 |
vol = (dnat > .5).sum() * np.abs(np.linalg.det(img.affine)) |
|
|
401 |
print(" Estimated intra-cranial volume (mm^3) (native space): %d" % vol) |
|
|
402 |
scalar_output.append(vol) |
|
|
403 |
scalar_output_report[-1] = vol # authoritative, so overwrite previous |
|
|
404 |
del dnat |
|
|
405 |
|
|
|
406 |
if 1: |
|
|
407 |
# cerebrum mask |
|
|
408 |
output = out1[0,2].astype("float32") |
|
|
409 |
|
|
|
410 |
out_cc, lab = scipy.ndimage.label(output > .01) |
|
|
411 |
output *= (out_cc == np.bincount(out_cc.flat)[1:].argmax()+1) |
|
|
412 |
|
|
|
413 |
vol = (output[output > .5]).sum() * voxscale_native64 |
|
|
414 |
if OUTPUT_DEBUG: |
|
|
415 |
print(" Estimated cerebrum volume (mm^3): %d" % vol) |
|
|
416 |
if 0: |
|
|
417 |
open(outfilename.replace("_tiv.nii.gz", "_eTIV_nocerebellum.txt"), "w").write("%d\n" % vol) |
|
|
418 |
scalar_output.append(vol) |
|
|
419 |
|
|
|
420 |
if OUTPUT_RES64: |
|
|
421 |
out = (output.clip(0, 1) * 255).astype("uint8") |
|
|
422 |
nibabel.Nifti1Image(out, aff_reor64, img.header).to_filename(outfilename.replace("_tiv", "_tissues%d_b64" % 2)) |
|
|
423 |
if OUTPUT_NATIVE: |
|
|
424 |
dnat = np.asarray(F.grid_sample(torch.as_tensor(output, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True).cpu()[0,0]) |
|
|
425 |
#nibabel.Nifti1Image(dnat, img.affine).to_filename(outfilename.replace("_tiv", "_tissues%d" % 2)) |
|
|
426 |
nibabel.Nifti1Image((dnat > .5).astype("uint8"), img.affine).to_filename(outfilename.replace("_tiv", "_cerebrum_mask")) |
|
|
427 |
vol = (dnat > .5).sum() * np.abs(np.linalg.det(img.affine)) |
|
|
428 |
print(" Estimated cerebrum volume (mm^3) (native space): %d" % vol) |
|
|
429 |
scalar_output.append(vol) |
|
|
430 |
del dnat |
|
|
431 |
|
|
|
432 |
# cortex |
|
|
433 |
output = out1[0,1].astype("float32") |
|
|
434 |
output[output < .01] = 0 |
|
|
435 |
if OUTPUT_RES64: |
|
|
436 |
out = (output.clip(0, 1) * 255).astype("uint8") |
|
|
437 |
nibabel.Nifti1Image(out, aff_reor64, img.header).to_filename(outfilename.replace("_tiv", "_tissues%d_b64" % 1)) |
|
|
438 |
if OUTPUT_NATIVE and OUTPUT_DEBUG: |
|
|
439 |
dnat = np.asarray(F.grid_sample(torch.as_tensor(output, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True).cpu()[0,0]) |
|
|
440 |
nibabel.Nifti1Image(dnat, img.affine).to_filename(outfilename.replace("_tiv", "_tissues%d" % 1)) |
|
|
441 |
del dnat |
|
|
442 |
|
|
|
443 |
|
|
|
444 |
## MNI affine |
|
|
445 |
T = time.time() |
|
|
446 |
with torch.no_grad(): |
|
|
447 |
wc1, tA = netAff(out1t[:,[1,3]] * brainmask_cc) |
|
|
448 |
|
|
|
449 |
wnat = np.linalg.lstsq(bbox_world(img.affine, img.shape[:3]), bbox_one @ revaff1, rcond=None)[0] |
|
|
450 |
wmni = np.linalg.lstsq(bbox_world(affine64_mni, (64,64,64)), bbox_one, rcond=None)[0] |
|
|
451 |
M = (wnat @ inv(np.asarray(tA[0].cpu())) @ inv(wmni)).T |
|
|
452 |
# [native world coord] @ M.T -> [mni world coord] , in LAS space |
|
|
453 |
|
|
|
454 |
if OUTPUT_DEBUG: |
|
|
455 |
# Output MNI, mostly for debug, save in box64, uint8 |
|
|
456 |
out2 = np.asarray(wc1.to("cpu")) |
|
|
457 |
out2 = np.clip((out2 * 255), 0, 255).astype("uint8") |
|
|
458 |
nibabel.Nifti1Image(out2[0,0], affine64_mni).to_filename(outfilename.replace("_tiv", "_mniwrapc1")) |
|
|
459 |
del out2 |
|
|
460 |
if 0: |
|
|
461 |
out2r = np.asarray(netAff.resample_other(d_orr).cpu()) |
|
|
462 |
out2r = (out2r - out2r.min()) * 255 / out2r.ptp() |
|
|
463 |
nibabel.Nifti1Image(out2r[0,0].astype("uint8"), affine64_mni).to_filename(outfilename.replace("_tiv", "_mniwrap")) |
|
|
464 |
del out2r |
|
|
465 |
|
|
|
466 |
|
|
|
467 |
# output an ANTs-compatible matrix (AntsApplyTransforms -t) |
|
|
468 |
f3 = np.array([[1, 1, -1, -1],[1, 1, -1, -1], [-1, -1, 1, 1], [1, 1, 1, 1]]) # ANTs LPS |
|
|
469 |
MI = inv(M) * f3 |
|
|
470 |
txt = """#Insight Transform File V1.0\nTransform: AffineTransform_float_3_3\nFixedParameters: 0 0 0\nParameters: """ |
|
|
471 |
txt += " ".join(["%4.6f %4.6f %4.6f" % tuple(x) for x in MI[:3,:3].tolist()]) + " %4.6f %4.6f %4.6f\n" % (MI[0,3], MI[1,3], MI[2,3]) |
|
|
472 |
if 0: |
|
|
473 |
open(outfilename.replace("_tiv.nii.gz", "_mni0Affine.txt"), "w").write(txt) |
|
|
474 |
|
|
|
475 |
u, s, vt = np.linalg.svd(MI[:3,:3]) |
|
|
476 |
MI3rigid = u @ vt |
|
|
477 |
txt = """#Insight Transform File V1.0\nTransform: AffineTransform_float_3_3\nFixedParameters: 0 0 0\nParameters: """ |
|
|
478 |
txt += " ".join(["%4.6f %4.6f %4.6f" % tuple(x) for x in MI3rigid.tolist()]) + " %4.6f %4.6f %4.6f\n" % (MI[0,3], MI[1,3], MI[2,3]) |
|
|
479 |
if 0: |
|
|
480 |
open(outfilename.replace("_tiv.nii.gz", "_mni0Rigid.txt"), "w").write(txt) |
|
|
481 |
|
|
|
482 |
## Hippodeep |
|
|
483 |
T = time.time() |
|
|
484 |
|
|
|
485 |
imgcroproi_affine = np.array([[ -1., -0., 0., 54.], [ -0., 1., -0., -59.], [0., 0., 1., -45.], [0., 0., 0., 1.]]) |
|
|
486 |
imgcroproi_shape = (107, 72, 68) |
|
|
487 |
# coord in mm bbox |
|
|
488 |
gsx, gsy, gsz = 107, 72, 68 |
|
|
489 |
sgrid = np.rollaxis(indices_unitary((gsx,gsy,gsz), dtype=np.float32),0,4) |
|
|
490 |
|
|
|
491 |
bboxnat = bbox_world(imgcroproi_affine, imgcroproi_shape) @ inv(M.T) @ wnat |
|
|
492 |
matzoom = np.linalg.lstsq(bbox_one, bboxnat, rcond=None)[0] # in -1..1 space |
|
|
493 |
# wgridt for hippo box |
|
|
494 |
wgridt = torch.tensor(mul_homo( sgrid, (matzoom @ revaff1i) )[None,...,[2,1,0]], device=device, dtype=torch.float32) |
|
|
495 |
del sgrid |
|
|
496 |
dout = F.grid_sample(torch.as_tensor(d, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True) |
|
|
497 |
# note: d was normalized from full-image |
|
|
498 |
d_in = np.asarray(dout[0,0].cpu()) # back to numpy since torch does not support negative step/strides |
|
|
499 |
|
|
|
500 |
if OUTPUT_RES64: |
|
|
501 |
d_in_u8 = (((d_in - d_in.min()) / d_in.ptp()) * 255).astype("uint8") |
|
|
502 |
nibabel.Nifti1Image(d_in_u8, imgcroproi_affine).to_filename(outfilename.replace("_tiv", "_affcrop")) |
|
|
503 |
|
|
|
504 |
d_in -= d_in.mean() |
|
|
505 |
d_in /= d_in.std() |
|
|
506 |
# split Left and Right (flipping Right) |
|
|
507 |
with torch.no_grad(): |
|
|
508 |
hippoR = hipponet(torch.as_tensor(d_in[None, None, 6: 54:+1,: ,2:-2 ].copy())) |
|
|
509 |
hippoL = hipponet(torch.as_tensor(d_in[None, None,-7:-55:-1,: ,2:-2 ].copy())) |
|
|
510 |
|
|
|
511 |
hippoRL = np.vstack([np.asarray(hippoR.cpu()), np.asarray(hippoL.cpu())]) |
|
|
512 |
#print("Hippo Inferrence in " + str(time.time() - T)) |
|
|
513 |
|
|
|
514 |
# smoothly rescale (.5 ~ .75) to (.5 ~ 1.) |
|
|
515 |
hippoRL = np.clip(((hippoRL - .5) * 2 + .5), 0, 1) * (hippoRL > .5) |
|
|
516 |
# lots numpy/torch copy below, because torch raises errors on negative strides |
|
|
517 |
output = np.zeros((2, 107, 72, 68), np.float32) |
|
|
518 |
output[0, -7:-55:-1,: ,2:-2][2:-2,2:-2,2:-2] = np.clip(hippoRL[1] * 255, 0, 255)#* maskL |
|
|
519 |
output[1, 6: 54:+1,: ,2:-2][2:-2,2:-2,2:-2] = np.clip(hippoRL[0] * 255, 0, 255) # * maskR |
|
|
520 |
|
|
|
521 |
if OUTPUT_DEBUG: |
|
|
522 |
#outputfn = outfilename.replace(".nii.gz", "_outseg_L.nii.gz") |
|
|
523 |
#nibabel.Nifti1Image(output[0], imgcroproi_affine).to_filename(outputfn) |
|
|
524 |
#outputfn = outfilename.replace(".nii.gz", "_outseg_R.nii.gz") |
|
|
525 |
#nibabel.Nifti1Image(output[1], imgcroproi_affine).to_filename(outputfn) |
|
|
526 |
outputfn = outfilename.replace("_tiv", "_affcrop_outseg_mask") |
|
|
527 |
nibabel.Nifti1Image(output.sum(0), imgcroproi_affine).to_filename(outputfn) |
|
|
528 |
|
|
|
529 |
boxvols = hippoRL[[1,0]].reshape(2, -1).sum(1) * np.abs(np.linalg.det(imgcroproi_affine @ inv(M))) |
|
|
530 |
scalar_output.append(boxvols) |
|
|
531 |
|
|
|
532 |
if 1: |
|
|
533 |
|
|
|
534 |
def bbox_xyz(shape, affine): |
|
|
535 |
" returns the worldspace of the edge of the image " |
|
|
536 |
s = shape[0]-1, shape[1]-1, shape[2]-1 |
|
|
537 |
bbox = [[0,0,0], [s[0],0,0], [0,s[1],0], [0,0,s[2]], [s[0],s[1],0], [s[0],0,s[2]], [0,s[1],s[2]], [s[0],s[1],s[2]]] |
|
|
538 |
return mul_homo(bbox, affine.T) |
|
|
539 |
|
|
|
540 |
def indices_xyz(shape, affine, offset_vox= np.array([0,0,0])): |
|
|
541 |
assert (len(shape) == 3) |
|
|
542 |
ind = np.indices(shape).astype(np.float32) + offset_vox.reshape(3, 1,1,1).astype(np.float32) |
|
|
543 |
return mul_homo(np.rollaxis(ind, 0, 4), affine.T) |
|
|
544 |
|
|
|
545 |
def xyz_to_DHW3(xyz, iaffine, srcshape): |
|
|
546 |
affine = np.linalg.inv(iaffine) |
|
|
547 |
ijk3 = mul_homo(xyz, affine.T) |
|
|
548 |
ijk3[...,0] /= srcshape[0] -1 |
|
|
549 |
ijk3[...,1] /= srcshape[1] -1 |
|
|
550 |
ijk3[...,2] /= srcshape[2] -1 |
|
|
551 |
ijk3 = ijk3 * 2 - 1 |
|
|
552 |
DHW3 = np.swapaxes(ijk3, 0, 2) |
|
|
553 |
return DHW3 |
|
|
554 |
|
|
|
555 |
pts = bbox_xyz(imgcroproi_shape, imgcroproi_affine) |
|
|
556 |
pts = mul_homo(pts, np.linalg.inv(M).T) |
|
|
557 |
pts_ijk = mul_homo(pts, np.linalg.inv(img.affine).T) |
|
|
558 |
for i in range(3): |
|
|
559 |
np.clip(pts_ijk[:,i], 0, img.shape[i], out = pts_ijk[:,i]) |
|
|
560 |
pmin = np.floor(np.min(pts_ijk, 0)).astype(int) |
|
|
561 |
pwidth = np.ceil(np.max(pts_ijk, 0)).astype(int) - pmin |
|
|
562 |
|
|
|
563 |
widx = indices_xyz(pwidth, img.affine, offset_vox=pmin) |
|
|
564 |
|
|
|
565 |
widx = mul_homo(widx, M.T) |
|
|
566 |
|
|
|
567 |
DHW3 = xyz_to_DHW3(widx, imgcroproi_affine, imgcroproi_shape) |
|
|
568 |
|
|
|
569 |
wdata = np.zeros(img.shape[:3], np.uint8) |
|
|
570 |
|
|
|
571 |
|
|
|
572 |
d = torch.tensor(output[0].T, dtype=torch.float32) |
|
|
573 |
outDHW = F.grid_sample(d[None,None], torch.tensor(DHW3[None]), align_corners=True) |
|
|
574 |
dnat = np.asarray(outDHW[0,0].permute(2,1,0)) |
|
|
575 |
dnat[dnat < 32] = 0 # remove noise |
|
|
576 |
volsAA_L = dnat.sum() / 255. * np.abs(np.linalg.det(img.affine)) |
|
|
577 |
wdata[pmin[0]:pmin[0]+pwidth[0], pmin[1]:pmin[1]+pwidth[1], pmin[2]:pmin[2]+pwidth[2]] = dnat.astype(np.uint8) |
|
|
578 |
nibabel.Nifti1Image(wdata.astype("uint8"), img.affine).to_filename(outfilename.replace("_tiv", "_mask_L")) |
|
|
579 |
|
|
|
580 |
d = torch.tensor(output[1].T, dtype=torch.float32) |
|
|
581 |
outDHW = F.grid_sample(d[None,None], torch.tensor(DHW3[None]), align_corners=True) |
|
|
582 |
dnat = np.asarray(outDHW[0,0].permute(2,1,0)) |
|
|
583 |
dnat[dnat < 32] = 0 # remove noise |
|
|
584 |
volsAA_R = dnat.sum() / 255. * np.abs(np.linalg.det(img.affine)) |
|
|
585 |
wdata[pmin[0]:pmin[0]+pwidth[0], pmin[1]:pmin[1]+pwidth[1], pmin[2]:pmin[2]+pwidth[2]] = dnat.astype(np.uint8) |
|
|
586 |
nibabel.Nifti1Image(wdata.astype("uint8"), img.affine).to_filename(outfilename.replace("_tiv", "_mask_R")) |
|
|
587 |
|
|
|
588 |
print(" Hippocampal volumes (L,R)", volsAA_L, volsAA_R) |
|
|
589 |
scalar_output.append([volsAA_L, volsAA_R]) |
|
|
590 |
scalar_output_report.append([volsAA_L, volsAA_R]) |
|
|
591 |
|
|
|
592 |
|
|
|
593 |
if OUTPUT_DEBUG: |
|
|
594 |
txt = "eTIV_mni,eTIV,cerebrum_mni,cerebrum,mni_hippoL,mni_hippoR,hippoL,hippoR\n" |
|
|
595 |
txt += "%4f,%4f,%4f,%4f,%4.4f,%4.4f,%4.4f,%4.4f\n" % (tuple(scalar_output[:4]) + tuple(scalar_output[4])+ tuple(scalar_output[5])) |
|
|
596 |
open(outfilename.replace("_tiv.nii.gz", "_scalars_hippo.csv"), "w").write(txt) |
|
|
597 |
|
|
|
598 |
if 1: |
|
|
599 |
txt = "eTIV,hippoL,hippoR\n" |
|
|
600 |
txt += "%4f,%4f,%4f\n" % (scalar_output_report[0], scalar_output_report[1][0], scalar_output_report[1][1]) |
|
|
601 |
open(outfilename.replace("_tiv.nii.gz", "_hippoLR_volumes.csv"), "w").write(txt) |
|
|
602 |
|
|
|
603 |
if OUTPUT_RES64: |
|
|
604 |
print("fslview %s %s -t .5 &" % (outfilename.replace("_tiv", "_affcrop"), outfilename.replace("_tiv", "_affcrop_outseg_mask"))) |
|
|
605 |
|
|
|
606 |
print(" Elapsed time for subject %4.2fs " % (time.time() - Ti)) |
|
|
607 |
print(" To display using fsleyes or fslview, try:") |
|
|
608 |
print(" fsleyes %s %s -a 75 -cm Red-Yellow %s -a 75 -cm Blue-Lightblue &" % (fname, outfilename.replace("_tiv", "_mask_L"), outfilename.replace("_tiv", "_mask_R"))) |
|
|
609 |
print(" fslview %s %s -t .5 %s -t .5 &" % (fname, outfilename.replace("_tiv", "_mask_L"), outfilename.replace("_tiv", "_mask_R"))) |
|
|
610 |
|
|
|
611 |
|
|
|
612 |
allsubjects_scalar_report.append( (fname, scalar_output_report[0], scalar_output_report[1][0], scalar_output_report[1][1]) ) |
|
|
613 |
|
|
|
614 |
try: |
|
|
615 |
print("Peak memory used (Gb) " + str(resource.getrusage(resource.RUSAGE_SELF)[2] / (1024.*1024))) |
|
|
616 |
except: |
|
|
617 |
pass |
|
|
618 |
|
|
|
619 |
print("Done") |
|
|
620 |
|
|
|
621 |
if len(sys.argv[1:]) > 1: |
|
|
622 |
outfilename = (os.path.dirname(fname) or ".") + "/all_subjects_hippo_report.csv" |
|
|
623 |
txt_entries = ["%s,%4f,%4f,%4f\n" % s for s in allsubjects_scalar_report] |
|
|
624 |
open(outfilename, "w").writelines( [ "filename,eTIV,hippoL,hippoR\n" ] + txt_entries) |
|
|
625 |
print("Volumes of every subjects saved as " + outfilename) |
|
|
626 |
|
|
|
627 |
if __name__ == "__main__": |
|
|
628 |
main() |