Diff of /dataio_motion.py [000000] .. [721c7a]

Switch to unified view

a b/dataio_motion.py
1
import torch.utils.data as data
2
import torch
3
from os import listdir
4
from os.path import join
5
import numpy as np
6
import nibabel as nib
7
import glob
8
import neural_renderer as nr
9
10
11
12
13
class TrainDataset(data.Dataset):
14
    def __init__(self, data_path):
15
        super(TrainDataset, self).__init__()
16
        self.data_path = data_path
17
        self.filename = [f for f in sorted(listdir(self.data_path))]
18
19
    def __getitem__(self, index):
20
        input_sa, input_2ch, input_4ch, contour_sa, contour_2ch, contour_4ch, \
21
        vertex_ed, faces, affine_inv, affine, origin = load_data(self.data_path, self.filename[index], T_num=50)
22
23
        img_sa_t = input_sa[0]
24
        img_sa_ed = input_sa[1]
25
26
        img_2ch_t = input_2ch[0]
27
        img_2ch_ed = input_2ch[1]
28
29
        img_4ch_t = input_4ch[0]
30
        img_4ch_ed = input_4ch[1]
31
32
        return img_sa_t, img_sa_ed, img_2ch_t, img_2ch_ed, img_4ch_t, img_4ch_ed,\
33
               contour_sa, contour_2ch, contour_4ch, \
34
               vertex_ed, faces, affine_inv, affine, origin
35
36
    def __len__(self):
37
        return len(self.filename)
38
39
class ValDataset(data.Dataset):
40
    def __init__(self, data_path):
41
        super(ValDataset, self).__init__()
42
        self.data_path = data_path
43
        self.filename = [f for f in sorted(listdir(self.data_path))]
44
45
    def __getitem__(self, index):
46
        input_sa, input_2ch, input_4ch, contour_sa, contour_2ch, contour_4ch, \
47
        vertex_ed, faces, affine_inv, affine, origin = load_data(self.data_path, self.filename[index], T_num=50,  rand_frame=20)
48
49
        img_sa_t = input_sa[0]
50
        img_sa_ed = input_sa[1]
51
52
        img_2ch_t = input_2ch[0]
53
        img_2ch_ed = input_2ch[1]
54
55
        img_4ch_t = input_4ch[0]
56
        img_4ch_ed = input_4ch[1]
57
58
59
        return img_sa_t, img_sa_ed, img_2ch_t, img_2ch_ed, img_4ch_t, img_4ch_ed,\
60
               contour_sa, contour_2ch, contour_4ch, \
61
               vertex_ed, faces, affine_inv, affine, origin
62
63
    def __len__(self):
64
        return len(self.filename)
65
66
class TestDataset(data.Dataset):
67
    def __init__(self, data_path):
68
        super(TestDataset, self).__init__()
69
        self.data_path = data_path
70
        self.filename = [f for f in sorted(listdir(self.data_path))]
71
        # print (self.filename)
72
73
    def __getitem__(self, index):
74
        input_sa, input_2ch, input_4ch, contour_sa_es, contour_2ch_es, contour_4ch_es, \
75
        vertex_es, faces, affine_inv, affine, origin = load_data_ES(self.data_path, self.filename[index])
76
77
        img_sa_es = input_sa[0]
78
        img_sa_ed = input_sa[1]
79
80
        img_2ch_es = input_2ch[0]
81
        img_2ch_ed = input_2ch[1]
82
83
84
        img_4ch_es = input_4ch[0]
85
        img_4ch_ed = input_4ch[1]
86
87
88
        return img_sa_es, img_sa_ed, img_2ch_es, img_2ch_ed, img_4ch_es, img_4ch_ed, \
89
               contour_sa_es, contour_2ch_es, contour_4ch_es, vertex_es, faces, affine_inv, affine, origin
90
91
    def __len__(self):
92
        return len(self.filename)
93
94
def get_data(path, fr):
95
    nim = nib.load(path)
96
    image = nim.get_data()[:, :, :, :]  # (h, w, slices, frame)
97
    image = np.array(image, dtype='float32')
98
99
100
    image_fr = image[..., fr]
101
    image_fr = image_fr[np.newaxis]
102
    image_ed = image[..., 0]
103
    image_ed = image_ed[np.newaxis]
104
105
    image_bank = np.concatenate((image_fr, image_ed), axis=0)
106
    image_bank = np.transpose(image_bank, (0, 3, 1, 2))
107
108
109
    return image_bank
110
111
def get_data_ES(path, path_ES):
112
    nim = nib.load(path)
113
    image = nim.get_data()[:, :, :, :]  # (h, w, slices, frame)
114
    image = np.array(image, dtype='float32')
115
116
    nim_ES = nib.load(path_ES)
117
    image_ES = nim_ES.get_data()[:, :, :, :]  # (h, w, slices, frame=0)
118
    image_ES = np.array(image_ES, dtype='float32')
119
120
121
    image_z_ed = image[..., 0]
122
    image_z_ed = image_z_ed[np.newaxis]
123
    image_z_es = image_ES[..., 0]
124
    image_z_es = image_z_es[np.newaxis]
125
126
127
    image_bank = np.concatenate((image_z_es, image_z_ed), axis=0)
128
    image_bank = np.transpose(image_bank, (0, 3, 1, 2))
129
130
131
    return image_bank
132
133
def load_data(data_path, filename, T_num, rand_frame=None):
134
    # Load images and labels
135
    img_sa_path = join(data_path, filename, 'sa_img.nii.gz')  # (H, W, 1, frames)
136
    img_2ch_path = join(data_path, filename, '2ch_img.nii.gz')
137
    img_4ch_path = join(data_path, filename, '4ch_img.nii.gz')
138
139
    contour_sa_path = join(data_path, filename, 'contour_sa.npy') # (H, W, 9, frames)
140
    contour_2ch_path = join(data_path, filename, 'contour_2ch.npy') # (H, W, 1, frames)
141
    contour_4ch_path = join(data_path, filename, 'contour_4ch.npy')# (H, W, 1, frames)
142
143
    vertices_path = join(data_path, filename, 'pred_vertices_ED_new.npy')
144
    faces_path = join(data_path, filename, 'faces_init_myo_ED.npy')
145
    affine_path = join(data_path, filename, 'affine.npz')
146
    origin_path = join(data_path, filename, 'origin.npz')
147
148
    # generate random index for t and z dimension
149
    if rand_frame is not None:
150
        rand_t = rand_frame
151
    else:
152
        rand_t = np.random.randint(0, T_num)
153
154
    image_sa_bank = get_data(img_sa_path, rand_t)
155
    image_2ch_bank = get_data(img_2ch_path, rand_t)
156
    image_4ch_bank = get_data(img_4ch_path, rand_t)
157
158
    contour_sa = np.transpose(np.load(contour_sa_path)[:,:,:,rand_t], (2,0,1)) # [H,W,slices,frame]
159
    contour_2ch = np.load(contour_2ch_path)[:,:, 0, rand_t] # [H,W,1, frame]
160
    contour_4ch = np.load(contour_4ch_path)[:,:, 0, rand_t] # [H,W,1, frame]
161
162
163
    # load mesh
164
    vertex_ed = np.load(vertices_path)
165
    faces = np.load(faces_path)
166
167
    # load affine
168
    aff_sa_inv = np.load(affine_path)['sainv']
169
    aff_2ch_inv = np.load(affine_path)['la2chinv']
170
    aff_4ch_inv = np.load(affine_path)['la4chinv']
171
    affine_inv = np.stack((aff_sa_inv, aff_2ch_inv, aff_4ch_inv), 0)
172
    aff_sa = np.load(affine_path)['sa']
173
    aff_2ch = np.load(affine_path)['la2ch']
174
    aff_4ch = np.load(affine_path)['la4ch']
175
    affine = np.stack((aff_sa, aff_2ch, aff_4ch), 0)
176
    # load origin
177
    origin_sa = np.load(origin_path)['sa']
178
    origin_2ch = np.load(origin_path)['la2ch']
179
    origin_4ch = np.load(origin_path)['la4ch']
180
    origin = np.stack((origin_sa, origin_2ch, origin_4ch), 0)
181
182
183
184
    return image_sa_bank, image_2ch_bank, image_4ch_bank, contour_sa, contour_2ch, contour_4ch, \
185
           vertex_ed, faces, affine_inv, affine, origin
186
187
def load_data_ES(data_path, filename):
188
    # Load images and labels
189
    img_sa_path = join(data_path, filename, 'sa_img.nii.gz')
190
    img_2ch_path = join(data_path, filename, '2ch_img.nii.gz')
191
    img_4ch_path = join(data_path, filename, '4ch_img.nii.gz')
192
193
    img_sa_ES_path = join(data_path, filename, 'sa_ES_img.nii.gz')
194
    img_2ch_ES_path = join(data_path, filename, '2ch_ES_img.nii.gz')
195
    img_4ch_ES_path = join(data_path, filename, '4ch_ES_img.nii.gz')
196
197
    contour_sa_path = join(data_path, filename, 'contour_sa_es.npy')
198
    contour_2ch_path = join(data_path, filename, 'contour_2ch_es.npy')
199
    contour_4ch_path = join(data_path, filename, 'contour_4ch_es.npy')
200
201
    vertices_path = join(data_path, filename, 'pred_vertices_ED_new.npy')
202
    faces_path = join(data_path, filename, 'faces_init_myo_ED.npy')
203
    affine_path = join(data_path, filename, 'affine.npz')
204
    origin_path = join(data_path, filename, 'origin.npz')
205
206
    # load obj
207
    vertex_ed = np.load(vertices_path)
208
    faces = np.load(faces_path)
209
    # load affine
210
    aff_sa_inv = np.load(affine_path)['sainv']
211
    aff_2ch_inv = np.load(affine_path)['la2chinv']
212
    aff_4ch_inv = np.load(affine_path)['la4chinv']
213
    affine_inv = np.stack((aff_sa_inv, aff_2ch_inv, aff_4ch_inv), 0)
214
    aff_sa = np.load(affine_path)['sa']
215
    aff_2ch = np.load(affine_path)['la2ch']
216
    aff_4ch = np.load(affine_path)['la4ch']
217
    affine = np.stack((aff_sa, aff_2ch, aff_4ch), 0)
218
    # load origin
219
    origin_sa = np.load(origin_path)['sa']
220
    origin_2ch = np.load(origin_path)['la2ch']
221
    origin_4ch = np.load(origin_path)['la4ch']
222
    origin = np.stack((origin_sa, origin_2ch, origin_4ch), 0)
223
224
    image_sa_ES_bank = get_data_ES(img_sa_path, img_sa_ES_path)
225
    image_2ch_ES_bank = get_data_ES(img_2ch_path, img_2ch_ES_path)
226
    image_4ch_ES_bank = get_data_ES(img_4ch_path, img_4ch_ES_path)
227
228
229
    contour_sa_es = np.transpose(np.load(contour_sa_path)[:, :, :, 0], (2, 0, 1))  # [H,W,slices,frame]
230
    contour_2ch_es = np.load(contour_2ch_path)[:, :, 0, 0]  # [H,W,frame]
231
    contour_4ch_es = np.load(contour_4ch_path)[:, :, 0, 0]  # [H,W,frame]
232
233
234
    return image_sa_ES_bank, image_2ch_ES_bank, image_4ch_ES_bank, \
235
           contour_sa_es, contour_2ch_es, contour_4ch_es, vertex_ed, faces, affine_inv, affine, origin
236