|
a |
|
b/dataio_reconstruction.py |
|
|
1 |
import torch.utils.data as data |
|
|
2 |
import torch |
|
|
3 |
from os import listdir |
|
|
4 |
from os.path import join |
|
|
5 |
import numpy as np |
|
|
6 |
import nibabel as nib |
|
|
7 |
import glob |
|
|
8 |
import neural_renderer as nr |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
|
|
|
12 |
|
|
|
13 |
class TrainDataset(data.Dataset): |
|
|
14 |
def __init__(self, data_path): |
|
|
15 |
super(TrainDataset, self).__init__() |
|
|
16 |
self.data_path = data_path |
|
|
17 |
self.filename = [f for f in sorted(listdir(self.data_path))] |
|
|
18 |
|
|
|
19 |
def __getitem__(self, index): |
|
|
20 |
input_sa, input_2ch, input_4ch, contour_sa_ed, contour_2ch_ed, contour_4ch_ed, \ |
|
|
21 |
vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, \ |
|
|
22 |
mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch = load_data(self.data_path, self.filename[index], T_num=50) |
|
|
23 |
|
|
|
24 |
img_sa_t = input_sa[0] |
|
|
25 |
img_sa_ed = input_sa[1] |
|
|
26 |
|
|
|
27 |
img_2ch_t = input_2ch[0] |
|
|
28 |
img_2ch_ed = input_2ch[1] |
|
|
29 |
|
|
|
30 |
img_4ch_t = input_4ch[0] |
|
|
31 |
img_4ch_ed = input_4ch[1] |
|
|
32 |
|
|
|
33 |
return img_sa_t, img_sa_ed, img_2ch_t, img_2ch_ed, img_4ch_t, img_4ch_ed, contour_sa_ed, contour_2ch_ed, contour_4ch_ed, \ |
|
|
34 |
vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch |
|
|
35 |
|
|
|
36 |
def __len__(self): |
|
|
37 |
return len(self.filename) |
|
|
38 |
|
|
|
39 |
class ValDataset(data.Dataset): |
|
|
40 |
def __init__(self, data_path): |
|
|
41 |
super(ValDataset, self).__init__() |
|
|
42 |
self.data_path = data_path |
|
|
43 |
self.filename = [f for f in sorted(listdir(self.data_path))] |
|
|
44 |
|
|
|
45 |
def __getitem__(self, index): |
|
|
46 |
input_sa, input_2ch, input_4ch, \ |
|
|
47 |
vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, contour_sa_ed, contour_2ch_ed, contour_4ch_ed, \ |
|
|
48 |
mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch = load_data(self.data_path, self.filename[index], T_num=50, rand_frame=20) |
|
|
49 |
|
|
|
50 |
img_sa_t = input_sa[0] |
|
|
51 |
img_sa_ed = input_sa[1] |
|
|
52 |
|
|
|
53 |
img_2ch_t = input_2ch[0] |
|
|
54 |
img_2ch_ed = input_2ch[1] |
|
|
55 |
|
|
|
56 |
img_4ch_t = input_4ch[0] |
|
|
57 |
img_4ch_ed = input_4ch[1] |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
return img_sa_t, img_sa_ed, img_2ch_t, img_2ch_ed, img_4ch_t, img_4ch_ed, contour_sa_ed, contour_2ch_ed, contour_4ch_ed, \ |
|
|
61 |
vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch |
|
|
62 |
|
|
|
63 |
def __len__(self): |
|
|
64 |
return len(self.filename) |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
def get_data(path, fr): |
|
|
68 |
nim = nib.load(path) |
|
|
69 |
image = nim.get_data()[:, :, :, :] # (h, w, slices, frame) |
|
|
70 |
image = np.array(image, dtype='float32') |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
image_fr = image[..., fr] |
|
|
74 |
image_fr = image_fr[np.newaxis] |
|
|
75 |
image_ed = image[..., 0] |
|
|
76 |
image_ed = image_ed[np.newaxis] |
|
|
77 |
|
|
|
78 |
image_bank = np.concatenate((image_fr, image_ed), axis=0) |
|
|
79 |
image_bank = np.transpose(image_bank, (0, 3, 1, 2)) |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
return image_bank |
|
|
83 |
|
|
|
84 |
|
|
|
85 |
def load_data(data_path, filename, T_num, rand_frame=None): |
|
|
86 |
# Load images and labels |
|
|
87 |
img_sa_path = join(data_path, filename, 'sa_img.nii.gz') # (H, W, 1, frames) |
|
|
88 |
img_2ch_path = join(data_path, filename, '2ch_img.nii.gz') |
|
|
89 |
img_4ch_path = join(data_path, filename, '4ch_img.nii.gz') |
|
|
90 |
|
|
|
91 |
mesh2seg_SA_path = join(data_path, filename, 'proj_mesh_SA.npy') # (H, W, D) |
|
|
92 |
mesh2seg_2CH_path = join(data_path, filename, 'proj_mesh_2CH.npy') # (H, W) |
|
|
93 |
mesh2seg_4CH_path = join(data_path, filename, 'proj_mesh_4CH.npy') # (H, W) |
|
|
94 |
|
|
|
95 |
contour_sa_path = join(data_path, filename, 'contour_sa.npy') # (H, W, 9, frames) |
|
|
96 |
contour_2ch_path = join(data_path, filename, 'contour_2ch.npy') # (H, W, 1, frames) |
|
|
97 |
contour_4ch_path = join(data_path, filename, 'contour_4ch.npy') # (H, W, 1, frames) |
|
|
98 |
|
|
|
99 |
vertices_path = join(data_path, filename, 'vertices_init_myo_ED_smooth.npy') |
|
|
100 |
faces_path = join(data_path, filename, 'faces_init_myo_ED_smooth.npy') |
|
|
101 |
affine_path = join(data_path, filename, 'affine.npz') |
|
|
102 |
origin_path = join(data_path, filename, 'origin.npz') |
|
|
103 |
vertices_gt_path = join(data_path, filename, 'vertices_resampled_ED.npy') |
|
|
104 |
|
|
|
105 |
|
|
|
106 |
# generate random index for t and z dimension |
|
|
107 |
if rand_frame is not None: |
|
|
108 |
rand_t = rand_frame |
|
|
109 |
else: |
|
|
110 |
rand_t = np.random.randint(0, T_num) |
|
|
111 |
|
|
|
112 |
image_sa_bank = get_data(img_sa_path, rand_t) |
|
|
113 |
image_2ch_bank = get_data(img_2ch_path, rand_t) |
|
|
114 |
image_4ch_bank = get_data(img_4ch_path, rand_t) |
|
|
115 |
|
|
|
116 |
contour_sa_ed = np.transpose(np.load(contour_sa_path)[:, :, :, 0], (2, 0, 1)) # [H,W,slices,frame] |
|
|
117 |
contour_2ch_ed = np.load(contour_2ch_path)[:, :, 0, 0] # [H,W, 1, frame] |
|
|
118 |
contour_4ch_ed = np.load(contour_4ch_path)[:, :, 0, 0] # [H,W, 1, frame] |
|
|
119 |
|
|
|
120 |
# load mesh |
|
|
121 |
vertex_tpl_ed = np.load(vertices_path) |
|
|
122 |
faces_tpl = np.load(faces_path) |
|
|
123 |
vertex_ed = np.load(vertices_gt_path) |
|
|
124 |
|
|
|
125 |
# load affine |
|
|
126 |
aff_sa_inv = np.load(affine_path)['sainv'] |
|
|
127 |
aff_2ch_inv = np.load(affine_path)['la2chinv'] |
|
|
128 |
aff_4ch_inv = np.load(affine_path)['la4chinv'] |
|
|
129 |
affine_inv = np.stack((aff_sa_inv, aff_2ch_inv, aff_4ch_inv), 0) |
|
|
130 |
aff_sa = np.load(affine_path)['sa'] |
|
|
131 |
aff_2ch = np.load(affine_path)['la2ch'] |
|
|
132 |
aff_4ch = np.load(affine_path)['la4ch'] |
|
|
133 |
affine = np.stack((aff_sa, aff_2ch, aff_4ch), 0) |
|
|
134 |
# load origin |
|
|
135 |
origin_sa = np.load(origin_path)['sa'] |
|
|
136 |
origin_2ch = np.load(origin_path)['la2ch'] |
|
|
137 |
origin_4ch = np.load(origin_path)['la4ch'] |
|
|
138 |
origin = np.stack((origin_sa, origin_2ch, origin_4ch), 0) |
|
|
139 |
|
|
|
140 |
|
|
|
141 |
mesh2seg_sa = np.load(mesh2seg_SA_path) |
|
|
142 |
mesh2seg_2ch = np.load(mesh2seg_2CH_path) |
|
|
143 |
mesh2seg_4ch = np.load(mesh2seg_4CH_path) |
|
|
144 |
|
|
|
145 |
|
|
|
146 |
return image_sa_bank, image_2ch_bank, image_4ch_bank, contour_sa_ed, contour_2ch_ed, contour_4ch_ed, \ |
|
|
147 |
vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch |
|
|
148 |
|
|
|
149 |
|