Diff of /dataset.py [000000] .. [390c2f]

Switch to unified view

a b/dataset.py
1
import os
2
import random
3
import numpy as np
4
import torch
5
import glob
6
import torch.utils.data as data
7
import sys
8
import pyvista
9
sys.path.append('.')
10
sys.path.append('..')
11
from utils import visualize_PC_with_label
12
import re
13
14
class LoadDataset(data.Dataset):
15
    def __init__(self, path, num_input=2048, split='train'): #16384
16
        self.path = path
17
        self.num_input = num_input
18
        self.use_cobiveco = True
19
        self.data_augment = False
20
        self.signal_length = 512
21
22
        with open(path + 'my_split/{}.list'.format(split), 'r') as f:
23
            filenames = [line.strip() for line in f]
24
25
        self.metadata = list()
26
        for filename in filenames:
27
            print(filename)
28
            datapath = path + filename + '/'
29
30
            unit = 0.1
31
            if self.use_cobiveco:
32
                nodesXYZ, label_index = getCobiveco_vtu(datapath + filename + '_cobiveco_AHA17.vtu')
33
            else:
34
                nodesXYZ = np.loadtxt(datapath + filename + '_xyz.csv', delimiter=',')
35
                label_index = np.zeros((nodesXYZ.shape[0], 1))
36
                LVendo_node = np.unique((np.loadtxt(datapath + filename + '_lvface.csv', delimiter=',')-1).astype(int))
37
                RVendo_node = np.unique((np.loadtxt(datapath + filename + '_rvface.csv', delimiter=',')-1).astype(int))
38
                epi_node = np.unique((np.loadtxt(datapath + filename + '_epiface.csv', delimiter=',')-1).astype(int))
39
                label_index[LVendo_node] = 1
40
                label_index[RVendo_node] = 2
41
                label_index[epi_node] = 3
42
                label_index = label_index[..., np.newaxis]
43
                surface_index = np.concatenate((LVendo_node, RVendo_node, epi_node), axis=0) 
44
            
45
            PC_XYZ_labeled = np.concatenate((unit*nodesXYZ, label_index), axis=1)           
46
            electrode_node = np.loadtxt(datapath + filename + '_electrodePositions.csv', delimiter=',')
47
            Coord_base_apex = np.loadtxt(datapath + filename + '_BaseApexCoord.csv', delimiter=',')
48
            Coord_apex, Coord_base = Coord_base_apex[1], Coord_base_apex[0] 
49
            electrode_index = 4*np.ones(electrode_node.shape[0], dtype=np.int32)
50
            electrode_XYZ_labeled = np.concatenate((unit*electrode_node, electrode_index[..., np.newaxis]), axis=1)
51
            
52
            signal_files = glob.glob(datapath + filename + '*_simulated_ECG' + '*.csv')
53
            num_signal = len(signal_files)
54
            # print(num_signal)
55
            for id in range(num_signal):
56
                MI_index = np.zeros(nodesXYZ.shape[0], dtype=np.int32)
57
                ECG_value = np.loadtxt(signal_files[id], delimiter=',')
58
                ECG_value_u = np.pad(ECG_value, ((0, 0), (0, self.signal_length-ECG_value.shape[1])), 'constant')           
59
                MI_type = signal_files[id].replace(path, '').replace(filename, '').replace('_simulated_ECG_', '').replace('.csv', '').replace('\\', '')
60
                
61
                if MI_type == 'B1_large_transmural_slow' or MI_type == 'normal' or MI_type == 'A2_30_40_transmural':
62
                    continue
63
64
                if re.compile(r'5_transmural|0_transmural', re.IGNORECASE).search(MI_type): # remove apical MI size test case
65
                    continue
66
67
                if re.compile(r'AHA', re.IGNORECASE).search(MI_type): # remove randomly generated MI
68
                    continue
69
70
                # if not re.compile(r'5_transmural|0_transmural', re.IGNORECASE).search(MI_type) and not (MI_type == 'A2_transmural'): # remove apical MI size test case
71
                #     continue
72
73
                # if not re.compile(r'AHA', re.IGNORECASE).search(MI_type): # test only random MI!
74
                #     continue
75
                #             
76
                # if MI_type.find('subendo') != -1:
77
                #     continue
78
 
79
                # if MI_type != 'B3_transmural' and MI_type != 'A3_transmural' and MI_type != 'A2_transmural':
80
                #     continue  
81
82
                # print(MI_type)             
83
84
                if MI_type != 'normal':
85
                    Scar_filename = signal_files[id].replace('simulated_ECG', 'lvscarnodes')
86
                    BZ_filename = signal_files[id].replace('simulated_ECG', 'lvborderzonenodes')
87
                    if MI_type == 'B1_large_transmural_slow':
88
                        Scar_filename = Scar_filename.replace('_slow', '')
89
                        BZ_filename = BZ_filename.replace('_slow', '')
90
91
                    Scar_node = np.unique((np.loadtxt(Scar_filename, delimiter=',')-1).astype(int))
92
                    BZ_node = np.unique((np.loadtxt(BZ_filename, delimiter=',')-1).astype(int))
93
                    MI_index[Scar_node] = 1
94
                    MI_index[BZ_node] = 2
95
                ECG_array = np.array(ECG_value_u)
96
                MI_array = np.array(MI_index)
97
                MI_type_id = np.array(id)
98
                # print(MI_type_id)
99
                
100
                partial_PC_labeled_array, idx_remained = resample_pcd(PC_XYZ_labeled, self.num_input)
101
                partial_MI_lab_array = MI_array[idx_remained]
102
                partial_PC_labeled_array_coarse, idx_remained = resample_pcd(PC_XYZ_labeled, self.num_input//4)             
103
                # visualize_PC_with_label(partial_PC_labeled_array[:, 0:3], partial_MI_array)
104
                partial_PC_electrode_labeled_array = partial_PC_labeled_array # np.concatenate((partial_PC_labeled_array, electrode_XYZ_labeled), axis=0)
105
                partial_PC_electrode_XYZ = partial_PC_electrode_labeled_array[:, 0:3]
106
                partial_PC_electrode_lab = partial_PC_electrode_labeled_array[:, 3:]
107
                # partial_MI_lab_array = partial_MI_lab_array + np.where(partial_PC_electrode_labeled_array[0:self.num_input, -1]==1.0, 3, 0)
108
                # visualize_PC_with_label(partial_PC_labeled_array[:, 0:3], partial_MI_lab_array)
109
110
                partial_PC_electrode_XYZ_normalized = normalize_data(partial_PC_electrode_XYZ, Coord_apex)
111
                if self.data_augment:
112
                    scaling = random.uniform(0.8, 1.2)
113
                    partial_PC_electrode_XYZ_normalized = scaling*translate_point(jitter_point(rotate_point(partial_PC_electrode_XYZ_normalized, np.random.random()*np.pi)))
114
                partial_PC_electrode_XYZ_normalized_labeled = np.concatenate((partial_PC_electrode_XYZ_normalized, partial_PC_electrode_lab), axis=1)
115
116
                partial_PC_electrode_XYZ_normalized_coarse = normalize_data(partial_PC_labeled_array_coarse[:, 0:3], Coord_apex)
117
                partial_PC_electrode_XYZ_normalized_labeled_coarse = np.concatenate((partial_PC_electrode_XYZ_normalized_coarse, partial_PC_labeled_array_coarse[:, 3:]), axis=1)
118
119
                self.metadata.append((partial_PC_electrode_XYZ_normalized_labeled, partial_MI_lab_array, ECG_array, partial_PC_electrode_XYZ_normalized_labeled_coarse, MI_type))
120
121
    def __getitem__(self, index):
122
        partial_PC_electrode_XYZ_normalized_labeled, partial_MI_array, ECG_array, partial_PC_electrode_XYZ, MI_type = self.metadata[index]
123
124
        partial_input = torch.from_numpy(partial_PC_electrode_XYZ_normalized_labeled).float()
125
        gt_MI = torch.from_numpy(partial_MI_array).long()
126
        ECG_input = torch.from_numpy(ECG_array).float()
127
        partial_input_coarse = torch.from_numpy(partial_PC_electrode_XYZ).float()
128
129
        return partial_input, ECG_input, gt_MI, partial_input_coarse, MI_type
130
131
    def __len__(self):
132
        return len(self.metadata)
133
134
135
class LoadDataset_all(data.Dataset):
136
    def __init__(self, path, num_input=2048, split='train'): #16384
137
        self.path = path
138
        self.num_input = num_input
139
        self.use_cobiveco = False
140
        self.data_augment = False
141
        self.signal_length = 512
142
143
        with open(path + 'my_split/{}.list'.format(split), 'r') as f:
144
            filenames = [line.strip() for line in f]
145
146
147
        self.metadata = list()
148
        for filename in filenames:
149
            print(filename)
150
            datapath = path + filename + '/'
151
152
            unit = 1
153
            if self.use_cobiveco:
154
                nodesXYZ, label_index = getCobiveco_vtu(datapath + filename + '_heart_cobiveco.vtu')
155
            else:
156
                nodesXYZ = np.loadtxt(datapath + filename + '_xyz.csv', delimiter=',')
157
                label_index = np.zeros((nodesXYZ.shape[0], 1), dtype=np.int)
158
                LVendo_node = np.unique((np.loadtxt(datapath + filename + '_lvface.csv', delimiter=',')-1).astype(int))
159
                RVendo_node = np.unique((np.loadtxt(datapath + filename + '_rvface.csv', delimiter=',')-1).astype(int))
160
                epi_node = np.unique((np.loadtxt(datapath + filename + '_epiface.csv', delimiter=',')-1).astype(int))
161
                label_index[LVendo_node] = 1
162
                label_index[RVendo_node] = 2
163
                label_index[epi_node] = 3
164
                surface_index = np.concatenate((LVendo_node, RVendo_node, epi_node), axis=0) 
165
            
166
            PC_XYZ_labeled = np.concatenate((unit*nodesXYZ, label_index), axis=1)           
167
            electrode_node = np.loadtxt(datapath + filename + '_electrodePositions.csv', delimiter=',')
168
            Coord_base_apex = np.loadtxt(datapath + filename + '_BaseApexCoord.csv', delimiter=',')
169
            Coord_apex, Coord_base = Coord_base_apex[1], Coord_base_apex[0] 
170
            electrode_index = 4*np.ones(electrode_node.shape[0], dtype=np.int)
171
            electrode_XYZ_labeled = np.concatenate((unit*electrode_node, electrode_index[..., np.newaxis]), axis=1)
172
            
173
            signal_files = glob.glob(datapath + filename + '*_simulated_ECG' + '*.csv')
174
            ECG_list, MI_index_list = list(), list()
175
            MItype_list = list()
176
            num_signal = len(signal_files)
177
            for id in range(num_signal):
178
                MI_index = np.zeros(nodesXYZ.shape[0], dtype=np.int)
179
                ECG_value = np.loadtxt(signal_files[id], delimiter=',')
180
                ECG_value_u = np.pad(ECG_value, ((0, 0), (0, self.signal_length-ECG_value.shape[1])), 'constant')           
181
                MI_type = signal_files[id].replace(path, '').replace(filename, '').replace('_simulated_ECG_', '').replace('.csv', '').replace('\\', '')
182
                if MI_type == 'B1_large_transmural_slow' or MI_type == 'B1_large_transmural_slow':
183
                    continue
184
                if MI_type != 'normal':
185
                    Scar_filename = signal_files[id].replace('simulated_ECG', 'lvscarnodes')
186
                    BZ_filename = signal_files[id].replace('simulated_ECG', 'lvborderzonenodes')
187
                    Scar_node = np.unique((np.loadtxt(Scar_filename, delimiter=',')-1).astype(int))
188
                    BZ_node = np.unique((np.loadtxt(BZ_filename, delimiter=',')-1).astype(int))
189
                    MI_index[Scar_node] = 421
190
                    MI_index[BZ_node] = 422
191
192
                ECG_list.append(ECG_value_u)
193
                MI_index_list.append(MI_index)
194
                MItype_list.append(MI_type)
195
                                                           
196
            ECG_array = np.array(ECG_list).transpose(1, 2, 0) 
197
            MI_array = np.array(MI_index_list).transpose(1, 0)                     
198
            partial_PC_labeled_array, idx_remained = resample_pcd(PC_XYZ_labeled[surface_index], self.num_input)
199
            partial_MI_array = MI_array[surface_index][idx_remained]
200
            partial_PC_electrode_labeled_array = np.concatenate((partial_PC_labeled_array, electrode_XYZ_labeled), axis=0)
201
            partial_PC_electrode_XYZ = partial_PC_electrode_labeled_array[:, 0:3]
202
            partial_PC_electrode_lab = np.expand_dims(partial_PC_electrode_labeled_array[:, 3], axis=1)
203
            partial_PC_electrode_XYZ_normalized = normalize_data(partial_PC_electrode_XYZ, Coord_apex)
204
            if self.data_augment:
205
                scaling = random.uniform(0.8, 1.2)
206
                partial_PC_electrode_XYZ_normalized = scaling*translate_point(jitter_point(rotate_point(partial_PC_electrode_XYZ_normalized, np.random.random()*np.pi)))
207
            partial_PC_electrode_XYZ_normalized_labeled = np.concatenate((partial_PC_electrode_XYZ_normalized, partial_PC_electrode_lab), axis=1)
208
          
209
                      
210
            self.metadata.append((partial_PC_electrode_XYZ_normalized_labeled, partial_MI_array, ECG_array, partial_PC_electrode_XYZ))
211
212
    def __getitem__(self, index):
213
        partial_PC_electrode_XYZ_normalized_labeled, partial_MI_array, ECG_array, partial_PC_electrode_XYZ = self.metadata[index]
214
215
        ECG_array[np.isnan(ECG_array)] = 0  # ECG output with a size of [n_batch, 8*256], covert the nan value into 0
216
        partial_input = torch.from_numpy(partial_PC_electrode_XYZ_normalized_labeled).float()
217
        gt_MI, ECG_input = torch.from_numpy(partial_MI_array).float(), torch.from_numpy(ECG_array).float()
218
        partial_input_ori = torch.from_numpy(partial_PC_electrode_XYZ).float()
219
220
        return partial_input, ECG_input, gt_MI, partial_input_ori
221
222
    def __len__(self):
223
        return len(self.metadata)
224
225
226
def getCobiveco_vtu(cobiveco_fileName): # Read Cobiveco data in .vtu format (added by Lei on 2023/01/30)
227
    cobiveco_vol = pyvista.read(cobiveco_fileName) #, force_ext='.vtu'
228
229
    cobiveco_nodesXYZ = cobiveco_vol.points
230
    cobiveco_nodes_array = cobiveco_vol.point_data
231
    # Apex-to-Base - ab
232
    ab = cobiveco_nodes_array['ab']
233
    # Rotation angle - rt
234
    rt = cobiveco_nodes_array['rt']
235
    # Transmurality - tm
236
    tm = cobiveco_nodes_array['tm']
237
    # Ventricle - tv
238
    tv = cobiveco_nodes_array['tv']
239
    # AHA-17 map - aha
240
    aha = cobiveco_nodes_array['aha']
241
242
    return cobiveco_nodesXYZ, np.transpose(np.array([ab, rt, tm, tv, aha], dtype=float))
243
244
### point cloud augmentation ###
245
# translate point cloud
246
def translate_point(point):
247
    point = np.array(point)
248
    shift = [random.uniform(-0.5, 0.5), random.uniform(-0.5, 0.5), random.uniform(-0.5, 0.5)]
249
    shift = np.expand_dims(np.array(shift), axis=0)
250
    shifted_point = np.repeat(shift, point.shape[0], axis=0)
251
    shifted_point += point
252
253
    return shifted_point
254
255
# add Gaussian noise 
256
def jitter_point(point, sigma=0.01, clip=0.01):
257
    assert(clip > 0)
258
    point = np.array(point)
259
    point = point.reshape(-1,3)
260
    Row, Col = point.shape
261
    jittered_point = np.clip(sigma * np.random.randn(Row, Col), -1*clip, clip)
262
    jittered_point += point
263
264
    return jittered_point
265
266
# rotate point cloud
267
def rotate_point(point, rotation_angle=0.5*np.pi):
268
    point = np.array(point)
269
    cos_theta = np.cos(rotation_angle)
270
    sin_theta = np.sin(rotation_angle)
271
    # Rotation around X axis
272
    rotation_matrix_X = np.array([[1, 0, 0],
273
                                [0, cos_theta, -sin_theta],
274
                                [0, sin_theta, cos_theta]])
275
    # Rotation around Y axis
276
    rotation_matrix_Y = np.array([[cos_theta, 0, sin_theta],
277
                                [0, 1, 0],
278
                                [-sin_theta, 0, cos_theta]])
279
    # Rotation around Z axis
280
    rotation_matrix_Z = np.array([[cos_theta, sin_theta, 0],
281
                                [-sin_theta, cos_theta, 0],
282
                                [0, 0, 1]])   
283
284
    rotated_point = np.dot(point.reshape(-1, 3), rotation_matrix_Z)
285
286
    return rotated_point
287
288
# normalize point cloud based on apex coordinate
289
def normalize_data(PC, Coord_apex):
290
    """ Normalize the point cloud, use coordinates of centroid/ apex,
291
        Input:
292
            NxC array
293
        Output:
294
            NxC array
295
    """
296
    N, C = PC.shape
297
    normal_data = np.zeros((N, C))
298
    # centroid = np.mean(PC, axis=0)
299
    PC = PC - Coord_apex
300
    # m = np.max(np.sqrt(np.sum(PC ** 2, axis=1)))
301
    # PC = PC / m
302
    # normal_data = PC
303
304
    # compute the minimum and maximum values of each coordinate
305
    min_coords = np.min(PC, axis=0)
306
    max_coords = np.max(PC, axis=0)
307
308
    # normalize the point cloud coordinates
309
    normal_data = (PC - min_coords) / (max_coords - min_coords)
310
311
    return normal_data
312
313
def resample_pcd_ATM(pcd, ATM, n):
314
    """Drop or duplicate points so that pcd has exactly n points"""
315
    idx_root_nodes = np.where(ATM[:, 0]==1.0) # ATM[:, 0]
316
    prob = 1/(pcd.shape[0]-idx_root_nodes[0].shape[0])
317
    node_prob = prob*np.ones(pcd.shape[0])
318
    node_prob[idx_root_nodes] = 0
319
    idx = np.random.choice(np.arange(pcd.shape[0]), n-idx_root_nodes[0].shape[0], p=node_prob, replace=False)
320
    idx_remained = np.union1d(idx, idx_root_nodes)
321
    # idx_updated_permuted = np.random.permutation(idx_updated)
322
    # if idx_updated_permuted.shape[0] < n:
323
    #     idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])])
324
    
325
    return pcd[idx_remained], ATM[idx_remained], idx_remained
326
327
def resample_pcd_ATM_ori(pcd, ATM, n):
328
    """Drop or duplicate points so that pcd has exactly n points"""
329
    idx = np.random.permutation(pcd.shape[0])
330
    if idx.shape[0] < n:
331
        idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])])
332
    return pcd[idx[:n]], ATM[idx[:n]]
333
334
def resample_gd(gt_output, num_coarse, num_dense): #added by Lei in 2022/02/10 to seperately resample groundtruth label
335
    """Drop or duplicate points so that pcd has exactly n points"""
336
    choice = np.random.choice(len(gt_output), num_coarse, replace=True)
337
    coarse_gt = gt_output[choice, :]
338
    dense_gt = resample_pcd(gt_output, num_dense)
339
    return coarse_gt, dense_gt
340
341
def resample_pcd(pcd, n):
342
    """Drop or duplicate points so that pcd has exactly n points"""
343
    idx = np.random.permutation(pcd.shape[0])
344
    if idx.shape[0] < n:
345
        idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])])
346
    return pcd[idx[:n]], idx[:n]
347
348
349
if __name__ == '__main__':
350
    ROOT = './dataset/'
351
    GT_ROOT = os.path.join(ROOT, 'gt')
352
    PARTIAL_ROOT = os.path.join(ROOT, 'partial')
353
354
    train_dataset = LoadDataset(partial_path=PARTIAL_ROOT, gt_path=GT_ROOT, split='train')
355
    val_dataset = LoadDataset(partial_path=PARTIAL_ROOT, gt_path=GT_ROOT, split='val')
356
    test_dataset = LoadDataset(partial_path=PARTIAL_ROOT, gt_path=GT_ROOT, split='test')
357
    print("\033[33mTraining dataset\033[0m has {} pair of partial and ground truth point clouds".format(len(train_dataset)))
358
    print("\033[33mValidation dataset\033[0m has {} pair of partial and ground truth point clouds".format(len(val_dataset)))
359
    print("\033[33mTesting dataset\033[0m has {} pair of partial and ground truth point clouds".format(len(test_dataset)))
360
361
    # visualization
362
    input_pc, coarse_pc, dense_pc, conditions = train_dataset[random.randint(0, len(train_dataset))-1]
363
    print("partial input point cloud has {} points".format(len(input_pc)))
364
    print("coarse output point cloud has {} points".format(len(coarse_pc)))
365
    print("dense output point cloud has {} points".format(len(dense_pc)))