[fd9ef4]: / datasets / ln_sil_heatmap.py

Download this file

78 lines (66 with data), 3.7 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import sys
import argparse
from tqdm import tqdm
from glob import glob
def get_args():
parser = argparse.ArgumentParser(description='Symlink silouette data and pose data into the same folder for SkeletonGait++ training.')
parser.add_argument('--heatmap_data_path', type=str, required=True, help="path of heatmap data, must be the absolute path.")
parser.add_argument('--silhouette_data_path', type=str, required=True, help="path of silouette data, must be the absolute path.")
parser.add_argument('--dataset_pkl_ext_name', type=str, default='.pkl', help="The extent name for .pkl files of silouettes data.")
parser.add_argument('--output_path', type=str, required=True, help="path of output data")
opt = parser.parse_args()
return opt
def main():
opt = get_args()
heatmap_data_path = opt.heatmap_data_path
silhouette_data_path = opt.silhouette_data_path
if not os.path.exists(heatmap_data_path):
print(f"heatmap data path {heatmap_data_path} does not exist.")
sys.exit(1)
if not os.path.exists(silhouette_data_path):
print(f"silouette data path {silhouette_data_path} does not exist.")
sys.exit(1)
all_heatmap_files = sorted(glob(os.path.join(heatmap_data_path, "*/*/*/*.pkl")))
all_silouette_files = sorted(glob(os.path.join(silhouette_data_path, f"*/*/*/*{opt.dataset_pkl_ext_name}")))
# print(len(all_heatmap_files), len(all_silouette_files))
# assert len(all_heatmap_files) == len(all_silouette_files), "The number of heatmap files and silouette files are not equal."
if len(all_heatmap_files) >= len(all_silouette_files):
for heatmap_file in tqdm(all_heatmap_files):
tmp_list = heatmap_file.split('/')
sil_folder = os.path.join(silhouette_data_path, *tmp_list[-4:-1])
if not os.path.exists(sil_folder):
print(f"silouette folder {sil_folder} does not exist.")
continue
else:
silouette_file = sorted(glob(os.path.join(sil_folder, f"*{opt.dataset_pkl_ext_name}")))[0]
output_file = os.path.join(opt.output_path, *tmp_list[-4:-1])
os.makedirs(output_file, exist_ok=True)
os.system(f"ln -s {silouette_file} {output_file}/1_sil.pkl")
os.system(f"ln -s {heatmap_file} {output_file}/0_heatmap.pkl")
else:
for silouette_file in tqdm(all_silouette_files):
tmp_list = silouette_file.split('/')
heatmap_folder = os.path.join(heatmap_data_path, *tmp_list[-4:-1])
if not os.path.exists(heatmap_folder):
print(f"heatmap folder {heatmap_folder} does not exist.")
continue
else:
heatmap_file = sorted(glob(os.path.join(heatmap_folder, "*.pkl")))[0]
output_file = os.path.join(opt.output_path, *tmp_list[-4:-1])
os.makedirs(output_file, exist_ok=True)
os.system(f"ln -s {silouette_file} {output_file}/1_sil.pkl")
os.system(f"ln -s {heatmap_file} {output_file}/0_heatmap.pkl")
print("Done! Output data is in ", opt.output_path)
# for tmp_file in tqdm(iter_files):
# heatmap_file = all_heatmap_files[i]
# silouette_file = all_silouette_files[i]
# sil_tmp_list = silouette_file.split('/')
# heatmap_tmp_list = heatmap_file.split('/')
# if
# output_file = os.path.join(opt.output_path, *tmp_list[-4:-1])
# os.makedirs(output_file, exist_ok=True)
# os.system(f"ln -s {silouette_file} {output_file}/1_sil.pkl")
# os.system(f"ln -s {heatmap_file} {output_file}/0_heatmap.pkl")
if __name__ == "__main__":
main()