Diff of /data/base_dataset.py [000000] .. [4cda31]

Switch to unified view

a b/data/base_dataset.py
1
# Manuel A. Morales (moralesq@mit.edu)
2
# Harvard-MIT Department of Health Sciences & Technology  
3
# Athinoula A. Martinos Center for Biomedical Imaging
4
5
import numpy as np
6
from abc import ABC, abstractmethod
7
from tensorflow.keras.utils import Sequence
8
from scipy.ndimage.measurements import center_of_mass
9
10
import nibabel as nib
11
from dipy.align.reslice import reslice
12
13
class BaseDataset(Sequence, ABC):
14
    """This class is an abstract base class (ABC) for datasets."""
15
16
    def __init__(self, opt):
17
        self.opt  = opt
18
        self.root = opt.dataroot
19
20
    @abstractmethod
21
    def __len__(self):
22
        """Return the size of the dataset."""
23
        return 
24
    
25
    @abstractmethod
26
    def __getitem__(self, idx):
27
        """Return a data point and its metadata information."""
28
        pass
29
                       
30
class Transforms():
31
    
32
    def __init__(self, opt):
33
        self.opt = opt 
34
        self.transform, self.transform_inv = self.get_transforms(opt)
35
       
36
    def __crop__(self, x, inv=False):
37
        
38
        if inv:
39
            nx, ny = self.original_shape[:2]
40
            xinv = np.zeros(self.original_shape[:2] + x.shape[2:])
41
            xinv[nx//2-64:nx//2+64, ny//2-64:ny//2+64] += x
42
            return xinv
43
        else:
44
            nx, ny = x.shape[:2]
45
            return x[nx//2-64:nx//2+64, ny//2-64:ny//2+64]
46
    
47
    def __reshape_to_carson__(self, x, inv=False):
48
        
49
        if inv:
50
            if len(self.original_shape)==3:
51
                x = x.transpose(1,2,0,3)
52
            elif len(self.original_shape)==4:
53
                nx,ny,nz,nt=self.original_shape
54
                Nx, Ny = x.shape[1:3]
55
                x = x.reshape((nt, nz, Nx, Ny, self.opt.nlabels))
56
                x = x.transpose(2,3,1,0,4)                
57
        else:
58
            if len(x.shape) == 3:
59
                nx,ny,nz=x.shape
60
                x=x.transpose(2,0,1)
61
            elif len(x.shape) == 4:
62
                nx,ny,nz,nt=x.shape
63
                x=x.transpose(3,2,0,1)
64
                x=x.reshape((nt*nz,nx,ny))            
65
        return x
66
67
    def __reshape_to_carmen__(self, x, inv=False):
68
        if inv:
69
            x = np.concatenate((np.zeros(x[:1].shape), x))
70
            x = x.transpose((1,2,3,0,4)) 
71
        else:
72
            assert len(x.shape) == 4
73
            nx,ny,nz,nt=x.shape
74
            x=x.transpose(3,0,1,2)
75
            x=np.stack((np.repeat(x[:1],nt-1,axis=0), x[1:nt]), -1)
76
        return x  
77
    
78
    def __zscore__(self, x):
79
80
        if len(x.shape) == 3:
81
            axis=(1,2) # normalize in-plane images independently
82
        elif len(x.shape) == 5:
83
            axis=(1,2,3) # normalize volumes independently
84
85
        self.mu = x.mean(axis=axis, keepdims=True)
86
        self.sd = x.std(axis=axis, keepdims=True)
87
        return (x - self.mu)/(self.sd + 1e-8)
88
89
    def get_transforms(self, opt):
90
91
        transform_list     = []
92
        transform_inv_list = []
93
        if 'crop' in opt.preprocess:
94
            transform_list.append(self.__crop__)
95
            transform_inv_list.append(lambda x:self.__crop__(x,inv=True))
96
        if 'reshape_to_carson' in opt.preprocess:
97
            transform_list.append(self.__reshape_to_carson__)
98
            transform_inv_list.append(lambda x:self.__reshape_to_carson__(x,inv=True))
99
        elif 'reshape_to_carmen' in opt.preprocess:
100
            transform_list.append(self.__reshape_to_carmen__)
101
            transform_inv_list.append(lambda x:self.__reshape_to_carmen__(x,inv=True))
102
        if 'zscore' in opt.preprocess:
103
            transform_list.append(self.__zscore__)                
104
        
105
        return transform_list, transform_inv_list
106
          
107
    def apply(self, x):
108
        
109
        self.original_shape = x.shape
110
        for transform in self.transform:
111
            x = transform(x)
112
        return x
113
    
114
    def apply_inv(self, x):
115
        
116
        for transform in self.transform_inv[::-1]:
117
            x = transform(x)
118
        return x    
119
    
120
121
def _centercrop(x):
122
    nx, ny = x.shape[:2]
123
    return x[nx//2-64:nx//2+64,ny//2-64:ny//2+64]
124
125
def _roll(x,rx,ry):
126
    x = np.roll(x,rx,axis=0)
127
    x = np.roll(x,ry,axis=1)
128
    return x
129
130
def _roll2center(x, center):
131
    return _roll(x, int(x.shape[0]//2-center[0]), int(x.shape[1]//2-center[1]))
132
    
133
def _roll2center_crop(x, center):
134
    x = _roll2center(x, center)
135
    return _centercrop(x)
136
    
137
    
138
#####################################################
139
## FUNCTIONS TO ADD MORE FLEXIBILITY IN SEGMENTATION
140
#####################################################
141
142
def resample_nifti_inv(nifti_resampled, zooms, order=1, mode='nearest'):
143
    """ Resample `nifti_resampled` to `zooms` resolution.
144
    """
145
    data_resampled   = nifti_resampled.get_fdata()
146
    zooms_resampled  = nifti_resampled.header.get_zooms()[:3]
147
    affine_resampled = nifti_resampled.affine 
148
        
149
    data_resampled, affine_resampled = reslice(data_resampled, 
150
                                               affine_resampled, zooms_resampled, zooms, order=order, mode=mode)
151
152
    nifti = nib.Nifti1Image(data_resampled, affine_resampled)
153
    
154
    return nifti
155
    
156
def convert_back_to_nifti(data_resampled, nifti_info_subject, inv_256x256=False, order=1, mode='nearest'):
157
158
    if inv_256x256:
159
        data_resampled_mod_corr = roll_and_pad_256x256_to_center_inv(data_resampled, nifti_info=nifti_info_subject)
160
    else:
161
        data_resampled_mod_corr = data_resampled
162
        
163
    affine           = nifti_info_subject['affine']
164
    affine_resampled = nifti_info_subject['affine_resampled']
165
    zooms            = nifti_info_subject['zooms'][:3]
166
    zooms_resampled  = nifti_info_subject['zooms_resampled'][:3]
167
    
168
    data_resampled, affine_resampled = reslice(data_resampled_mod_corr, 
169
                                               affine_resampled, zooms_resampled, zooms, order=order, mode=mode)
170
    nifti = nib.Nifti1Image(data_resampled, affine_resampled)
171
    
172
    return nifti
173
174
def roll(x,rx,ry):
175
        x = np.roll(x,rx,axis=0)
176
        x = np.roll(x,ry,axis=1)
177
        return x
178
    
179
def roll2center(x, center):
180
    return roll(x, int(x.shape[0]//2-center[0]), int(x.shape[1]//2-center[1]))
181
    
182
def pad_256x256(x):
183
        xpad = (512-x.shape[0])//2, (512-x.shape[0])-(512-x.shape[0])//2
184
        ypad = (512-x.shape[1])//2, (512-x.shape[1])-(512-x.shape[1])//2
185
        pads = (xpad,ypad)+((0,0),)*(len(x.shape)-2)
186
        vals = ((0,0),)*len(x.shape)
187
        x = np.pad(x, pads, 'constant', constant_values=vals)
188
        x = x[512//2-256//2:512//2+256//2,512//2-256//2:512//2+256//2]
189
        return x
190
    
191
def roll_and_pad_256x256_to_center(x, center):
192
    x = roll2center(x, center)
193
    x = pad_256x256(x)
194
    return x
195
196
def roll_and_pad_256x256_to_center_inv(x, nifti_info):
197
198
    # Recover 256x256 array that was center-cropped to 128x128!
199
    x_256_256 = np.zeros((256,256)+x.shape[2:])
200
    x_256_256[128-64:128+64,128-64:128+64] += x
201
    
202
    # Coordinates to put the image in its original location.
203
    cx, cy         = nifti_info['center_resampled'][:2]
204
    cx_mod, cy_mod = nifti_info['center_resampled_256x256'][:2]
205
    
206
    x_inv = np.zeros(nifti_info['shape_resampled'][:3]+x.shape[3:])
207
208
    dx = min(int(cx),64)
209
    dy = min(int(cy),64)
210
    if (dx!=64)|(dy!=64):
211
        print('WARNING:FOV < 128x128!')
212
213
    x_inv[int(cx-dx):int(cx+dx),int(cy-dy):int(cy+dy)] += x_256_256[int(cx_mod-dx):int(cx_mod+dx),
214
                                                                    int(cy_mod-dy):int(cy_mod+dy)]
215
    return x_inv