import os
import random
import numpy as np
import torch
import glob
import torch.utils.data as data
import sys
import pyvista
sys.path.append('.')
sys.path.append('..')
from utils import visualize_PC_with_label
import re
class LoadDataset(data.Dataset):
def __init__(self, path, num_input=2048, split='train'): #16384
self.path = path
self.num_input = num_input
self.use_cobiveco = True
self.data_augment = False
self.signal_length = 512
with open(path + 'my_split/{}.list'.format(split), 'r') as f:
filenames = [line.strip() for line in f]
self.metadata = list()
for filename in filenames:
print(filename)
datapath = path + filename + '/'
unit = 0.1
if self.use_cobiveco:
nodesXYZ, label_index = getCobiveco_vtu(datapath + filename + '_cobiveco_AHA17.vtu')
else:
nodesXYZ = np.loadtxt(datapath + filename + '_xyz.csv', delimiter=',')
label_index = np.zeros((nodesXYZ.shape[0], 1))
LVendo_node = np.unique((np.loadtxt(datapath + filename + '_lvface.csv', delimiter=',')-1).astype(int))
RVendo_node = np.unique((np.loadtxt(datapath + filename + '_rvface.csv', delimiter=',')-1).astype(int))
epi_node = np.unique((np.loadtxt(datapath + filename + '_epiface.csv', delimiter=',')-1).astype(int))
label_index[LVendo_node] = 1
label_index[RVendo_node] = 2
label_index[epi_node] = 3
label_index = label_index[..., np.newaxis]
surface_index = np.concatenate((LVendo_node, RVendo_node, epi_node), axis=0)
PC_XYZ_labeled = np.concatenate((unit*nodesXYZ, label_index), axis=1)
electrode_node = np.loadtxt(datapath + filename + '_electrodePositions.csv', delimiter=',')
Coord_base_apex = np.loadtxt(datapath + filename + '_BaseApexCoord.csv', delimiter=',')
Coord_apex, Coord_base = Coord_base_apex[1], Coord_base_apex[0]
electrode_index = 4*np.ones(electrode_node.shape[0], dtype=np.int32)
electrode_XYZ_labeled = np.concatenate((unit*electrode_node, electrode_index[..., np.newaxis]), axis=1)
signal_files = glob.glob(datapath + filename + '*_simulated_ECG' + '*.csv')
num_signal = len(signal_files)
# print(num_signal)
for id in range(num_signal):
MI_index = np.zeros(nodesXYZ.shape[0], dtype=np.int32)
ECG_value = np.loadtxt(signal_files[id], delimiter=',')
ECG_value_u = np.pad(ECG_value, ((0, 0), (0, self.signal_length-ECG_value.shape[1])), 'constant')
MI_type = signal_files[id].replace(path, '').replace(filename, '').replace('_simulated_ECG_', '').replace('.csv', '').replace('\\', '')
if MI_type == 'B1_large_transmural_slow' or MI_type == 'normal' or MI_type == 'A2_30_40_transmural':
continue
if re.compile(r'5_transmural|0_transmural', re.IGNORECASE).search(MI_type): # remove apical MI size test case
continue
if re.compile(r'AHA', re.IGNORECASE).search(MI_type): # remove randomly generated MI
continue
# if not re.compile(r'5_transmural|0_transmural', re.IGNORECASE).search(MI_type) and not (MI_type == 'A2_transmural'): # remove apical MI size test case
# continue
# if not re.compile(r'AHA', re.IGNORECASE).search(MI_type): # test only random MI!
# continue
#
# if MI_type.find('subendo') != -1:
# continue
# if MI_type != 'B3_transmural' and MI_type != 'A3_transmural' and MI_type != 'A2_transmural':
# continue
# print(MI_type)
if MI_type != 'normal':
Scar_filename = signal_files[id].replace('simulated_ECG', 'lvscarnodes')
BZ_filename = signal_files[id].replace('simulated_ECG', 'lvborderzonenodes')
if MI_type == 'B1_large_transmural_slow':
Scar_filename = Scar_filename.replace('_slow', '')
BZ_filename = BZ_filename.replace('_slow', '')
Scar_node = np.unique((np.loadtxt(Scar_filename, delimiter=',')-1).astype(int))
BZ_node = np.unique((np.loadtxt(BZ_filename, delimiter=',')-1).astype(int))
MI_index[Scar_node] = 1
MI_index[BZ_node] = 2
ECG_array = np.array(ECG_value_u)
MI_array = np.array(MI_index)
MI_type_id = np.array(id)
# print(MI_type_id)
partial_PC_labeled_array, idx_remained = resample_pcd(PC_XYZ_labeled, self.num_input)
partial_MI_lab_array = MI_array[idx_remained]
partial_PC_labeled_array_coarse, idx_remained = resample_pcd(PC_XYZ_labeled, self.num_input//4)
# visualize_PC_with_label(partial_PC_labeled_array[:, 0:3], partial_MI_array)
partial_PC_electrode_labeled_array = partial_PC_labeled_array # np.concatenate((partial_PC_labeled_array, electrode_XYZ_labeled), axis=0)
partial_PC_electrode_XYZ = partial_PC_electrode_labeled_array[:, 0:3]
partial_PC_electrode_lab = partial_PC_electrode_labeled_array[:, 3:]
# partial_MI_lab_array = partial_MI_lab_array + np.where(partial_PC_electrode_labeled_array[0:self.num_input, -1]==1.0, 3, 0)
# visualize_PC_with_label(partial_PC_labeled_array[:, 0:3], partial_MI_lab_array)
partial_PC_electrode_XYZ_normalized = normalize_data(partial_PC_electrode_XYZ, Coord_apex)
if self.data_augment:
scaling = random.uniform(0.8, 1.2)
partial_PC_electrode_XYZ_normalized = scaling*translate_point(jitter_point(rotate_point(partial_PC_electrode_XYZ_normalized, np.random.random()*np.pi)))
partial_PC_electrode_XYZ_normalized_labeled = np.concatenate((partial_PC_electrode_XYZ_normalized, partial_PC_electrode_lab), axis=1)
partial_PC_electrode_XYZ_normalized_coarse = normalize_data(partial_PC_labeled_array_coarse[:, 0:3], Coord_apex)
partial_PC_electrode_XYZ_normalized_labeled_coarse = np.concatenate((partial_PC_electrode_XYZ_normalized_coarse, partial_PC_labeled_array_coarse[:, 3:]), axis=1)
self.metadata.append((partial_PC_electrode_XYZ_normalized_labeled, partial_MI_lab_array, ECG_array, partial_PC_electrode_XYZ_normalized_labeled_coarse, MI_type))
def __getitem__(self, index):
partial_PC_electrode_XYZ_normalized_labeled, partial_MI_array, ECG_array, partial_PC_electrode_XYZ, MI_type = self.metadata[index]
partial_input = torch.from_numpy(partial_PC_electrode_XYZ_normalized_labeled).float()
gt_MI = torch.from_numpy(partial_MI_array).long()
ECG_input = torch.from_numpy(ECG_array).float()
partial_input_coarse = torch.from_numpy(partial_PC_electrode_XYZ).float()
return partial_input, ECG_input, gt_MI, partial_input_coarse, MI_type
def __len__(self):
return len(self.metadata)
class LoadDataset_all(data.Dataset):
def __init__(self, path, num_input=2048, split='train'): #16384
self.path = path
self.num_input = num_input
self.use_cobiveco = False
self.data_augment = False
self.signal_length = 512
with open(path + 'my_split/{}.list'.format(split), 'r') as f:
filenames = [line.strip() for line in f]
self.metadata = list()
for filename in filenames:
print(filename)
datapath = path + filename + '/'
unit = 1
if self.use_cobiveco:
nodesXYZ, label_index = getCobiveco_vtu(datapath + filename + '_heart_cobiveco.vtu')
else:
nodesXYZ = np.loadtxt(datapath + filename + '_xyz.csv', delimiter=',')
label_index = np.zeros((nodesXYZ.shape[0], 1), dtype=np.int)
LVendo_node = np.unique((np.loadtxt(datapath + filename + '_lvface.csv', delimiter=',')-1).astype(int))
RVendo_node = np.unique((np.loadtxt(datapath + filename + '_rvface.csv', delimiter=',')-1).astype(int))
epi_node = np.unique((np.loadtxt(datapath + filename + '_epiface.csv', delimiter=',')-1).astype(int))
label_index[LVendo_node] = 1
label_index[RVendo_node] = 2
label_index[epi_node] = 3
surface_index = np.concatenate((LVendo_node, RVendo_node, epi_node), axis=0)
PC_XYZ_labeled = np.concatenate((unit*nodesXYZ, label_index), axis=1)
electrode_node = np.loadtxt(datapath + filename + '_electrodePositions.csv', delimiter=',')
Coord_base_apex = np.loadtxt(datapath + filename + '_BaseApexCoord.csv', delimiter=',')
Coord_apex, Coord_base = Coord_base_apex[1], Coord_base_apex[0]
electrode_index = 4*np.ones(electrode_node.shape[0], dtype=np.int)
electrode_XYZ_labeled = np.concatenate((unit*electrode_node, electrode_index[..., np.newaxis]), axis=1)
signal_files = glob.glob(datapath + filename + '*_simulated_ECG' + '*.csv')
ECG_list, MI_index_list = list(), list()
MItype_list = list()
num_signal = len(signal_files)
for id in range(num_signal):
MI_index = np.zeros(nodesXYZ.shape[0], dtype=np.int)
ECG_value = np.loadtxt(signal_files[id], delimiter=',')
ECG_value_u = np.pad(ECG_value, ((0, 0), (0, self.signal_length-ECG_value.shape[1])), 'constant')
MI_type = signal_files[id].replace(path, '').replace(filename, '').replace('_simulated_ECG_', '').replace('.csv', '').replace('\\', '')
if MI_type == 'B1_large_transmural_slow' or MI_type == 'B1_large_transmural_slow':
continue
if MI_type != 'normal':
Scar_filename = signal_files[id].replace('simulated_ECG', 'lvscarnodes')
BZ_filename = signal_files[id].replace('simulated_ECG', 'lvborderzonenodes')
Scar_node = np.unique((np.loadtxt(Scar_filename, delimiter=',')-1).astype(int))
BZ_node = np.unique((np.loadtxt(BZ_filename, delimiter=',')-1).astype(int))
MI_index[Scar_node] = 421
MI_index[BZ_node] = 422
ECG_list.append(ECG_value_u)
MI_index_list.append(MI_index)
MItype_list.append(MI_type)
ECG_array = np.array(ECG_list).transpose(1, 2, 0)
MI_array = np.array(MI_index_list).transpose(1, 0)
partial_PC_labeled_array, idx_remained = resample_pcd(PC_XYZ_labeled[surface_index], self.num_input)
partial_MI_array = MI_array[surface_index][idx_remained]
partial_PC_electrode_labeled_array = np.concatenate((partial_PC_labeled_array, electrode_XYZ_labeled), axis=0)
partial_PC_electrode_XYZ = partial_PC_electrode_labeled_array[:, 0:3]
partial_PC_electrode_lab = np.expand_dims(partial_PC_electrode_labeled_array[:, 3], axis=1)
partial_PC_electrode_XYZ_normalized = normalize_data(partial_PC_electrode_XYZ, Coord_apex)
if self.data_augment:
scaling = random.uniform(0.8, 1.2)
partial_PC_electrode_XYZ_normalized = scaling*translate_point(jitter_point(rotate_point(partial_PC_electrode_XYZ_normalized, np.random.random()*np.pi)))
partial_PC_electrode_XYZ_normalized_labeled = np.concatenate((partial_PC_electrode_XYZ_normalized, partial_PC_electrode_lab), axis=1)
self.metadata.append((partial_PC_electrode_XYZ_normalized_labeled, partial_MI_array, ECG_array, partial_PC_electrode_XYZ))
def __getitem__(self, index):
partial_PC_electrode_XYZ_normalized_labeled, partial_MI_array, ECG_array, partial_PC_electrode_XYZ = self.metadata[index]
ECG_array[np.isnan(ECG_array)] = 0 # ECG output with a size of [n_batch, 8*256], covert the nan value into 0
partial_input = torch.from_numpy(partial_PC_electrode_XYZ_normalized_labeled).float()
gt_MI, ECG_input = torch.from_numpy(partial_MI_array).float(), torch.from_numpy(ECG_array).float()
partial_input_ori = torch.from_numpy(partial_PC_electrode_XYZ).float()
return partial_input, ECG_input, gt_MI, partial_input_ori
def __len__(self):
return len(self.metadata)
def getCobiveco_vtu(cobiveco_fileName): # Read Cobiveco data in .vtu format (added by Lei on 2023/01/30)
cobiveco_vol = pyvista.read(cobiveco_fileName) #, force_ext='.vtu'
cobiveco_nodesXYZ = cobiveco_vol.points
cobiveco_nodes_array = cobiveco_vol.point_data
# Apex-to-Base - ab
ab = cobiveco_nodes_array['ab']
# Rotation angle - rt
rt = cobiveco_nodes_array['rt']
# Transmurality - tm
tm = cobiveco_nodes_array['tm']
# Ventricle - tv
tv = cobiveco_nodes_array['tv']
# AHA-17 map - aha
aha = cobiveco_nodes_array['aha']
return cobiveco_nodesXYZ, np.transpose(np.array([ab, rt, tm, tv, aha], dtype=float))
### point cloud augmentation ###
# translate point cloud
def translate_point(point):
point = np.array(point)
shift = [random.uniform(-0.5, 0.5), random.uniform(-0.5, 0.5), random.uniform(-0.5, 0.5)]
shift = np.expand_dims(np.array(shift), axis=0)
shifted_point = np.repeat(shift, point.shape[0], axis=0)
shifted_point += point
return shifted_point
# add Gaussian noise
def jitter_point(point, sigma=0.01, clip=0.01):
assert(clip > 0)
point = np.array(point)
point = point.reshape(-1,3)
Row, Col = point.shape
jittered_point = np.clip(sigma * np.random.randn(Row, Col), -1*clip, clip)
jittered_point += point
return jittered_point
# rotate point cloud
def rotate_point(point, rotation_angle=0.5*np.pi):
point = np.array(point)
cos_theta = np.cos(rotation_angle)
sin_theta = np.sin(rotation_angle)
# Rotation around X axis
rotation_matrix_X = np.array([[1, 0, 0],
[0, cos_theta, -sin_theta],
[0, sin_theta, cos_theta]])
# Rotation around Y axis
rotation_matrix_Y = np.array([[cos_theta, 0, sin_theta],
[0, 1, 0],
[-sin_theta, 0, cos_theta]])
# Rotation around Z axis
rotation_matrix_Z = np.array([[cos_theta, sin_theta, 0],
[-sin_theta, cos_theta, 0],
[0, 0, 1]])
rotated_point = np.dot(point.reshape(-1, 3), rotation_matrix_Z)
return rotated_point
# normalize point cloud based on apex coordinate
def normalize_data(PC, Coord_apex):
""" Normalize the point cloud, use coordinates of centroid/ apex,
Input:
NxC array
Output:
NxC array
"""
N, C = PC.shape
normal_data = np.zeros((N, C))
# centroid = np.mean(PC, axis=0)
PC = PC - Coord_apex
# m = np.max(np.sqrt(np.sum(PC ** 2, axis=1)))
# PC = PC / m
# normal_data = PC
# compute the minimum and maximum values of each coordinate
min_coords = np.min(PC, axis=0)
max_coords = np.max(PC, axis=0)
# normalize the point cloud coordinates
normal_data = (PC - min_coords) / (max_coords - min_coords)
return normal_data
def resample_pcd_ATM(pcd, ATM, n):
"""Drop or duplicate points so that pcd has exactly n points"""
idx_root_nodes = np.where(ATM[:, 0]==1.0) # ATM[:, 0]
prob = 1/(pcd.shape[0]-idx_root_nodes[0].shape[0])
node_prob = prob*np.ones(pcd.shape[0])
node_prob[idx_root_nodes] = 0
idx = np.random.choice(np.arange(pcd.shape[0]), n-idx_root_nodes[0].shape[0], p=node_prob, replace=False)
idx_remained = np.union1d(idx, idx_root_nodes)
# idx_updated_permuted = np.random.permutation(idx_updated)
# if idx_updated_permuted.shape[0] < n:
# idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])])
return pcd[idx_remained], ATM[idx_remained], idx_remained
def resample_pcd_ATM_ori(pcd, ATM, n):
"""Drop or duplicate points so that pcd has exactly n points"""
idx = np.random.permutation(pcd.shape[0])
if idx.shape[0] < n:
idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])])
return pcd[idx[:n]], ATM[idx[:n]]
def resample_gd(gt_output, num_coarse, num_dense): #added by Lei in 2022/02/10 to seperately resample groundtruth label
"""Drop or duplicate points so that pcd has exactly n points"""
choice = np.random.choice(len(gt_output), num_coarse, replace=True)
coarse_gt = gt_output[choice, :]
dense_gt = resample_pcd(gt_output, num_dense)
return coarse_gt, dense_gt
def resample_pcd(pcd, n):
"""Drop or duplicate points so that pcd has exactly n points"""
idx = np.random.permutation(pcd.shape[0])
if idx.shape[0] < n:
idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])])
return pcd[idx[:n]], idx[:n]
if __name__ == '__main__':
ROOT = './dataset/'
GT_ROOT = os.path.join(ROOT, 'gt')
PARTIAL_ROOT = os.path.join(ROOT, 'partial')
train_dataset = LoadDataset(partial_path=PARTIAL_ROOT, gt_path=GT_ROOT, split='train')
val_dataset = LoadDataset(partial_path=PARTIAL_ROOT, gt_path=GT_ROOT, split='val')
test_dataset = LoadDataset(partial_path=PARTIAL_ROOT, gt_path=GT_ROOT, split='test')
print("\033[33mTraining dataset\033[0m has {} pair of partial and ground truth point clouds".format(len(train_dataset)))
print("\033[33mValidation dataset\033[0m has {} pair of partial and ground truth point clouds".format(len(val_dataset)))
print("\033[33mTesting dataset\033[0m has {} pair of partial and ground truth point clouds".format(len(test_dataset)))
# visualization
input_pc, coarse_pc, dense_pc, conditions = train_dataset[random.randint(0, len(train_dataset))-1]
print("partial input point cloud has {} points".format(len(input_pc)))
print("coarse output point cloud has {} points".format(len(coarse_pc)))
print("dense output point cloud has {} points".format(len(dense_pc)))