|
a |
|
b/dataset.py |
|
|
1 |
import os |
|
|
2 |
import random |
|
|
3 |
import numpy as np |
|
|
4 |
import torch |
|
|
5 |
import glob |
|
|
6 |
import torch.utils.data as data |
|
|
7 |
import sys |
|
|
8 |
import pyvista |
|
|
9 |
sys.path.append('.') |
|
|
10 |
sys.path.append('..') |
|
|
11 |
from utils import visualize_PC_with_label |
|
|
12 |
import re |
|
|
13 |
|
|
|
14 |
class LoadDataset(data.Dataset): |
|
|
15 |
def __init__(self, path, num_input=2048, split='train'): #16384 |
|
|
16 |
self.path = path |
|
|
17 |
self.num_input = num_input |
|
|
18 |
self.use_cobiveco = True |
|
|
19 |
self.data_augment = False |
|
|
20 |
self.signal_length = 512 |
|
|
21 |
|
|
|
22 |
with open(path + 'my_split/{}.list'.format(split), 'r') as f: |
|
|
23 |
filenames = [line.strip() for line in f] |
|
|
24 |
|
|
|
25 |
self.metadata = list() |
|
|
26 |
for filename in filenames: |
|
|
27 |
print(filename) |
|
|
28 |
datapath = path + filename + '/' |
|
|
29 |
|
|
|
30 |
unit = 0.1 |
|
|
31 |
if self.use_cobiveco: |
|
|
32 |
nodesXYZ, label_index = getCobiveco_vtu(datapath + filename + '_cobiveco_AHA17.vtu') |
|
|
33 |
else: |
|
|
34 |
nodesXYZ = np.loadtxt(datapath + filename + '_xyz.csv', delimiter=',') |
|
|
35 |
label_index = np.zeros((nodesXYZ.shape[0], 1)) |
|
|
36 |
LVendo_node = np.unique((np.loadtxt(datapath + filename + '_lvface.csv', delimiter=',')-1).astype(int)) |
|
|
37 |
RVendo_node = np.unique((np.loadtxt(datapath + filename + '_rvface.csv', delimiter=',')-1).astype(int)) |
|
|
38 |
epi_node = np.unique((np.loadtxt(datapath + filename + '_epiface.csv', delimiter=',')-1).astype(int)) |
|
|
39 |
label_index[LVendo_node] = 1 |
|
|
40 |
label_index[RVendo_node] = 2 |
|
|
41 |
label_index[epi_node] = 3 |
|
|
42 |
label_index = label_index[..., np.newaxis] |
|
|
43 |
surface_index = np.concatenate((LVendo_node, RVendo_node, epi_node), axis=0) |
|
|
44 |
|
|
|
45 |
PC_XYZ_labeled = np.concatenate((unit*nodesXYZ, label_index), axis=1) |
|
|
46 |
electrode_node = np.loadtxt(datapath + filename + '_electrodePositions.csv', delimiter=',') |
|
|
47 |
Coord_base_apex = np.loadtxt(datapath + filename + '_BaseApexCoord.csv', delimiter=',') |
|
|
48 |
Coord_apex, Coord_base = Coord_base_apex[1], Coord_base_apex[0] |
|
|
49 |
electrode_index = 4*np.ones(electrode_node.shape[0], dtype=np.int32) |
|
|
50 |
electrode_XYZ_labeled = np.concatenate((unit*electrode_node, electrode_index[..., np.newaxis]), axis=1) |
|
|
51 |
|
|
|
52 |
signal_files = glob.glob(datapath + filename + '*_simulated_ECG' + '*.csv') |
|
|
53 |
num_signal = len(signal_files) |
|
|
54 |
# print(num_signal) |
|
|
55 |
for id in range(num_signal): |
|
|
56 |
MI_index = np.zeros(nodesXYZ.shape[0], dtype=np.int32) |
|
|
57 |
ECG_value = np.loadtxt(signal_files[id], delimiter=',') |
|
|
58 |
ECG_value_u = np.pad(ECG_value, ((0, 0), (0, self.signal_length-ECG_value.shape[1])), 'constant') |
|
|
59 |
MI_type = signal_files[id].replace(path, '').replace(filename, '').replace('_simulated_ECG_', '').replace('.csv', '').replace('\\', '') |
|
|
60 |
|
|
|
61 |
if MI_type == 'B1_large_transmural_slow' or MI_type == 'normal' or MI_type == 'A2_30_40_transmural': |
|
|
62 |
continue |
|
|
63 |
|
|
|
64 |
if re.compile(r'5_transmural|0_transmural', re.IGNORECASE).search(MI_type): # remove apical MI size test case |
|
|
65 |
continue |
|
|
66 |
|
|
|
67 |
if re.compile(r'AHA', re.IGNORECASE).search(MI_type): # remove randomly generated MI |
|
|
68 |
continue |
|
|
69 |
|
|
|
70 |
# if not re.compile(r'5_transmural|0_transmural', re.IGNORECASE).search(MI_type) and not (MI_type == 'A2_transmural'): # remove apical MI size test case |
|
|
71 |
# continue |
|
|
72 |
|
|
|
73 |
# if not re.compile(r'AHA', re.IGNORECASE).search(MI_type): # test only random MI! |
|
|
74 |
# continue |
|
|
75 |
# |
|
|
76 |
# if MI_type.find('subendo') != -1: |
|
|
77 |
# continue |
|
|
78 |
|
|
|
79 |
# if MI_type != 'B3_transmural' and MI_type != 'A3_transmural' and MI_type != 'A2_transmural': |
|
|
80 |
# continue |
|
|
81 |
|
|
|
82 |
# print(MI_type) |
|
|
83 |
|
|
|
84 |
if MI_type != 'normal': |
|
|
85 |
Scar_filename = signal_files[id].replace('simulated_ECG', 'lvscarnodes') |
|
|
86 |
BZ_filename = signal_files[id].replace('simulated_ECG', 'lvborderzonenodes') |
|
|
87 |
if MI_type == 'B1_large_transmural_slow': |
|
|
88 |
Scar_filename = Scar_filename.replace('_slow', '') |
|
|
89 |
BZ_filename = BZ_filename.replace('_slow', '') |
|
|
90 |
|
|
|
91 |
Scar_node = np.unique((np.loadtxt(Scar_filename, delimiter=',')-1).astype(int)) |
|
|
92 |
BZ_node = np.unique((np.loadtxt(BZ_filename, delimiter=',')-1).astype(int)) |
|
|
93 |
MI_index[Scar_node] = 1 |
|
|
94 |
MI_index[BZ_node] = 2 |
|
|
95 |
ECG_array = np.array(ECG_value_u) |
|
|
96 |
MI_array = np.array(MI_index) |
|
|
97 |
MI_type_id = np.array(id) |
|
|
98 |
# print(MI_type_id) |
|
|
99 |
|
|
|
100 |
partial_PC_labeled_array, idx_remained = resample_pcd(PC_XYZ_labeled, self.num_input) |
|
|
101 |
partial_MI_lab_array = MI_array[idx_remained] |
|
|
102 |
partial_PC_labeled_array_coarse, idx_remained = resample_pcd(PC_XYZ_labeled, self.num_input//4) |
|
|
103 |
# visualize_PC_with_label(partial_PC_labeled_array[:, 0:3], partial_MI_array) |
|
|
104 |
partial_PC_electrode_labeled_array = partial_PC_labeled_array # np.concatenate((partial_PC_labeled_array, electrode_XYZ_labeled), axis=0) |
|
|
105 |
partial_PC_electrode_XYZ = partial_PC_electrode_labeled_array[:, 0:3] |
|
|
106 |
partial_PC_electrode_lab = partial_PC_electrode_labeled_array[:, 3:] |
|
|
107 |
# partial_MI_lab_array = partial_MI_lab_array + np.where(partial_PC_electrode_labeled_array[0:self.num_input, -1]==1.0, 3, 0) |
|
|
108 |
# visualize_PC_with_label(partial_PC_labeled_array[:, 0:3], partial_MI_lab_array) |
|
|
109 |
|
|
|
110 |
partial_PC_electrode_XYZ_normalized = normalize_data(partial_PC_electrode_XYZ, Coord_apex) |
|
|
111 |
if self.data_augment: |
|
|
112 |
scaling = random.uniform(0.8, 1.2) |
|
|
113 |
partial_PC_electrode_XYZ_normalized = scaling*translate_point(jitter_point(rotate_point(partial_PC_electrode_XYZ_normalized, np.random.random()*np.pi))) |
|
|
114 |
partial_PC_electrode_XYZ_normalized_labeled = np.concatenate((partial_PC_electrode_XYZ_normalized, partial_PC_electrode_lab), axis=1) |
|
|
115 |
|
|
|
116 |
partial_PC_electrode_XYZ_normalized_coarse = normalize_data(partial_PC_labeled_array_coarse[:, 0:3], Coord_apex) |
|
|
117 |
partial_PC_electrode_XYZ_normalized_labeled_coarse = np.concatenate((partial_PC_electrode_XYZ_normalized_coarse, partial_PC_labeled_array_coarse[:, 3:]), axis=1) |
|
|
118 |
|
|
|
119 |
self.metadata.append((partial_PC_electrode_XYZ_normalized_labeled, partial_MI_lab_array, ECG_array, partial_PC_electrode_XYZ_normalized_labeled_coarse, MI_type)) |
|
|
120 |
|
|
|
121 |
def __getitem__(self, index): |
|
|
122 |
partial_PC_electrode_XYZ_normalized_labeled, partial_MI_array, ECG_array, partial_PC_electrode_XYZ, MI_type = self.metadata[index] |
|
|
123 |
|
|
|
124 |
partial_input = torch.from_numpy(partial_PC_electrode_XYZ_normalized_labeled).float() |
|
|
125 |
gt_MI = torch.from_numpy(partial_MI_array).long() |
|
|
126 |
ECG_input = torch.from_numpy(ECG_array).float() |
|
|
127 |
partial_input_coarse = torch.from_numpy(partial_PC_electrode_XYZ).float() |
|
|
128 |
|
|
|
129 |
return partial_input, ECG_input, gt_MI, partial_input_coarse, MI_type |
|
|
130 |
|
|
|
131 |
def __len__(self): |
|
|
132 |
return len(self.metadata) |
|
|
133 |
|
|
|
134 |
|
|
|
135 |
class LoadDataset_all(data.Dataset): |
|
|
136 |
def __init__(self, path, num_input=2048, split='train'): #16384 |
|
|
137 |
self.path = path |
|
|
138 |
self.num_input = num_input |
|
|
139 |
self.use_cobiveco = False |
|
|
140 |
self.data_augment = False |
|
|
141 |
self.signal_length = 512 |
|
|
142 |
|
|
|
143 |
with open(path + 'my_split/{}.list'.format(split), 'r') as f: |
|
|
144 |
filenames = [line.strip() for line in f] |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
self.metadata = list() |
|
|
148 |
for filename in filenames: |
|
|
149 |
print(filename) |
|
|
150 |
datapath = path + filename + '/' |
|
|
151 |
|
|
|
152 |
unit = 1 |
|
|
153 |
if self.use_cobiveco: |
|
|
154 |
nodesXYZ, label_index = getCobiveco_vtu(datapath + filename + '_heart_cobiveco.vtu') |
|
|
155 |
else: |
|
|
156 |
nodesXYZ = np.loadtxt(datapath + filename + '_xyz.csv', delimiter=',') |
|
|
157 |
label_index = np.zeros((nodesXYZ.shape[0], 1), dtype=np.int) |
|
|
158 |
LVendo_node = np.unique((np.loadtxt(datapath + filename + '_lvface.csv', delimiter=',')-1).astype(int)) |
|
|
159 |
RVendo_node = np.unique((np.loadtxt(datapath + filename + '_rvface.csv', delimiter=',')-1).astype(int)) |
|
|
160 |
epi_node = np.unique((np.loadtxt(datapath + filename + '_epiface.csv', delimiter=',')-1).astype(int)) |
|
|
161 |
label_index[LVendo_node] = 1 |
|
|
162 |
label_index[RVendo_node] = 2 |
|
|
163 |
label_index[epi_node] = 3 |
|
|
164 |
surface_index = np.concatenate((LVendo_node, RVendo_node, epi_node), axis=0) |
|
|
165 |
|
|
|
166 |
PC_XYZ_labeled = np.concatenate((unit*nodesXYZ, label_index), axis=1) |
|
|
167 |
electrode_node = np.loadtxt(datapath + filename + '_electrodePositions.csv', delimiter=',') |
|
|
168 |
Coord_base_apex = np.loadtxt(datapath + filename + '_BaseApexCoord.csv', delimiter=',') |
|
|
169 |
Coord_apex, Coord_base = Coord_base_apex[1], Coord_base_apex[0] |
|
|
170 |
electrode_index = 4*np.ones(electrode_node.shape[0], dtype=np.int) |
|
|
171 |
electrode_XYZ_labeled = np.concatenate((unit*electrode_node, electrode_index[..., np.newaxis]), axis=1) |
|
|
172 |
|
|
|
173 |
signal_files = glob.glob(datapath + filename + '*_simulated_ECG' + '*.csv') |
|
|
174 |
ECG_list, MI_index_list = list(), list() |
|
|
175 |
MItype_list = list() |
|
|
176 |
num_signal = len(signal_files) |
|
|
177 |
for id in range(num_signal): |
|
|
178 |
MI_index = np.zeros(nodesXYZ.shape[0], dtype=np.int) |
|
|
179 |
ECG_value = np.loadtxt(signal_files[id], delimiter=',') |
|
|
180 |
ECG_value_u = np.pad(ECG_value, ((0, 0), (0, self.signal_length-ECG_value.shape[1])), 'constant') |
|
|
181 |
MI_type = signal_files[id].replace(path, '').replace(filename, '').replace('_simulated_ECG_', '').replace('.csv', '').replace('\\', '') |
|
|
182 |
if MI_type == 'B1_large_transmural_slow' or MI_type == 'B1_large_transmural_slow': |
|
|
183 |
continue |
|
|
184 |
if MI_type != 'normal': |
|
|
185 |
Scar_filename = signal_files[id].replace('simulated_ECG', 'lvscarnodes') |
|
|
186 |
BZ_filename = signal_files[id].replace('simulated_ECG', 'lvborderzonenodes') |
|
|
187 |
Scar_node = np.unique((np.loadtxt(Scar_filename, delimiter=',')-1).astype(int)) |
|
|
188 |
BZ_node = np.unique((np.loadtxt(BZ_filename, delimiter=',')-1).astype(int)) |
|
|
189 |
MI_index[Scar_node] = 421 |
|
|
190 |
MI_index[BZ_node] = 422 |
|
|
191 |
|
|
|
192 |
ECG_list.append(ECG_value_u) |
|
|
193 |
MI_index_list.append(MI_index) |
|
|
194 |
MItype_list.append(MI_type) |
|
|
195 |
|
|
|
196 |
ECG_array = np.array(ECG_list).transpose(1, 2, 0) |
|
|
197 |
MI_array = np.array(MI_index_list).transpose(1, 0) |
|
|
198 |
partial_PC_labeled_array, idx_remained = resample_pcd(PC_XYZ_labeled[surface_index], self.num_input) |
|
|
199 |
partial_MI_array = MI_array[surface_index][idx_remained] |
|
|
200 |
partial_PC_electrode_labeled_array = np.concatenate((partial_PC_labeled_array, electrode_XYZ_labeled), axis=0) |
|
|
201 |
partial_PC_electrode_XYZ = partial_PC_electrode_labeled_array[:, 0:3] |
|
|
202 |
partial_PC_electrode_lab = np.expand_dims(partial_PC_electrode_labeled_array[:, 3], axis=1) |
|
|
203 |
partial_PC_electrode_XYZ_normalized = normalize_data(partial_PC_electrode_XYZ, Coord_apex) |
|
|
204 |
if self.data_augment: |
|
|
205 |
scaling = random.uniform(0.8, 1.2) |
|
|
206 |
partial_PC_electrode_XYZ_normalized = scaling*translate_point(jitter_point(rotate_point(partial_PC_electrode_XYZ_normalized, np.random.random()*np.pi))) |
|
|
207 |
partial_PC_electrode_XYZ_normalized_labeled = np.concatenate((partial_PC_electrode_XYZ_normalized, partial_PC_electrode_lab), axis=1) |
|
|
208 |
|
|
|
209 |
|
|
|
210 |
self.metadata.append((partial_PC_electrode_XYZ_normalized_labeled, partial_MI_array, ECG_array, partial_PC_electrode_XYZ)) |
|
|
211 |
|
|
|
212 |
def __getitem__(self, index): |
|
|
213 |
partial_PC_electrode_XYZ_normalized_labeled, partial_MI_array, ECG_array, partial_PC_electrode_XYZ = self.metadata[index] |
|
|
214 |
|
|
|
215 |
ECG_array[np.isnan(ECG_array)] = 0 # ECG output with a size of [n_batch, 8*256], covert the nan value into 0 |
|
|
216 |
partial_input = torch.from_numpy(partial_PC_electrode_XYZ_normalized_labeled).float() |
|
|
217 |
gt_MI, ECG_input = torch.from_numpy(partial_MI_array).float(), torch.from_numpy(ECG_array).float() |
|
|
218 |
partial_input_ori = torch.from_numpy(partial_PC_electrode_XYZ).float() |
|
|
219 |
|
|
|
220 |
return partial_input, ECG_input, gt_MI, partial_input_ori |
|
|
221 |
|
|
|
222 |
def __len__(self): |
|
|
223 |
return len(self.metadata) |
|
|
224 |
|
|
|
225 |
|
|
|
226 |
def getCobiveco_vtu(cobiveco_fileName): # Read Cobiveco data in .vtu format (added by Lei on 2023/01/30) |
|
|
227 |
cobiveco_vol = pyvista.read(cobiveco_fileName) #, force_ext='.vtu' |
|
|
228 |
|
|
|
229 |
cobiveco_nodesXYZ = cobiveco_vol.points |
|
|
230 |
cobiveco_nodes_array = cobiveco_vol.point_data |
|
|
231 |
# Apex-to-Base - ab |
|
|
232 |
ab = cobiveco_nodes_array['ab'] |
|
|
233 |
# Rotation angle - rt |
|
|
234 |
rt = cobiveco_nodes_array['rt'] |
|
|
235 |
# Transmurality - tm |
|
|
236 |
tm = cobiveco_nodes_array['tm'] |
|
|
237 |
# Ventricle - tv |
|
|
238 |
tv = cobiveco_nodes_array['tv'] |
|
|
239 |
# AHA-17 map - aha |
|
|
240 |
aha = cobiveco_nodes_array['aha'] |
|
|
241 |
|
|
|
242 |
return cobiveco_nodesXYZ, np.transpose(np.array([ab, rt, tm, tv, aha], dtype=float)) |
|
|
243 |
|
|
|
244 |
### point cloud augmentation ### |
|
|
245 |
# translate point cloud |
|
|
246 |
def translate_point(point): |
|
|
247 |
point = np.array(point) |
|
|
248 |
shift = [random.uniform(-0.5, 0.5), random.uniform(-0.5, 0.5), random.uniform(-0.5, 0.5)] |
|
|
249 |
shift = np.expand_dims(np.array(shift), axis=0) |
|
|
250 |
shifted_point = np.repeat(shift, point.shape[0], axis=0) |
|
|
251 |
shifted_point += point |
|
|
252 |
|
|
|
253 |
return shifted_point |
|
|
254 |
|
|
|
255 |
# add Gaussian noise |
|
|
256 |
def jitter_point(point, sigma=0.01, clip=0.01): |
|
|
257 |
assert(clip > 0) |
|
|
258 |
point = np.array(point) |
|
|
259 |
point = point.reshape(-1,3) |
|
|
260 |
Row, Col = point.shape |
|
|
261 |
jittered_point = np.clip(sigma * np.random.randn(Row, Col), -1*clip, clip) |
|
|
262 |
jittered_point += point |
|
|
263 |
|
|
|
264 |
return jittered_point |
|
|
265 |
|
|
|
266 |
# rotate point cloud |
|
|
267 |
def rotate_point(point, rotation_angle=0.5*np.pi): |
|
|
268 |
point = np.array(point) |
|
|
269 |
cos_theta = np.cos(rotation_angle) |
|
|
270 |
sin_theta = np.sin(rotation_angle) |
|
|
271 |
# Rotation around X axis |
|
|
272 |
rotation_matrix_X = np.array([[1, 0, 0], |
|
|
273 |
[0, cos_theta, -sin_theta], |
|
|
274 |
[0, sin_theta, cos_theta]]) |
|
|
275 |
# Rotation around Y axis |
|
|
276 |
rotation_matrix_Y = np.array([[cos_theta, 0, sin_theta], |
|
|
277 |
[0, 1, 0], |
|
|
278 |
[-sin_theta, 0, cos_theta]]) |
|
|
279 |
# Rotation around Z axis |
|
|
280 |
rotation_matrix_Z = np.array([[cos_theta, sin_theta, 0], |
|
|
281 |
[-sin_theta, cos_theta, 0], |
|
|
282 |
[0, 0, 1]]) |
|
|
283 |
|
|
|
284 |
rotated_point = np.dot(point.reshape(-1, 3), rotation_matrix_Z) |
|
|
285 |
|
|
|
286 |
return rotated_point |
|
|
287 |
|
|
|
288 |
# normalize point cloud based on apex coordinate |
|
|
289 |
def normalize_data(PC, Coord_apex): |
|
|
290 |
""" Normalize the point cloud, use coordinates of centroid/ apex, |
|
|
291 |
Input: |
|
|
292 |
NxC array |
|
|
293 |
Output: |
|
|
294 |
NxC array |
|
|
295 |
""" |
|
|
296 |
N, C = PC.shape |
|
|
297 |
normal_data = np.zeros((N, C)) |
|
|
298 |
# centroid = np.mean(PC, axis=0) |
|
|
299 |
PC = PC - Coord_apex |
|
|
300 |
# m = np.max(np.sqrt(np.sum(PC ** 2, axis=1))) |
|
|
301 |
# PC = PC / m |
|
|
302 |
# normal_data = PC |
|
|
303 |
|
|
|
304 |
# compute the minimum and maximum values of each coordinate |
|
|
305 |
min_coords = np.min(PC, axis=0) |
|
|
306 |
max_coords = np.max(PC, axis=0) |
|
|
307 |
|
|
|
308 |
# normalize the point cloud coordinates |
|
|
309 |
normal_data = (PC - min_coords) / (max_coords - min_coords) |
|
|
310 |
|
|
|
311 |
return normal_data |
|
|
312 |
|
|
|
313 |
def resample_pcd_ATM(pcd, ATM, n): |
|
|
314 |
"""Drop or duplicate points so that pcd has exactly n points""" |
|
|
315 |
idx_root_nodes = np.where(ATM[:, 0]==1.0) # ATM[:, 0] |
|
|
316 |
prob = 1/(pcd.shape[0]-idx_root_nodes[0].shape[0]) |
|
|
317 |
node_prob = prob*np.ones(pcd.shape[0]) |
|
|
318 |
node_prob[idx_root_nodes] = 0 |
|
|
319 |
idx = np.random.choice(np.arange(pcd.shape[0]), n-idx_root_nodes[0].shape[0], p=node_prob, replace=False) |
|
|
320 |
idx_remained = np.union1d(idx, idx_root_nodes) |
|
|
321 |
# idx_updated_permuted = np.random.permutation(idx_updated) |
|
|
322 |
# if idx_updated_permuted.shape[0] < n: |
|
|
323 |
# idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])]) |
|
|
324 |
|
|
|
325 |
return pcd[idx_remained], ATM[idx_remained], idx_remained |
|
|
326 |
|
|
|
327 |
def resample_pcd_ATM_ori(pcd, ATM, n): |
|
|
328 |
"""Drop or duplicate points so that pcd has exactly n points""" |
|
|
329 |
idx = np.random.permutation(pcd.shape[0]) |
|
|
330 |
if idx.shape[0] < n: |
|
|
331 |
idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])]) |
|
|
332 |
return pcd[idx[:n]], ATM[idx[:n]] |
|
|
333 |
|
|
|
334 |
def resample_gd(gt_output, num_coarse, num_dense): #added by Lei in 2022/02/10 to seperately resample groundtruth label |
|
|
335 |
"""Drop or duplicate points so that pcd has exactly n points""" |
|
|
336 |
choice = np.random.choice(len(gt_output), num_coarse, replace=True) |
|
|
337 |
coarse_gt = gt_output[choice, :] |
|
|
338 |
dense_gt = resample_pcd(gt_output, num_dense) |
|
|
339 |
return coarse_gt, dense_gt |
|
|
340 |
|
|
|
341 |
def resample_pcd(pcd, n): |
|
|
342 |
"""Drop or duplicate points so that pcd has exactly n points""" |
|
|
343 |
idx = np.random.permutation(pcd.shape[0]) |
|
|
344 |
if idx.shape[0] < n: |
|
|
345 |
idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])]) |
|
|
346 |
return pcd[idx[:n]], idx[:n] |
|
|
347 |
|
|
|
348 |
|
|
|
349 |
if __name__ == '__main__': |
|
|
350 |
ROOT = './dataset/' |
|
|
351 |
GT_ROOT = os.path.join(ROOT, 'gt') |
|
|
352 |
PARTIAL_ROOT = os.path.join(ROOT, 'partial') |
|
|
353 |
|
|
|
354 |
train_dataset = LoadDataset(partial_path=PARTIAL_ROOT, gt_path=GT_ROOT, split='train') |
|
|
355 |
val_dataset = LoadDataset(partial_path=PARTIAL_ROOT, gt_path=GT_ROOT, split='val') |
|
|
356 |
test_dataset = LoadDataset(partial_path=PARTIAL_ROOT, gt_path=GT_ROOT, split='test') |
|
|
357 |
print("\033[33mTraining dataset\033[0m has {} pair of partial and ground truth point clouds".format(len(train_dataset))) |
|
|
358 |
print("\033[33mValidation dataset\033[0m has {} pair of partial and ground truth point clouds".format(len(val_dataset))) |
|
|
359 |
print("\033[33mTesting dataset\033[0m has {} pair of partial and ground truth point clouds".format(len(test_dataset))) |
|
|
360 |
|
|
|
361 |
# visualization |
|
|
362 |
input_pc, coarse_pc, dense_pc, conditions = train_dataset[random.randint(0, len(train_dataset))-1] |
|
|
363 |
print("partial input point cloud has {} points".format(len(input_pc))) |
|
|
364 |
print("coarse output point cloud has {} points".format(len(coarse_pc))) |
|
|
365 |
print("dense output point cloud has {} points".format(len(dense_pc))) |