a b/hippodeep.py
1
import torch
2
import nibabel
3
import numpy as np
4
import os, sys, time
5
import scipy.ndimage
6
import torch.nn as nn
7
import torch.nn.functional as F
8
from numpy.linalg import inv
9
try:
10
 import resource
11
except:
12
 pass
13
14
# monkey-patch for back-compatibility with older (~1.0.0) torch
15
try:
16
 import inspect
17
 if not "align_corners" in inspect.signature(F.grid_sample).parameters:
18
    old_grid_sample = torch.nn.functional.grid_sample
19
    F.grid_sample = lambda *x, **k : old_grid_sample(*x)
20
except:
21
    pass
22
23
if len(sys.argv[1:]) == 0:
24
    print("Need to pass one or more T1 image filename as argument")
25
    sys.exit(1)
26
27
print("Using all available CPU threads")
28
if 0: # otherwise, set a limit (useful for running multiple instances)
29
    torch.set_num_threads(4)
30
31
32
class HeadModel(nn.Module):
33
    def __init__(self):
34
        super(HeadModel, self).__init__()
35
        self.conv0a = nn.Conv3d(1, 8, 3, padding=1)
36
        self.conv0b = nn.Conv3d(8, 8, 3, padding=1)
37
        self.bn0a = nn.BatchNorm3d(8)
38
39
        self.ma1 = nn.MaxPool3d(2)
40
        self.conv1a = nn.Conv3d(8, 16, 3, padding=1)
41
        self.conv1b = nn.Conv3d(16, 24, 3, padding=1)
42
        self.bn1a = nn.BatchNorm3d(24)
43
44
        self.ma2 = nn.MaxPool3d(2)
45
        self.conv2a = nn.Conv3d(24, 24, 3, padding=1)
46
        self.conv2b = nn.Conv3d(24, 32, 3, padding=1)
47
        self.bn2a = nn.BatchNorm3d(32)
48
49
        self.ma3 = nn.MaxPool3d(2)
50
        self.conv3a = nn.Conv3d(32, 48, 3, padding=1)
51
        self.conv3b = nn.Conv3d(48, 48, 3, padding=1)
52
        self.bn3a = nn.BatchNorm3d(48)
53
54
55
        self.conv2u = nn.Conv3d(48, 24, 3, padding=1)
56
        self.conv2v = nn.Conv3d(24+32, 24, 3, padding=1)
57
        self.bn2u = nn.BatchNorm3d(24)
58
59
60
        self.conv1u = nn.Conv3d(24, 24, 3, padding=1)
61
        self.conv1v = nn.Conv3d(24+24, 24, 3, padding=1)
62
        self.bn1u = nn.BatchNorm3d(24)
63
64
65
        self.conv0u = nn.Conv3d(24, 16, 3, padding=1)
66
        self.conv0v = nn.Conv3d(16+8, 8, 3, padding=1)
67
        self.bn0u = nn.BatchNorm3d(8)
68
69
        self.conv1x = nn.Conv3d(8, 4, 1, padding=0)
70
71
    def forward(self, x):
72
        x = F.elu(self.conv0a(x))
73
        self.li0 = x = F.elu(self.bn0a(self.conv0b(x)))
74
75
        x = self.ma1(x)
76
        x = F.elu(self.conv1a(x))
77
        self.li1 = x = F.elu(self.bn1a(self.conv1b(x)))
78
79
        x = self.ma2(x)
80
        x = F.elu(self.conv2a(x))
81
        self.li2 = x = F.elu(self.bn2a(self.conv2b(x)))
82
83
        x = self.ma3(x)
84
        x = F.elu(self.conv3a(x))
85
        self.li3 = x = F.elu(self.bn3a(self.conv3b(x)))
86
87
        x = F.interpolate(x, scale_factor=2, mode="nearest")
88
89
        x = F.elu(self.conv2u(x))
90
        x = torch.cat([x, self.li2], 1)
91
        x = F.elu(self.bn2u(self.conv2v(x)))
92
93
        self.lo1 = x
94
        x = F.interpolate(x, scale_factor=2, mode="nearest")
95
96
        x = F.elu(self.conv1u(x))
97
        x = torch.cat([x, self.li1], 1)
98
        x = F.elu(self.bn1u(self.conv1v(x)))
99
100
        x = F.interpolate(x, scale_factor=2, mode="nearest")
101
        self.la1 = x
102
103
        x = F.elu(self.conv0u(x))
104
        x = torch.cat([x, self.li0], 1)
105
        x = F.elu(self.bn0u(self.conv0v(x)))
106
107
        self.out = x = self.conv1x(x)
108
        x = torch.sigmoid(x)
109
        return x
110
111
112
113
114
class ModelAff(nn.Module):
115
    def __init__(self):
116
        super(ModelAff, self).__init__()
117
        self.convaff1 = nn.Conv3d(2, 16, 3, padding=1)
118
        self.maaff1 = nn.MaxPool3d(2)
119
        self.convaff2 = nn.Conv3d(16, 16, 3, padding=1)
120
        self.bnaff2 = nn.LayerNorm([32, 32, 32])
121
122
        self.maaff2 = nn.MaxPool3d(2)
123
        self.convaff3 = nn.Conv3d(16, 32, 3, padding=1)
124
        self.bnaff3 = nn.LayerNorm([16, 16, 16])
125
126
        self.maaff3 = nn.MaxPool3d(2)
127
        self.convaff4 = nn.Conv3d(32, 64, 3, padding=1)
128
        self.maaff4 = nn.MaxPool3d(2)
129
        self.bnaff4 = nn.LayerNorm([8, 8, 8])
130
        self.convaff5 = nn.Conv3d(64, 128, 1, padding=0)
131
        self.convaff6 = nn.Conv3d(128, 12, 4, padding=0)
132
133
        gsx, gsy, gsz = 64, 64, 64
134
        gx, gy, gz = np.linspace(-1, 1, gsx), np.linspace(-1, 1, gsy), np.linspace(-1,1, gsz)
135
        grid = np.meshgrid(gx, gy, gz) # Y, X, Z
136
        grid = np.stack([grid[2], grid[1], grid[0], np.ones_like(grid[0])], axis=3)
137
        netgrid = np.swapaxes(grid, 0, 1)[...,[2,1,0,3]]
138
        
139
        self.register_buffer('grid', torch.tensor(netgrid.astype("float32"), requires_grad = False))
140
        self.register_buffer('diagA', torch.eye(4, dtype=torch.float32))
141
142
    def forward(self, outc1):
143
        x = outc1
144
        x = F.relu(self.convaff1(x))
145
        x = self.maaff1(x)
146
        x = F.relu(self.bnaff2(self.convaff2(x)))
147
        x = self.maaff2(x)
148
        x = F.relu(self.bnaff3(self.convaff3(x)))
149
        x = self.maaff3(x)
150
        x = F.relu(self.bnaff4(self.convaff4(x)))
151
        x = self.maaff4(x)
152
        x = F.relu(self.convaff5(x))
153
        x = self.convaff6(x)
154
155
        x = x.view(-1, 3, 4)
156
        x = torch.cat([x, x[:,0:1] * 0], dim=1)
157
        self.tA = torch.transpose(x + self.diagA, 1, 2)
158
159
        wgrid = self.grid @ self.tA[:,None,None]
160
        gout = F.grid_sample(outc1, wgrid[...,[2,1,0]], align_corners=True)
161
        return gout, self.tA
162
163
    def resample_other(self, other):
164
        with torch.no_grad():
165
            wgrid = self.grid @ self.tA[:,None,None]
166
            gout = F.grid_sample(other, wgrid[...,[2,1,0]], align_corners=True)
167
            return gout
168
169
170
171
def bbox_world(affine, shape):
172
    s = shape[0]-1, shape[1]-1, shape[2]-1
173
    bbox = [[0,0,0], [s[0],0,0], [0,s[1],0], [0,0,s[2]], [s[0],s[1],0], [s[0],0,s[2]], [0,s[1],s[2]], [s[0],s[1],s[2]]]
174
    w = affine @ np.column_stack([bbox, [1]*8]).T
175
    return w.T
176
177
bbox_one = np.array([[-1,-1,-1,1], [1, -1, -1, 1], [-1, 1, -1, 1], [-1, -1, 1, 1], [1, 1, -1, 1], [1, -1, 1, 1], [-1, 1, 1, 1], [1,1,1,1]])
178
179
affine64_mni = \
180
np.array([[  -2.85714293,   -0.        ,    0.        ,   90.        ],
181
          [  -0.        ,    3.42857146,   -0.        , -126.        ],
182
          [   0.        ,    0.        ,    2.85714293,  -72.        ],
183
          [   0.        ,    0.        ,    0.        ,    1.        ]])
184
185
186
scriptpath = os.path.dirname(os.path.realpath(__file__))
187
188
device = torch.device("cpu")
189
net = HeadModel()
190
net.to(device)
191
net.load_state_dict(torch.load(scriptpath + "/torchparams/params_head_00075_00000.pt", map_location=device))
192
net.eval()
193
194
netAff = ModelAff()
195
netAff.load_state_dict(torch.load(scriptpath + "/torchparams/paramsaffineta_00079_00000.pt", map_location=device), strict=False)
196
netAff.to(device)
197
netAff.eval()
198
199
200
201
class HippoModel(nn.Module):
202
    def __init__(self):
203
        super(HippoModel, self).__init__()
204
        self.conv0a_0 = l = nn.Conv3d(1, 16, (1,1,3), padding=0)
205
        self.conv0a_1 = l = nn.Conv3d(16, 16, (1,3,1), padding=0)
206
        self.conv0a = nn.Conv3d(16, 16, (3,1,1), padding=0)
207
208
        self.convf1 = nn.Conv3d(16, 48, (3,3,3), padding=0)
209
210
        self.maxpool1 = nn.MaxPool3d(2)
211
212
        self.bn1 = nn.BatchNorm3d(48, momentum=1)
213
        self.bn1.training = False
214
        self.convout0 = nn.Conv3d(48, 48, (3,3,3), padding=1)
215
        self.convout1 = nn.Conv3d(48, 48, (3,3,3), padding=1)
216
217
        self.maxpool2 = nn.MaxPool3d(2)
218
219
        self.bn2 = nn.BatchNorm3d(48, momentum=1)
220
        self.bn2.training = False
221
222
        self.convout2p = nn.Conv3d(48, 48, (3,3,3), padding=1)
223
        self.convout2 = nn.Conv3d(48, 48, (3,3,3), padding=1)
224
225
        self.convlx3 = nn.Conv3d(48, 48, (3,3,3), padding=1)
226
227
        self.convlx5 = nn.Conv3d(48, 48, (3,3,3), padding=1)
228
229
        self.convlx7 = nn.Conv3d(48, 16, (3,3,3), padding=1)
230
231
        self.convlx8 = nn.Conv3d(16, 1, 1, padding=0)
232
233
        self.blur = nn.Conv3d(1, 1, 7, padding=3)
234
235
        self.conv_extract = nn.Conv3d(48, 47, 3, padding=1)
236
        self.convmix = nn.Conv3d(48, 16, 3, padding=1)
237
        self.convout1x = nn.Conv3d(16, 1, 1, padding=0)
238
239
    def forward(self, x):
240
        x = F.relu(self.conv0a_0(x))
241
        x = F.relu(self.conv0a_1(x))
242
        x = F.relu(self.conv0a(x))
243
        self.out_conv_f1 = x = F.relu(self.convf1(x))
244
        
245
        self.out_maxpool1 = x = self.maxpool1(x)
246
        x = self.bn1(x)
247
        x = F.relu(self.convout0(x))
248
        x = self.convout1(x)
249
        x = x + self.out_maxpool1
250
        x = F.relu(x)
251
252
        self.out_maxpool2 = x = self.maxpool2(x)
253
        x = self.bn2(x)
254
        x = F.relu(self.convout2p(x))
255
        x = self.convout2(x)
256
        x = x + self.out_maxpool2
257
        x = F.relu(x)
258
259
        self.lx2 = F.interpolate(x, scale_factor=2, mode="nearest")
260
261
        x = F.relu(self.convlx3(x))
262
        x = F.interpolate(x, scale_factor=2, mode="nearest")
263
        x = F.relu(self.convlx5(x))
264
        x = F.interpolate(x, scale_factor=2, mode="nearest")
265
        x = F.relu(self.convlx7(x))
266
        self.out_output1 = x = torch.sigmoid(self.convlx8(x))
267
268
        x = torch.sigmoid(self.blur(x))
269
        x = x * self.out_conv_f1
270
        x = F.leaky_relu(self.conv_extract(x))
271
        x = torch.cat([self.out_output1, x], dim=1)
272
        
273
        x = F.relu(self.convmix(x))
274
        self.out_output2 = x = torch.sigmoid(self.convout1x(x))    
275
        #x = torch.cat([self.out_output2, self.out_output1], dim=1)
276
277
        return x
278
279
hipponet = HippoModel()
280
hipponet.load_state_dict(torch.load(scriptpath + "/torchparams/hippodeep.pt"))
281
282
283
OUTPUT_RES64 = False
284
OUTPUT_NATIVE = True
285
OUTPUT_DEBUG = False
286
287
allsubjects_scalar_report = []
288
289
mul_homo = lambda g, Mt : g @ Mt[:3,:3].astype(np.float32) + Mt[3,:3].astype(np.float32)
290
291
def indices_unitary(dimensions, dtype):
292
    dimensions = tuple(dimensions)
293
    N = len(dimensions)
294
    shape = (1,)*N
295
    res = np.empty((N,)+dimensions, dtype=dtype)
296
    for i, dim in enumerate(dimensions):
297
        res[i] = np.linspace(-1, 1, dim, dtype=dtype).reshape( shape[:i] + (dim,) + shape[i+1:]  )
298
    return res
299
300
def main():
301
  for fname in sys.argv[1:]:
302
    if "_mask" in fname:
303
        print("Skipping %s because the filename contains _mask in it" % fname)
304
        continue
305
    Ti = time.time()
306
    try:
307
        print("Loading image " + fname)
308
        outfilename = fname.replace(".mnc", ".nii").replace(".mgz", ".nii").replace(".nii.gz", ".nii").replace(".nii", "_tiv.nii.gz")
309
        img = nibabel.load(fname)
310
311
        if type(img) is nibabel.nifti1.Nifti1Image:
312
            img._affine = img.get_qform() # for ANTs compatibility
313
314
        if type(img) is nibabel.Nifti1Image:
315
            if img.header["qform_code"] == 0:
316
                if img.header["sform_code"] == 0:
317
                    print(" *** Error: the header of this nifti file has no qform_code defined.")
318
                    print(" Fix the header manually or reconvert from the original DICOM.")
319
                    if not OUTPUT_DEBUG:
320
                        continue
321
322
            if not np.allclose(img.get_sform(), img.get_qform()):
323
                img._affine = img.get_qform() # simplify later ANTs compatibility
324
                print("This image has an sform defined, ignoring it - work in scanner space using the qform")
325
326
    except:
327
        open(fname + ".warning.txt", "a").write("can't open the file\n")
328
        print(" *** Error: can't open file. Skip")
329
        continue
330
331
    d = img.get_fdata(caching="unchanged", dtype=np.float32)
332
    while len(d.shape) > 3:
333
        print("Warning: this looks like a timeserie. Averaging it")
334
        open(fname + ".warning.txt", "a").write("dim not 3. Averaging last dimension\n")
335
        d = d.mean(-1)
336
337
    d = (d - d.mean()) / d.std()
338
339
    o1 = nibabel.orientations.io_orientation(img.affine)
340
    o2 = np.array([[ 0., -1.], [ 1.,  1.], [ 2.,  1.]]) # We work in LAS space (same as the mni_icbm152 template)
341
    trn = nibabel.orientations.ornt_transform(o1, o2) # o1 to o2 (apply to o2 to obtain o1)
342
    trn_back = nibabel.orientations.ornt_transform(o2, o1)    
343
344
    revaff1 = nibabel.orientations.inv_ornt_aff(trn, (1,1,1)) # mult on o1 to obtain o2
345
    revaff1i = nibabel.orientations.inv_ornt_aff(trn_back, (1,1,1)) # mult on o2 to obtain o1
346
347
    aff_orig64 = np.linalg.lstsq(bbox_world(np.identity(4), (64,64,64)), bbox_world(img.affine, img.shape[:3]), rcond=None)[0].T
348
    voxscale_native64 = np.abs(np.linalg.det(aff_orig64))
349
    revaff64i = nibabel.orientations.inv_ornt_aff(trn_back, (64,64,64))
350
    aff_reor64 = np.linalg.lstsq(bbox_world(revaff64i, (64,64,64)), bbox_world(img.affine, img.shape[:3]), rcond=None)[0].T
351
352
    wgridt = (netAff.grid @ torch.tensor(revaff1i, device=device, dtype=torch.float32))[None,...,[2,1,0]]
353
    d_orr = F.grid_sample(torch.as_tensor(d, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True)
354
355
    if OUTPUT_DEBUG:
356
        nibabel.Nifti1Image(np.asarray(d_orr[0,0].cpu()), aff_reor64).to_filename(outfilename.replace("_tiv", "_orig_b64"))
357
358
## Head priors
359
    T = time.time()
360
    with torch.no_grad():
361
        out1t = net(d_orr)
362
    out1 = np.asarray(out1t.cpu())
363
    #print("Head Inference in ", time.time() - T)
364
365
    ## Output head priors
366
    scalar_output = []
367
    scalar_output_report = []
368
369
370
    # brain mask
371
    output = out1[0,0].astype("float32")
372
373
    out_cc, lab = scipy.ndimage.label(output > .01)
374
    #output *= (out_cc == np.bincount(out_cc.flat)[1:].argmax()+1)
375
    brainmask_cc = torch.tensor(output)
376
377
    vol = (output[output > .5]).sum() * voxscale_native64
378
    if OUTPUT_DEBUG:
379
        print(" Estimated intra-cranial volume (mm^3): %d" % vol)
380
    if 0:
381
        open(outfilename.replace("_tiv.nii.gz", "_eTIV.txt"), "w").write("%d\n" % vol)
382
    scalar_output.append(vol)
383
    scalar_output_report.append(vol)
384
385
    if OUTPUT_RES64:
386
        out = (output.clip(0, 1) * 255).astype("uint8")
387
        nibabel.Nifti1Image(out, aff_reor64, img.header).to_filename(outfilename.replace("_tiv", "_tissues%d_b64" % 0))
388
389
    if OUTPUT_NATIVE:
390
        # wgridt for native space
391
        gsx, gsy, gsz = img.shape[:3]
392
        # this is a big array, so use float16
393
        sgrid = np.rollaxis(indices_unitary((gsx,gsy,gsz), dtype=np.float16),0,4)
394
        wgridt = torch.as_tensor(mul_homo(sgrid, inv(revaff1i))[None,...,[2,1,0]], device=device, dtype=torch.float32)
395
        del sgrid
396
397
        dnat = np.asarray(F.grid_sample(torch.as_tensor(output, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True).cpu())[0,0]
398
        #nibabel.Nifti1Image(dnat, img.affine).to_filename(outfilename.replace("_tiv", "_tissues%d" % 0))
399
        nibabel.Nifti1Image((dnat > .5).astype("uint8"), img.affine).to_filename(outfilename.replace("_tiv", "_brain_mask"))
400
        vol = (dnat > .5).sum() * np.abs(np.linalg.det(img.affine))
401
        print(" Estimated intra-cranial volume (mm^3) (native space): %d" % vol)
402
        scalar_output.append(vol)
403
        scalar_output_report[-1] = vol # authoritative, so overwrite previous
404
        del dnat
405
406
    if 1:
407
        # cerebrum mask
408
        output = out1[0,2].astype("float32")
409
    
410
        out_cc, lab = scipy.ndimage.label(output > .01)
411
        output *= (out_cc == np.bincount(out_cc.flat)[1:].argmax()+1)
412
    
413
        vol = (output[output > .5]).sum() * voxscale_native64
414
        if OUTPUT_DEBUG:
415
            print(" Estimated cerebrum volume (mm^3): %d" % vol)
416
        if 0:
417
            open(outfilename.replace("_tiv.nii.gz", "_eTIV_nocerebellum.txt"), "w").write("%d\n" % vol)
418
        scalar_output.append(vol)
419
    
420
        if OUTPUT_RES64:
421
            out = (output.clip(0, 1) * 255).astype("uint8")
422
            nibabel.Nifti1Image(out, aff_reor64, img.header).to_filename(outfilename.replace("_tiv", "_tissues%d_b64" % 2))
423
        if OUTPUT_NATIVE:
424
            dnat = np.asarray(F.grid_sample(torch.as_tensor(output, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True).cpu()[0,0])
425
            #nibabel.Nifti1Image(dnat, img.affine).to_filename(outfilename.replace("_tiv", "_tissues%d" % 2))
426
            nibabel.Nifti1Image((dnat > .5).astype("uint8"), img.affine).to_filename(outfilename.replace("_tiv", "_cerebrum_mask"))
427
            vol = (dnat > .5).sum() * np.abs(np.linalg.det(img.affine))
428
            print(" Estimated cerebrum volume (mm^3) (native space): %d" % vol)
429
            scalar_output.append(vol)
430
            del dnat
431
432
    # cortex
433
    output = out1[0,1].astype("float32")
434
    output[output < .01] = 0
435
    if OUTPUT_RES64:
436
        out = (output.clip(0, 1) * 255).astype("uint8")
437
        nibabel.Nifti1Image(out, aff_reor64, img.header).to_filename(outfilename.replace("_tiv", "_tissues%d_b64" % 1))
438
    if OUTPUT_NATIVE and OUTPUT_DEBUG:
439
        dnat = np.asarray(F.grid_sample(torch.as_tensor(output, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True).cpu()[0,0])
440
        nibabel.Nifti1Image(dnat, img.affine).to_filename(outfilename.replace("_tiv", "_tissues%d" % 1))
441
        del dnat
442
443
444
## MNI affine
445
    T = time.time()
446
    with torch.no_grad():
447
        wc1, tA = netAff(out1t[:,[1,3]] * brainmask_cc)
448
449
    wnat = np.linalg.lstsq(bbox_world(img.affine, img.shape[:3]), bbox_one @ revaff1, rcond=None)[0]
450
    wmni = np.linalg.lstsq(bbox_world(affine64_mni, (64,64,64)), bbox_one, rcond=None)[0]
451
    M = (wnat @ inv(np.asarray(tA[0].cpu())) @ inv(wmni)).T
452
    # [native world coord] @ M.T -> [mni world coord] , in LAS space
453
454
    if OUTPUT_DEBUG:
455
        # Output MNI, mostly for debug, save in box64, uint8
456
        out2 = np.asarray(wc1.to("cpu"))
457
        out2 = np.clip((out2 * 255), 0, 255).astype("uint8")
458
        nibabel.Nifti1Image(out2[0,0], affine64_mni).to_filename(outfilename.replace("_tiv", "_mniwrapc1"))
459
        del out2
460
    if 0:
461
        out2r = np.asarray(netAff.resample_other(d_orr).cpu())
462
        out2r = (out2r - out2r.min()) * 255 / out2r.ptp()
463
        nibabel.Nifti1Image(out2r[0,0].astype("uint8"), affine64_mni).to_filename(outfilename.replace("_tiv", "_mniwrap"))
464
        del out2r
465
466
467
    # output an ANTs-compatible matrix (AntsApplyTransforms -t)
468
    f3 = np.array([[1, 1, -1, -1],[1, 1, -1, -1], [-1, -1, 1, 1], [1, 1, 1, 1]]) # ANTs LPS
469
    MI = inv(M) * f3
470
    txt = """#Insight Transform File V1.0\nTransform: AffineTransform_float_3_3\nFixedParameters: 0 0 0\nParameters: """
471
    txt += " ".join(["%4.6f %4.6f %4.6f" % tuple(x) for x in MI[:3,:3].tolist()]) + " %4.6f %4.6f %4.6f\n" % (MI[0,3], MI[1,3], MI[2,3])
472
    if 0:
473
        open(outfilename.replace("_tiv.nii.gz", "_mni0Affine.txt"), "w").write(txt)
474
475
    u, s, vt = np.linalg.svd(MI[:3,:3])
476
    MI3rigid = u @ vt
477
    txt = """#Insight Transform File V1.0\nTransform: AffineTransform_float_3_3\nFixedParameters: 0 0 0\nParameters: """
478
    txt += " ".join(["%4.6f %4.6f %4.6f" % tuple(x) for x in MI3rigid.tolist()]) + " %4.6f %4.6f %4.6f\n" % (MI[0,3], MI[1,3], MI[2,3])
479
    if 0:
480
        open(outfilename.replace("_tiv.nii.gz", "_mni0Rigid.txt"), "w").write(txt)
481
482
## Hippodeep
483
    T = time.time()
484
485
    imgcroproi_affine = np.array([[ -1., -0., 0., 54.], [ -0., 1., -0., -59.], [0., 0., 1., -45.], [0., 0., 0., 1.]])
486
    imgcroproi_shape = (107, 72, 68)
487
    # coord in mm bbox
488
    gsx, gsy, gsz = 107, 72, 68
489
    sgrid = np.rollaxis(indices_unitary((gsx,gsy,gsz), dtype=np.float32),0,4)
490
491
    bboxnat = bbox_world(imgcroproi_affine, imgcroproi_shape) @ inv(M.T) @ wnat
492
    matzoom = np.linalg.lstsq(bbox_one, bboxnat, rcond=None)[0] # in -1..1 space
493
    # wgridt for hippo box
494
    wgridt = torch.tensor(mul_homo( sgrid, (matzoom @ revaff1i) )[None,...,[2,1,0]], device=device, dtype=torch.float32)
495
    del sgrid
496
    dout = F.grid_sample(torch.as_tensor(d, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True)
497
    # note: d was normalized from full-image
498
    d_in = np.asarray(dout[0,0].cpu()) # back to numpy since torch does not support negative step/strides
499
500
    if OUTPUT_RES64:
501
        d_in_u8 = (((d_in - d_in.min()) / d_in.ptp()) * 255).astype("uint8")
502
        nibabel.Nifti1Image(d_in_u8, imgcroproi_affine).to_filename(outfilename.replace("_tiv", "_affcrop"))
503
504
    d_in -= d_in.mean()
505
    d_in /= d_in.std()
506
    # split Left and Right (flipping Right)
507
    with torch.no_grad():
508
        hippoR = hipponet(torch.as_tensor(d_in[None, None, 6: 54:+1,: ,2:-2 ].copy()))
509
        hippoL = hipponet(torch.as_tensor(d_in[None, None,-7:-55:-1,: ,2:-2 ].copy()))
510
511
    hippoRL = np.vstack([np.asarray(hippoR.cpu()), np.asarray(hippoL.cpu())])
512
    #print("Hippo Inferrence in " + str(time.time() - T))
513
514
    # smoothly rescale (.5 ~ .75) to (.5 ~ 1.)
515
    hippoRL = np.clip(((hippoRL - .5) * 2 + .5), 0, 1) * (hippoRL > .5)
516
    # lots numpy/torch copy below, because torch raises errors on negative strides
517
    output = np.zeros((2, 107, 72, 68), np.float32)
518
    output[0, -7:-55:-1,: ,2:-2][2:-2,2:-2,2:-2] = np.clip(hippoRL[1] * 255, 0, 255)#* maskL
519
    output[1, 6: 54:+1,: ,2:-2][2:-2,2:-2,2:-2] = np.clip(hippoRL[0] * 255, 0, 255) # * maskR
520
521
    if OUTPUT_DEBUG:
522
        #outputfn = outfilename.replace(".nii.gz", "_outseg_L.nii.gz")
523
        #nibabel.Nifti1Image(output[0], imgcroproi_affine).to_filename(outputfn)
524
        #outputfn = outfilename.replace(".nii.gz", "_outseg_R.nii.gz")
525
        #nibabel.Nifti1Image(output[1], imgcroproi_affine).to_filename(outputfn)
526
        outputfn = outfilename.replace("_tiv", "_affcrop_outseg_mask")
527
        nibabel.Nifti1Image(output.sum(0), imgcroproi_affine).to_filename(outputfn)
528
529
    boxvols = hippoRL[[1,0]].reshape(2, -1).sum(1) * np.abs(np.linalg.det(imgcroproi_affine @ inv(M)))
530
    scalar_output.append(boxvols)
531
532
    if 1:
533
534
        def bbox_xyz(shape, affine):
535
            " returns the worldspace of the edge of the image "
536
            s = shape[0]-1, shape[1]-1, shape[2]-1
537
            bbox = [[0,0,0], [s[0],0,0], [0,s[1],0], [0,0,s[2]], [s[0],s[1],0], [s[0],0,s[2]], [0,s[1],s[2]], [s[0],s[1],s[2]]]
538
            return mul_homo(bbox, affine.T)
539
540
        def indices_xyz(shape, affine, offset_vox= np.array([0,0,0])):
541
            assert (len(shape) == 3)
542
            ind = np.indices(shape).astype(np.float32) + offset_vox.reshape(3, 1,1,1).astype(np.float32)
543
            return mul_homo(np.rollaxis(ind, 0, 4), affine.T)
544
545
        def xyz_to_DHW3(xyz, iaffine, srcshape):
546
            affine = np.linalg.inv(iaffine)
547
            ijk3 = mul_homo(xyz, affine.T)
548
            ijk3[...,0] /= srcshape[0] -1
549
            ijk3[...,1] /= srcshape[1] -1
550
            ijk3[...,2] /= srcshape[2] -1
551
            ijk3 = ijk3 * 2 - 1
552
            DHW3 = np.swapaxes(ijk3, 0, 2)
553
            return DHW3
554
555
        pts = bbox_xyz(imgcroproi_shape, imgcroproi_affine)
556
        pts = mul_homo(pts, np.linalg.inv(M).T)
557
        pts_ijk = mul_homo(pts, np.linalg.inv(img.affine).T)
558
        for i in range(3):
559
            np.clip(pts_ijk[:,i], 0, img.shape[i], out = pts_ijk[:,i])
560
        pmin = np.floor(np.min(pts_ijk, 0)).astype(int)
561
        pwidth = np.ceil(np.max(pts_ijk, 0)).astype(int) - pmin
562
563
        widx = indices_xyz(pwidth, img.affine, offset_vox=pmin)
564
565
        widx = mul_homo(widx, M.T)
566
567
        DHW3 = xyz_to_DHW3(widx, imgcroproi_affine, imgcroproi_shape)
568
569
        wdata = np.zeros(img.shape[:3], np.uint8)
570
571
572
        d = torch.tensor(output[0].T, dtype=torch.float32)
573
        outDHW = F.grid_sample(d[None,None], torch.tensor(DHW3[None]), align_corners=True)
574
        dnat = np.asarray(outDHW[0,0].permute(2,1,0))
575
        dnat[dnat < 32] = 0 # remove noise
576
        volsAA_L = dnat.sum() / 255. * np.abs(np.linalg.det(img.affine))
577
        wdata[pmin[0]:pmin[0]+pwidth[0], pmin[1]:pmin[1]+pwidth[1], pmin[2]:pmin[2]+pwidth[2]] = dnat.astype(np.uint8)
578
        nibabel.Nifti1Image(wdata.astype("uint8"), img.affine).to_filename(outfilename.replace("_tiv", "_mask_L"))
579
580
        d = torch.tensor(output[1].T, dtype=torch.float32)
581
        outDHW = F.grid_sample(d[None,None], torch.tensor(DHW3[None]), align_corners=True)
582
        dnat = np.asarray(outDHW[0,0].permute(2,1,0))
583
        dnat[dnat < 32] = 0 # remove noise
584
        volsAA_R = dnat.sum() / 255. * np.abs(np.linalg.det(img.affine))
585
        wdata[pmin[0]:pmin[0]+pwidth[0], pmin[1]:pmin[1]+pwidth[1], pmin[2]:pmin[2]+pwidth[2]] = dnat.astype(np.uint8)
586
        nibabel.Nifti1Image(wdata.astype("uint8"), img.affine).to_filename(outfilename.replace("_tiv", "_mask_R"))
587
588
        print(" Hippocampal volumes (L,R)", volsAA_L, volsAA_R)
589
        scalar_output.append([volsAA_L, volsAA_R])
590
        scalar_output_report.append([volsAA_L, volsAA_R])
591
592
593
    if OUTPUT_DEBUG:
594
        txt = "eTIV_mni,eTIV,cerebrum_mni,cerebrum,mni_hippoL,mni_hippoR,hippoL,hippoR\n"
595
        txt += "%4f,%4f,%4f,%4f,%4.4f,%4.4f,%4.4f,%4.4f\n" % (tuple(scalar_output[:4]) + tuple(scalar_output[4])+ tuple(scalar_output[5]))
596
        open(outfilename.replace("_tiv.nii.gz", "_scalars_hippo.csv"), "w").write(txt)
597
598
    if 1:
599
        txt = "eTIV,hippoL,hippoR\n"
600
        txt += "%4f,%4f,%4f\n" % (scalar_output_report[0], scalar_output_report[1][0], scalar_output_report[1][1])
601
        open(outfilename.replace("_tiv.nii.gz", "_hippoLR_volumes.csv"), "w").write(txt)
602
603
    if OUTPUT_RES64:
604
        print("fslview %s %s -t .5 &" % (outfilename.replace("_tiv", "_affcrop"), outfilename.replace("_tiv", "_affcrop_outseg_mask")))
605
606
    print(" Elapsed time for subject %4.2fs " % (time.time() - Ti))
607
    print(" To display using fsleyes or fslview, try:")
608
    print("  fsleyes %s %s -a 75 -cm Red-Yellow %s -a 75 -cm Blue-Lightblue &" % (fname, outfilename.replace("_tiv", "_mask_L"), outfilename.replace("_tiv", "_mask_R")))
609
    print("  fslview %s %s -t .5 %s -t .5 &" % (fname, outfilename.replace("_tiv", "_mask_L"), outfilename.replace("_tiv", "_mask_R")))
610
611
612
    allsubjects_scalar_report.append( (fname, scalar_output_report[0], scalar_output_report[1][0], scalar_output_report[1][1]) )
613
614
  try:
615
    print("Peak memory used (Gb) " + str(resource.getrusage(resource.RUSAGE_SELF)[2] / (1024.*1024)))
616
  except:
617
    pass
618
619
  print("Done")
620
621
  if len(sys.argv[1:]) > 1:
622
    outfilename = (os.path.dirname(fname) or ".") + "/all_subjects_hippo_report.csv"
623
    txt_entries = ["%s,%4f,%4f,%4f\n" % s for s in allsubjects_scalar_report]
624
    open(outfilename, "w").writelines( [ "filename,eTIV,hippoL,hippoR\n" ] + txt_entries)
625
    print("Volumes of every subjects saved as " + outfilename)
626
627
if __name__ == "__main__":
628
    main()