Diff of /dataio_reconstruction.py [000000] .. [721c7a]

Switch to unified view

a b/dataio_reconstruction.py
1
import torch.utils.data as data
2
import torch
3
from os import listdir
4
from os.path import join
5
import numpy as np
6
import nibabel as nib
7
import glob
8
import neural_renderer as nr
9
10
11
12
13
class TrainDataset(data.Dataset):
14
    def __init__(self, data_path):
15
        super(TrainDataset, self).__init__()
16
        self.data_path = data_path
17
        self.filename = [f for f in sorted(listdir(self.data_path))]
18
19
    def __getitem__(self, index):
20
        input_sa, input_2ch, input_4ch, contour_sa_ed, contour_2ch_ed, contour_4ch_ed, \
21
        vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, \
22
        mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch = load_data(self.data_path, self.filename[index], T_num=50)
23
24
        img_sa_t = input_sa[0]
25
        img_sa_ed = input_sa[1]
26
27
        img_2ch_t = input_2ch[0]
28
        img_2ch_ed = input_2ch[1]
29
30
        img_4ch_t = input_4ch[0]
31
        img_4ch_ed = input_4ch[1]
32
33
        return img_sa_t, img_sa_ed, img_2ch_t, img_2ch_ed, img_4ch_t, img_4ch_ed, contour_sa_ed, contour_2ch_ed, contour_4ch_ed, \
34
               vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch
35
36
    def __len__(self):
37
        return len(self.filename)
38
39
class ValDataset(data.Dataset):
40
    def __init__(self, data_path):
41
        super(ValDataset, self).__init__()
42
        self.data_path = data_path
43
        self.filename = [f for f in sorted(listdir(self.data_path))]
44
45
    def __getitem__(self, index):
46
        input_sa, input_2ch, input_4ch, \
47
        vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, contour_sa_ed, contour_2ch_ed, contour_4ch_ed,  \
48
        mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch = load_data(self.data_path, self.filename[index], T_num=50,  rand_frame=20)
49
50
        img_sa_t = input_sa[0]
51
        img_sa_ed = input_sa[1]
52
53
        img_2ch_t = input_2ch[0]
54
        img_2ch_ed = input_2ch[1]
55
56
        img_4ch_t = input_4ch[0]
57
        img_4ch_ed = input_4ch[1]
58
59
60
        return img_sa_t, img_sa_ed, img_2ch_t, img_2ch_ed, img_4ch_t, img_4ch_ed, contour_sa_ed, contour_2ch_ed, contour_4ch_ed, \
61
               vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch
62
63
    def __len__(self):
64
        return len(self.filename)
65
66
67
def get_data(path, fr):
68
    nim = nib.load(path)
69
    image = nim.get_data()[:, :, :, :]  # (h, w, slices, frame)
70
    image = np.array(image, dtype='float32')
71
72
73
    image_fr = image[..., fr]
74
    image_fr = image_fr[np.newaxis]
75
    image_ed = image[..., 0]
76
    image_ed = image_ed[np.newaxis]
77
78
    image_bank = np.concatenate((image_fr, image_ed), axis=0)
79
    image_bank = np.transpose(image_bank, (0, 3, 1, 2))
80
81
82
    return image_bank
83
84
85
def load_data(data_path, filename, T_num, rand_frame=None):
86
    # Load images and labels
87
    img_sa_path = join(data_path, filename, 'sa_img.nii.gz')  # (H, W, 1, frames)
88
    img_2ch_path = join(data_path, filename, '2ch_img.nii.gz')
89
    img_4ch_path = join(data_path, filename, '4ch_img.nii.gz')
90
91
    mesh2seg_SA_path = join(data_path, filename, 'proj_mesh_SA.npy') # (H, W, D)
92
    mesh2seg_2CH_path = join(data_path, filename, 'proj_mesh_2CH.npy') # (H, W)
93
    mesh2seg_4CH_path = join(data_path, filename, 'proj_mesh_4CH.npy') # (H, W)
94
95
    contour_sa_path = join(data_path, filename, 'contour_sa.npy')  # (H, W, 9, frames)
96
    contour_2ch_path = join(data_path, filename, 'contour_2ch.npy')  # (H, W, 1, frames)
97
    contour_4ch_path = join(data_path, filename, 'contour_4ch.npy')  # (H, W, 1, frames)
98
99
    vertices_path = join(data_path, filename, 'vertices_init_myo_ED_smooth.npy')
100
    faces_path = join(data_path, filename, 'faces_init_myo_ED_smooth.npy')
101
    affine_path = join(data_path, filename, 'affine.npz')
102
    origin_path = join(data_path, filename, 'origin.npz')
103
    vertices_gt_path = join(data_path, filename, 'vertices_resampled_ED.npy')
104
105
106
    # generate random index for t and z dimension
107
    if rand_frame is not None:
108
        rand_t = rand_frame
109
    else:
110
        rand_t = np.random.randint(0, T_num)
111
112
    image_sa_bank = get_data(img_sa_path, rand_t)
113
    image_2ch_bank = get_data(img_2ch_path, rand_t)
114
    image_4ch_bank = get_data(img_4ch_path, rand_t)
115
116
    contour_sa_ed = np.transpose(np.load(contour_sa_path)[:, :, :, 0], (2, 0, 1))  # [H,W,slices,frame]
117
    contour_2ch_ed = np.load(contour_2ch_path)[:, :, 0, 0]  # [H,W, 1, frame]
118
    contour_4ch_ed = np.load(contour_4ch_path)[:, :, 0, 0]  # [H,W, 1, frame]
119
120
    # load mesh
121
    vertex_tpl_ed = np.load(vertices_path)
122
    faces_tpl = np.load(faces_path)
123
    vertex_ed = np.load(vertices_gt_path)
124
125
    # load affine
126
    aff_sa_inv = np.load(affine_path)['sainv']
127
    aff_2ch_inv = np.load(affine_path)['la2chinv']
128
    aff_4ch_inv = np.load(affine_path)['la4chinv']
129
    affine_inv = np.stack((aff_sa_inv, aff_2ch_inv, aff_4ch_inv), 0)
130
    aff_sa = np.load(affine_path)['sa']
131
    aff_2ch = np.load(affine_path)['la2ch']
132
    aff_4ch = np.load(affine_path)['la4ch']
133
    affine = np.stack((aff_sa, aff_2ch, aff_4ch), 0)
134
    # load origin
135
    origin_sa = np.load(origin_path)['sa']
136
    origin_2ch = np.load(origin_path)['la2ch']
137
    origin_4ch = np.load(origin_path)['la4ch']
138
    origin = np.stack((origin_sa, origin_2ch, origin_4ch), 0)
139
140
141
    mesh2seg_sa = np.load(mesh2seg_SA_path)
142
    mesh2seg_2ch = np.load(mesh2seg_2CH_path)
143
    mesh2seg_4ch = np.load(mesh2seg_4CH_path)
144
145
146
    return image_sa_bank, image_2ch_bank, image_4ch_bank, contour_sa_ed, contour_2ch_ed, contour_4ch_ed, \
147
           vertex_tpl_ed, faces_tpl, affine_inv, affine, origin, vertex_ed, mesh2seg_sa, mesh2seg_2ch, mesh2seg_4ch
148
149