Diff of /dataset.py [000000] .. [390c2f]

Switch to side-by-side view

--- a
+++ b/dataset.py
@@ -0,0 +1,365 @@
+import os
+import random
+import numpy as np
+import torch
+import glob
+import torch.utils.data as data
+import sys
+import pyvista
+sys.path.append('.')
+sys.path.append('..')
+from utils import visualize_PC_with_label
+import re
+
+class LoadDataset(data.Dataset):
+    def __init__(self, path, num_input=2048, split='train'): #16384
+        self.path = path
+        self.num_input = num_input
+        self.use_cobiveco = True
+        self.data_augment = False
+        self.signal_length = 512
+
+        with open(path + 'my_split/{}.list'.format(split), 'r') as f:
+            filenames = [line.strip() for line in f]
+
+        self.metadata = list()
+        for filename in filenames:
+            print(filename)
+            datapath = path + filename + '/'
+
+            unit = 0.1
+            if self.use_cobiveco:
+                nodesXYZ, label_index = getCobiveco_vtu(datapath + filename + '_cobiveco_AHA17.vtu')
+            else:
+                nodesXYZ = np.loadtxt(datapath + filename + '_xyz.csv', delimiter=',')
+                label_index = np.zeros((nodesXYZ.shape[0], 1))
+                LVendo_node = np.unique((np.loadtxt(datapath + filename + '_lvface.csv', delimiter=',')-1).astype(int))
+                RVendo_node = np.unique((np.loadtxt(datapath + filename + '_rvface.csv', delimiter=',')-1).astype(int))
+                epi_node = np.unique((np.loadtxt(datapath + filename + '_epiface.csv', delimiter=',')-1).astype(int))
+                label_index[LVendo_node] = 1
+                label_index[RVendo_node] = 2
+                label_index[epi_node] = 3
+                label_index = label_index[..., np.newaxis]
+                surface_index = np.concatenate((LVendo_node, RVendo_node, epi_node), axis=0) 
+            
+            PC_XYZ_labeled = np.concatenate((unit*nodesXYZ, label_index), axis=1)           
+            electrode_node = np.loadtxt(datapath + filename + '_electrodePositions.csv', delimiter=',')
+            Coord_base_apex = np.loadtxt(datapath + filename + '_BaseApexCoord.csv', delimiter=',')
+            Coord_apex, Coord_base = Coord_base_apex[1], Coord_base_apex[0] 
+            electrode_index = 4*np.ones(electrode_node.shape[0], dtype=np.int32)
+            electrode_XYZ_labeled = np.concatenate((unit*electrode_node, electrode_index[..., np.newaxis]), axis=1)
+            
+            signal_files = glob.glob(datapath + filename + '*_simulated_ECG' + '*.csv')
+            num_signal = len(signal_files)
+            # print(num_signal)
+            for id in range(num_signal):
+                MI_index = np.zeros(nodesXYZ.shape[0], dtype=np.int32)
+                ECG_value = np.loadtxt(signal_files[id], delimiter=',')
+                ECG_value_u = np.pad(ECG_value, ((0, 0), (0, self.signal_length-ECG_value.shape[1])), 'constant')           
+                MI_type = signal_files[id].replace(path, '').replace(filename, '').replace('_simulated_ECG_', '').replace('.csv', '').replace('\\', '')
+                
+                if MI_type == 'B1_large_transmural_slow' or MI_type == 'normal' or MI_type == 'A2_30_40_transmural':
+                    continue
+
+                if re.compile(r'5_transmural|0_transmural', re.IGNORECASE).search(MI_type): # remove apical MI size test case
+                    continue
+
+                if re.compile(r'AHA', re.IGNORECASE).search(MI_type): # remove randomly generated MI
+                    continue
+
+                # if not re.compile(r'5_transmural|0_transmural', re.IGNORECASE).search(MI_type) and not (MI_type == 'A2_transmural'): # remove apical MI size test case
+                #     continue
+
+                # if not re.compile(r'AHA', re.IGNORECASE).search(MI_type): # test only random MI!
+                #     continue
+                #             
+                # if MI_type.find('subendo') != -1:
+                #     continue
+ 
+                # if MI_type != 'B3_transmural' and MI_type != 'A3_transmural' and MI_type != 'A2_transmural':
+                #     continue  
+
+                # print(MI_type)             
+
+                if MI_type != 'normal':
+                    Scar_filename = signal_files[id].replace('simulated_ECG', 'lvscarnodes')
+                    BZ_filename = signal_files[id].replace('simulated_ECG', 'lvborderzonenodes')
+                    if MI_type == 'B1_large_transmural_slow':
+                        Scar_filename = Scar_filename.replace('_slow', '')
+                        BZ_filename = BZ_filename.replace('_slow', '')
+
+                    Scar_node = np.unique((np.loadtxt(Scar_filename, delimiter=',')-1).astype(int))
+                    BZ_node = np.unique((np.loadtxt(BZ_filename, delimiter=',')-1).astype(int))
+                    MI_index[Scar_node] = 1
+                    MI_index[BZ_node] = 2
+                ECG_array = np.array(ECG_value_u)
+                MI_array = np.array(MI_index)
+                MI_type_id = np.array(id)
+                # print(MI_type_id)
+                
+                partial_PC_labeled_array, idx_remained = resample_pcd(PC_XYZ_labeled, self.num_input)
+                partial_MI_lab_array = MI_array[idx_remained]
+                partial_PC_labeled_array_coarse, idx_remained = resample_pcd(PC_XYZ_labeled, self.num_input//4)             
+                # visualize_PC_with_label(partial_PC_labeled_array[:, 0:3], partial_MI_array)
+                partial_PC_electrode_labeled_array = partial_PC_labeled_array # np.concatenate((partial_PC_labeled_array, electrode_XYZ_labeled), axis=0)
+                partial_PC_electrode_XYZ = partial_PC_electrode_labeled_array[:, 0:3]
+                partial_PC_electrode_lab = partial_PC_electrode_labeled_array[:, 3:]
+                # partial_MI_lab_array = partial_MI_lab_array + np.where(partial_PC_electrode_labeled_array[0:self.num_input, -1]==1.0, 3, 0)
+                # visualize_PC_with_label(partial_PC_labeled_array[:, 0:3], partial_MI_lab_array)
+
+                partial_PC_electrode_XYZ_normalized = normalize_data(partial_PC_electrode_XYZ, Coord_apex)
+                if self.data_augment:
+                    scaling = random.uniform(0.8, 1.2)
+                    partial_PC_electrode_XYZ_normalized = scaling*translate_point(jitter_point(rotate_point(partial_PC_electrode_XYZ_normalized, np.random.random()*np.pi)))
+                partial_PC_electrode_XYZ_normalized_labeled = np.concatenate((partial_PC_electrode_XYZ_normalized, partial_PC_electrode_lab), axis=1)
+
+                partial_PC_electrode_XYZ_normalized_coarse = normalize_data(partial_PC_labeled_array_coarse[:, 0:3], Coord_apex)
+                partial_PC_electrode_XYZ_normalized_labeled_coarse = np.concatenate((partial_PC_electrode_XYZ_normalized_coarse, partial_PC_labeled_array_coarse[:, 3:]), axis=1)
+
+                self.metadata.append((partial_PC_electrode_XYZ_normalized_labeled, partial_MI_lab_array, ECG_array, partial_PC_electrode_XYZ_normalized_labeled_coarse, MI_type))
+
+    def __getitem__(self, index):
+        partial_PC_electrode_XYZ_normalized_labeled, partial_MI_array, ECG_array, partial_PC_electrode_XYZ, MI_type = self.metadata[index]
+
+        partial_input = torch.from_numpy(partial_PC_electrode_XYZ_normalized_labeled).float()
+        gt_MI = torch.from_numpy(partial_MI_array).long()
+        ECG_input = torch.from_numpy(ECG_array).float()
+        partial_input_coarse = torch.from_numpy(partial_PC_electrode_XYZ).float()
+
+        return partial_input, ECG_input, gt_MI, partial_input_coarse, MI_type
+
+    def __len__(self):
+        return len(self.metadata)
+
+
+class LoadDataset_all(data.Dataset):
+    def __init__(self, path, num_input=2048, split='train'): #16384
+        self.path = path
+        self.num_input = num_input
+        self.use_cobiveco = False
+        self.data_augment = False
+        self.signal_length = 512
+
+        with open(path + 'my_split/{}.list'.format(split), 'r') as f:
+            filenames = [line.strip() for line in f]
+
+
+        self.metadata = list()
+        for filename in filenames:
+            print(filename)
+            datapath = path + filename + '/'
+
+            unit = 1
+            if self.use_cobiveco:
+                nodesXYZ, label_index = getCobiveco_vtu(datapath + filename + '_heart_cobiveco.vtu')
+            else:
+                nodesXYZ = np.loadtxt(datapath + filename + '_xyz.csv', delimiter=',')
+                label_index = np.zeros((nodesXYZ.shape[0], 1), dtype=np.int)
+                LVendo_node = np.unique((np.loadtxt(datapath + filename + '_lvface.csv', delimiter=',')-1).astype(int))
+                RVendo_node = np.unique((np.loadtxt(datapath + filename + '_rvface.csv', delimiter=',')-1).astype(int))
+                epi_node = np.unique((np.loadtxt(datapath + filename + '_epiface.csv', delimiter=',')-1).astype(int))
+                label_index[LVendo_node] = 1
+                label_index[RVendo_node] = 2
+                label_index[epi_node] = 3
+                surface_index = np.concatenate((LVendo_node, RVendo_node, epi_node), axis=0) 
+            
+            PC_XYZ_labeled = np.concatenate((unit*nodesXYZ, label_index), axis=1)           
+            electrode_node = np.loadtxt(datapath + filename + '_electrodePositions.csv', delimiter=',')
+            Coord_base_apex = np.loadtxt(datapath + filename + '_BaseApexCoord.csv', delimiter=',')
+            Coord_apex, Coord_base = Coord_base_apex[1], Coord_base_apex[0] 
+            electrode_index = 4*np.ones(electrode_node.shape[0], dtype=np.int)
+            electrode_XYZ_labeled = np.concatenate((unit*electrode_node, electrode_index[..., np.newaxis]), axis=1)
+            
+            signal_files = glob.glob(datapath + filename + '*_simulated_ECG' + '*.csv')
+            ECG_list, MI_index_list = list(), list()
+            MItype_list = list()
+            num_signal = len(signal_files)
+            for id in range(num_signal):
+                MI_index = np.zeros(nodesXYZ.shape[0], dtype=np.int)
+                ECG_value = np.loadtxt(signal_files[id], delimiter=',')
+                ECG_value_u = np.pad(ECG_value, ((0, 0), (0, self.signal_length-ECG_value.shape[1])), 'constant')           
+                MI_type = signal_files[id].replace(path, '').replace(filename, '').replace('_simulated_ECG_', '').replace('.csv', '').replace('\\', '')
+                if MI_type == 'B1_large_transmural_slow' or MI_type == 'B1_large_transmural_slow':
+                    continue
+                if MI_type != 'normal':
+                    Scar_filename = signal_files[id].replace('simulated_ECG', 'lvscarnodes')
+                    BZ_filename = signal_files[id].replace('simulated_ECG', 'lvborderzonenodes')
+                    Scar_node = np.unique((np.loadtxt(Scar_filename, delimiter=',')-1).astype(int))
+                    BZ_node = np.unique((np.loadtxt(BZ_filename, delimiter=',')-1).astype(int))
+                    MI_index[Scar_node] = 421
+                    MI_index[BZ_node] = 422
+
+                ECG_list.append(ECG_value_u)
+                MI_index_list.append(MI_index)
+                MItype_list.append(MI_type)
+                                                           
+            ECG_array = np.array(ECG_list).transpose(1, 2, 0) 
+            MI_array = np.array(MI_index_list).transpose(1, 0)                     
+            partial_PC_labeled_array, idx_remained = resample_pcd(PC_XYZ_labeled[surface_index], self.num_input)
+            partial_MI_array = MI_array[surface_index][idx_remained]
+            partial_PC_electrode_labeled_array = np.concatenate((partial_PC_labeled_array, electrode_XYZ_labeled), axis=0)
+            partial_PC_electrode_XYZ = partial_PC_electrode_labeled_array[:, 0:3]
+            partial_PC_electrode_lab = np.expand_dims(partial_PC_electrode_labeled_array[:, 3], axis=1)
+            partial_PC_electrode_XYZ_normalized = normalize_data(partial_PC_electrode_XYZ, Coord_apex)
+            if self.data_augment:
+                scaling = random.uniform(0.8, 1.2)
+                partial_PC_electrode_XYZ_normalized = scaling*translate_point(jitter_point(rotate_point(partial_PC_electrode_XYZ_normalized, np.random.random()*np.pi)))
+            partial_PC_electrode_XYZ_normalized_labeled = np.concatenate((partial_PC_electrode_XYZ_normalized, partial_PC_electrode_lab), axis=1)
+          
+                      
+            self.metadata.append((partial_PC_electrode_XYZ_normalized_labeled, partial_MI_array, ECG_array, partial_PC_electrode_XYZ))
+
+    def __getitem__(self, index):
+        partial_PC_electrode_XYZ_normalized_labeled, partial_MI_array, ECG_array, partial_PC_electrode_XYZ = self.metadata[index]
+
+        ECG_array[np.isnan(ECG_array)] = 0  # ECG output with a size of [n_batch, 8*256], covert the nan value into 0
+        partial_input = torch.from_numpy(partial_PC_electrode_XYZ_normalized_labeled).float()
+        gt_MI, ECG_input = torch.from_numpy(partial_MI_array).float(), torch.from_numpy(ECG_array).float()
+        partial_input_ori = torch.from_numpy(partial_PC_electrode_XYZ).float()
+
+        return partial_input, ECG_input, gt_MI, partial_input_ori
+
+    def __len__(self):
+        return len(self.metadata)
+
+
+def getCobiveco_vtu(cobiveco_fileName): # Read Cobiveco data in .vtu format (added by Lei on 2023/01/30)
+    cobiveco_vol = pyvista.read(cobiveco_fileName) #, force_ext='.vtu'
+
+    cobiveco_nodesXYZ = cobiveco_vol.points
+    cobiveco_nodes_array = cobiveco_vol.point_data
+    # Apex-to-Base - ab
+    ab = cobiveco_nodes_array['ab']
+    # Rotation angle - rt
+    rt = cobiveco_nodes_array['rt']
+    # Transmurality - tm
+    tm = cobiveco_nodes_array['tm']
+    # Ventricle - tv
+    tv = cobiveco_nodes_array['tv']
+    # AHA-17 map - aha
+    aha = cobiveco_nodes_array['aha']
+
+    return cobiveco_nodesXYZ, np.transpose(np.array([ab, rt, tm, tv, aha], dtype=float))
+
+### point cloud augmentation ###
+# translate point cloud
+def translate_point(point):
+    point = np.array(point)
+    shift = [random.uniform(-0.5, 0.5), random.uniform(-0.5, 0.5), random.uniform(-0.5, 0.5)]
+    shift = np.expand_dims(np.array(shift), axis=0)
+    shifted_point = np.repeat(shift, point.shape[0], axis=0)
+    shifted_point += point
+
+    return shifted_point
+
+# add Gaussian noise 
+def jitter_point(point, sigma=0.01, clip=0.01):
+    assert(clip > 0)
+    point = np.array(point)
+    point = point.reshape(-1,3)
+    Row, Col = point.shape
+    jittered_point = np.clip(sigma * np.random.randn(Row, Col), -1*clip, clip)
+    jittered_point += point
+
+    return jittered_point
+
+# rotate point cloud
+def rotate_point(point, rotation_angle=0.5*np.pi):
+    point = np.array(point)
+    cos_theta = np.cos(rotation_angle)
+    sin_theta = np.sin(rotation_angle)
+    # Rotation around X axis
+    rotation_matrix_X = np.array([[1, 0, 0],
+                                [0, cos_theta, -sin_theta],
+                                [0, sin_theta, cos_theta]])
+    # Rotation around Y axis
+    rotation_matrix_Y = np.array([[cos_theta, 0, sin_theta],
+                                [0, 1, 0],
+                                [-sin_theta, 0, cos_theta]])
+    # Rotation around Z axis
+    rotation_matrix_Z = np.array([[cos_theta, sin_theta, 0],
+                                [-sin_theta, cos_theta, 0],
+                                [0, 0, 1]])   
+
+    rotated_point = np.dot(point.reshape(-1, 3), rotation_matrix_Z)
+
+    return rotated_point
+
+# normalize point cloud based on apex coordinate
+def normalize_data(PC, Coord_apex):
+    """ Normalize the point cloud, use coordinates of centroid/ apex,
+        Input:
+            NxC array
+        Output:
+            NxC array
+    """
+    N, C = PC.shape
+    normal_data = np.zeros((N, C))
+    # centroid = np.mean(PC, axis=0)
+    PC = PC - Coord_apex
+    # m = np.max(np.sqrt(np.sum(PC ** 2, axis=1)))
+    # PC = PC / m
+    # normal_data = PC
+
+    # compute the minimum and maximum values of each coordinate
+    min_coords = np.min(PC, axis=0)
+    max_coords = np.max(PC, axis=0)
+
+    # normalize the point cloud coordinates
+    normal_data = (PC - min_coords) / (max_coords - min_coords)
+
+    return normal_data
+
+def resample_pcd_ATM(pcd, ATM, n):
+    """Drop or duplicate points so that pcd has exactly n points"""
+    idx_root_nodes = np.where(ATM[:, 0]==1.0) # ATM[:, 0]
+    prob = 1/(pcd.shape[0]-idx_root_nodes[0].shape[0])
+    node_prob = prob*np.ones(pcd.shape[0])
+    node_prob[idx_root_nodes] = 0
+    idx = np.random.choice(np.arange(pcd.shape[0]), n-idx_root_nodes[0].shape[0], p=node_prob, replace=False)
+    idx_remained = np.union1d(idx, idx_root_nodes)
+    # idx_updated_permuted = np.random.permutation(idx_updated)
+    # if idx_updated_permuted.shape[0] < n:
+    #     idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])])
+    
+    return pcd[idx_remained], ATM[idx_remained], idx_remained
+
+def resample_pcd_ATM_ori(pcd, ATM, n):
+    """Drop or duplicate points so that pcd has exactly n points"""
+    idx = np.random.permutation(pcd.shape[0])
+    if idx.shape[0] < n:
+        idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])])
+    return pcd[idx[:n]], ATM[idx[:n]]
+
+def resample_gd(gt_output, num_coarse, num_dense): #added by Lei in 2022/02/10 to seperately resample groundtruth label
+    """Drop or duplicate points so that pcd has exactly n points"""
+    choice = np.random.choice(len(gt_output), num_coarse, replace=True)
+    coarse_gt = gt_output[choice, :]
+    dense_gt = resample_pcd(gt_output, num_dense)
+    return coarse_gt, dense_gt
+
+def resample_pcd(pcd, n):
+    """Drop or duplicate points so that pcd has exactly n points"""
+    idx = np.random.permutation(pcd.shape[0])
+    if idx.shape[0] < n:
+        idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])])
+    return pcd[idx[:n]], idx[:n]
+
+
+if __name__ == '__main__':
+    ROOT = './dataset/'
+    GT_ROOT = os.path.join(ROOT, 'gt')
+    PARTIAL_ROOT = os.path.join(ROOT, 'partial')
+
+    train_dataset = LoadDataset(partial_path=PARTIAL_ROOT, gt_path=GT_ROOT, split='train')
+    val_dataset = LoadDataset(partial_path=PARTIAL_ROOT, gt_path=GT_ROOT, split='val')
+    test_dataset = LoadDataset(partial_path=PARTIAL_ROOT, gt_path=GT_ROOT, split='test')
+    print("\033[33mTraining dataset\033[0m has {} pair of partial and ground truth point clouds".format(len(train_dataset)))
+    print("\033[33mValidation dataset\033[0m has {} pair of partial and ground truth point clouds".format(len(val_dataset)))
+    print("\033[33mTesting dataset\033[0m has {} pair of partial and ground truth point clouds".format(len(test_dataset)))
+
+    # visualization
+    input_pc, coarse_pc, dense_pc, conditions = train_dataset[random.randint(0, len(train_dataset))-1]
+    print("partial input point cloud has {} points".format(len(input_pc)))
+    print("coarse output point cloud has {} points".format(len(coarse_pc)))
+    print("dense output point cloud has {} points".format(len(dense_pc)))