|
a |
|
b/data/base_dataset.py |
|
|
1 |
# Manuel A. Morales (moralesq@mit.edu) |
|
|
2 |
# Harvard-MIT Department of Health Sciences & Technology |
|
|
3 |
# Athinoula A. Martinos Center for Biomedical Imaging |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
from abc import ABC, abstractmethod |
|
|
7 |
from tensorflow.keras.utils import Sequence |
|
|
8 |
from scipy.ndimage.measurements import center_of_mass |
|
|
9 |
|
|
|
10 |
import nibabel as nib |
|
|
11 |
from dipy.align.reslice import reslice |
|
|
12 |
|
|
|
13 |
class BaseDataset(Sequence, ABC): |
|
|
14 |
"""This class is an abstract base class (ABC) for datasets.""" |
|
|
15 |
|
|
|
16 |
def __init__(self, opt): |
|
|
17 |
self.opt = opt |
|
|
18 |
self.root = opt.dataroot |
|
|
19 |
|
|
|
20 |
@abstractmethod |
|
|
21 |
def __len__(self): |
|
|
22 |
"""Return the size of the dataset.""" |
|
|
23 |
return |
|
|
24 |
|
|
|
25 |
@abstractmethod |
|
|
26 |
def __getitem__(self, idx): |
|
|
27 |
"""Return a data point and its metadata information.""" |
|
|
28 |
pass |
|
|
29 |
|
|
|
30 |
class Transforms(): |
|
|
31 |
|
|
|
32 |
def __init__(self, opt): |
|
|
33 |
self.opt = opt |
|
|
34 |
self.transform, self.transform_inv = self.get_transforms(opt) |
|
|
35 |
|
|
|
36 |
def __crop__(self, x, inv=False): |
|
|
37 |
|
|
|
38 |
if inv: |
|
|
39 |
nx, ny = self.original_shape[:2] |
|
|
40 |
xinv = np.zeros(self.original_shape[:2] + x.shape[2:]) |
|
|
41 |
xinv[nx//2-64:nx//2+64, ny//2-64:ny//2+64] += x |
|
|
42 |
return xinv |
|
|
43 |
else: |
|
|
44 |
nx, ny = x.shape[:2] |
|
|
45 |
return x[nx//2-64:nx//2+64, ny//2-64:ny//2+64] |
|
|
46 |
|
|
|
47 |
def __reshape_to_carson__(self, x, inv=False): |
|
|
48 |
|
|
|
49 |
if inv: |
|
|
50 |
if len(self.original_shape)==3: |
|
|
51 |
x = x.transpose(1,2,0,3) |
|
|
52 |
elif len(self.original_shape)==4: |
|
|
53 |
nx,ny,nz,nt=self.original_shape |
|
|
54 |
Nx, Ny = x.shape[1:3] |
|
|
55 |
x = x.reshape((nt, nz, Nx, Ny, self.opt.nlabels)) |
|
|
56 |
x = x.transpose(2,3,1,0,4) |
|
|
57 |
else: |
|
|
58 |
if len(x.shape) == 3: |
|
|
59 |
nx,ny,nz=x.shape |
|
|
60 |
x=x.transpose(2,0,1) |
|
|
61 |
elif len(x.shape) == 4: |
|
|
62 |
nx,ny,nz,nt=x.shape |
|
|
63 |
x=x.transpose(3,2,0,1) |
|
|
64 |
x=x.reshape((nt*nz,nx,ny)) |
|
|
65 |
return x |
|
|
66 |
|
|
|
67 |
def __reshape_to_carmen__(self, x, inv=False): |
|
|
68 |
if inv: |
|
|
69 |
x = np.concatenate((np.zeros(x[:1].shape), x)) |
|
|
70 |
x = x.transpose((1,2,3,0,4)) |
|
|
71 |
else: |
|
|
72 |
assert len(x.shape) == 4 |
|
|
73 |
nx,ny,nz,nt=x.shape |
|
|
74 |
x=x.transpose(3,0,1,2) |
|
|
75 |
x=np.stack((np.repeat(x[:1],nt-1,axis=0), x[1:nt]), -1) |
|
|
76 |
return x |
|
|
77 |
|
|
|
78 |
def __zscore__(self, x): |
|
|
79 |
|
|
|
80 |
if len(x.shape) == 3: |
|
|
81 |
axis=(1,2) # normalize in-plane images independently |
|
|
82 |
elif len(x.shape) == 5: |
|
|
83 |
axis=(1,2,3) # normalize volumes independently |
|
|
84 |
|
|
|
85 |
self.mu = x.mean(axis=axis, keepdims=True) |
|
|
86 |
self.sd = x.std(axis=axis, keepdims=True) |
|
|
87 |
return (x - self.mu)/(self.sd + 1e-8) |
|
|
88 |
|
|
|
89 |
def get_transforms(self, opt): |
|
|
90 |
|
|
|
91 |
transform_list = [] |
|
|
92 |
transform_inv_list = [] |
|
|
93 |
if 'crop' in opt.preprocess: |
|
|
94 |
transform_list.append(self.__crop__) |
|
|
95 |
transform_inv_list.append(lambda x:self.__crop__(x,inv=True)) |
|
|
96 |
if 'reshape_to_carson' in opt.preprocess: |
|
|
97 |
transform_list.append(self.__reshape_to_carson__) |
|
|
98 |
transform_inv_list.append(lambda x:self.__reshape_to_carson__(x,inv=True)) |
|
|
99 |
elif 'reshape_to_carmen' in opt.preprocess: |
|
|
100 |
transform_list.append(self.__reshape_to_carmen__) |
|
|
101 |
transform_inv_list.append(lambda x:self.__reshape_to_carmen__(x,inv=True)) |
|
|
102 |
if 'zscore' in opt.preprocess: |
|
|
103 |
transform_list.append(self.__zscore__) |
|
|
104 |
|
|
|
105 |
return transform_list, transform_inv_list |
|
|
106 |
|
|
|
107 |
def apply(self, x): |
|
|
108 |
|
|
|
109 |
self.original_shape = x.shape |
|
|
110 |
for transform in self.transform: |
|
|
111 |
x = transform(x) |
|
|
112 |
return x |
|
|
113 |
|
|
|
114 |
def apply_inv(self, x): |
|
|
115 |
|
|
|
116 |
for transform in self.transform_inv[::-1]: |
|
|
117 |
x = transform(x) |
|
|
118 |
return x |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
def _centercrop(x): |
|
|
122 |
nx, ny = x.shape[:2] |
|
|
123 |
return x[nx//2-64:nx//2+64,ny//2-64:ny//2+64] |
|
|
124 |
|
|
|
125 |
def _roll(x,rx,ry): |
|
|
126 |
x = np.roll(x,rx,axis=0) |
|
|
127 |
x = np.roll(x,ry,axis=1) |
|
|
128 |
return x |
|
|
129 |
|
|
|
130 |
def _roll2center(x, center): |
|
|
131 |
return _roll(x, int(x.shape[0]//2-center[0]), int(x.shape[1]//2-center[1])) |
|
|
132 |
|
|
|
133 |
def _roll2center_crop(x, center): |
|
|
134 |
x = _roll2center(x, center) |
|
|
135 |
return _centercrop(x) |
|
|
136 |
|
|
|
137 |
|
|
|
138 |
##################################################### |
|
|
139 |
## FUNCTIONS TO ADD MORE FLEXIBILITY IN SEGMENTATION |
|
|
140 |
##################################################### |
|
|
141 |
|
|
|
142 |
def resample_nifti_inv(nifti_resampled, zooms, order=1, mode='nearest'): |
|
|
143 |
""" Resample `nifti_resampled` to `zooms` resolution. |
|
|
144 |
""" |
|
|
145 |
data_resampled = nifti_resampled.get_fdata() |
|
|
146 |
zooms_resampled = nifti_resampled.header.get_zooms()[:3] |
|
|
147 |
affine_resampled = nifti_resampled.affine |
|
|
148 |
|
|
|
149 |
data_resampled, affine_resampled = reslice(data_resampled, |
|
|
150 |
affine_resampled, zooms_resampled, zooms, order=order, mode=mode) |
|
|
151 |
|
|
|
152 |
nifti = nib.Nifti1Image(data_resampled, affine_resampled) |
|
|
153 |
|
|
|
154 |
return nifti |
|
|
155 |
|
|
|
156 |
def convert_back_to_nifti(data_resampled, nifti_info_subject, inv_256x256=False, order=1, mode='nearest'): |
|
|
157 |
|
|
|
158 |
if inv_256x256: |
|
|
159 |
data_resampled_mod_corr = roll_and_pad_256x256_to_center_inv(data_resampled, nifti_info=nifti_info_subject) |
|
|
160 |
else: |
|
|
161 |
data_resampled_mod_corr = data_resampled |
|
|
162 |
|
|
|
163 |
affine = nifti_info_subject['affine'] |
|
|
164 |
affine_resampled = nifti_info_subject['affine_resampled'] |
|
|
165 |
zooms = nifti_info_subject['zooms'][:3] |
|
|
166 |
zooms_resampled = nifti_info_subject['zooms_resampled'][:3] |
|
|
167 |
|
|
|
168 |
data_resampled, affine_resampled = reslice(data_resampled_mod_corr, |
|
|
169 |
affine_resampled, zooms_resampled, zooms, order=order, mode=mode) |
|
|
170 |
nifti = nib.Nifti1Image(data_resampled, affine_resampled) |
|
|
171 |
|
|
|
172 |
return nifti |
|
|
173 |
|
|
|
174 |
def roll(x,rx,ry): |
|
|
175 |
x = np.roll(x,rx,axis=0) |
|
|
176 |
x = np.roll(x,ry,axis=1) |
|
|
177 |
return x |
|
|
178 |
|
|
|
179 |
def roll2center(x, center): |
|
|
180 |
return roll(x, int(x.shape[0]//2-center[0]), int(x.shape[1]//2-center[1])) |
|
|
181 |
|
|
|
182 |
def pad_256x256(x): |
|
|
183 |
xpad = (512-x.shape[0])//2, (512-x.shape[0])-(512-x.shape[0])//2 |
|
|
184 |
ypad = (512-x.shape[1])//2, (512-x.shape[1])-(512-x.shape[1])//2 |
|
|
185 |
pads = (xpad,ypad)+((0,0),)*(len(x.shape)-2) |
|
|
186 |
vals = ((0,0),)*len(x.shape) |
|
|
187 |
x = np.pad(x, pads, 'constant', constant_values=vals) |
|
|
188 |
x = x[512//2-256//2:512//2+256//2,512//2-256//2:512//2+256//2] |
|
|
189 |
return x |
|
|
190 |
|
|
|
191 |
def roll_and_pad_256x256_to_center(x, center): |
|
|
192 |
x = roll2center(x, center) |
|
|
193 |
x = pad_256x256(x) |
|
|
194 |
return x |
|
|
195 |
|
|
|
196 |
def roll_and_pad_256x256_to_center_inv(x, nifti_info): |
|
|
197 |
|
|
|
198 |
# Recover 256x256 array that was center-cropped to 128x128! |
|
|
199 |
x_256_256 = np.zeros((256,256)+x.shape[2:]) |
|
|
200 |
x_256_256[128-64:128+64,128-64:128+64] += x |
|
|
201 |
|
|
|
202 |
# Coordinates to put the image in its original location. |
|
|
203 |
cx, cy = nifti_info['center_resampled'][:2] |
|
|
204 |
cx_mod, cy_mod = nifti_info['center_resampled_256x256'][:2] |
|
|
205 |
|
|
|
206 |
x_inv = np.zeros(nifti_info['shape_resampled'][:3]+x.shape[3:]) |
|
|
207 |
|
|
|
208 |
dx = min(int(cx),64) |
|
|
209 |
dy = min(int(cy),64) |
|
|
210 |
if (dx!=64)|(dy!=64): |
|
|
211 |
print('WARNING:FOV < 128x128!') |
|
|
212 |
|
|
|
213 |
x_inv[int(cx-dx):int(cx+dx),int(cy-dy):int(cy+dy)] += x_256_256[int(cx_mod-dx):int(cx_mod+dx), |
|
|
214 |
int(cy_mod-dy):int(cy_mod+dy)] |
|
|
215 |
return x_inv |