|
a |
|
b/datasets/SUSTech1K/pretreatment_SUSTech1K.py |
|
|
1 |
# This source is based on https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py |
|
|
2 |
import argparse |
|
|
3 |
import logging |
|
|
4 |
import multiprocessing as mp |
|
|
5 |
import os |
|
|
6 |
import pickle |
|
|
7 |
from collections import defaultdict |
|
|
8 |
from functools import partial |
|
|
9 |
from pathlib import Path |
|
|
10 |
from typing import Tuple |
|
|
11 |
|
|
|
12 |
import cv2 |
|
|
13 |
import numpy as np |
|
|
14 |
from tqdm import tqdm |
|
|
15 |
|
|
|
16 |
import json |
|
|
17 |
import open3d as o3d |
|
|
18 |
|
|
|
19 |
def compare_pcd_rgb_timestamp(pcd_file,rgb_file): |
|
|
20 |
pcd_time = float(pcd_file.split('/')[-1].replace('.pcd','')) + 0.05 |
|
|
21 |
rgb_time = float(rgb_file.split('/')[-1].replace('.jpg','')[:10] + '.' + rgb_file.split('/')[-1].replace('.jpg','')[10:]) |
|
|
22 |
return pcd_time, rgb_time |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def imgs2pickle(img_groups: Tuple, output_path: Path, img_size: int = 64, verbose: bool = False, dataset='CASIAB') -> None: |
|
|
27 |
"""Reads a group of images and saves the data in pickle format. |
|
|
28 |
|
|
|
29 |
Args: |
|
|
30 |
img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths. |
|
|
31 |
output_path (Path): Output path. |
|
|
32 |
img_size (int, optional): Image resizing size. Defaults to 64. |
|
|
33 |
verbose (bool, optional): Display debug info. Defaults to False. |
|
|
34 |
""" |
|
|
35 |
sinfo = img_groups[0] |
|
|
36 |
img_paths = img_groups[1] # path with modality name |
|
|
37 |
to_pickle = [] |
|
|
38 |
cnt = 0 |
|
|
39 |
pcd_list = [] |
|
|
40 |
rgb_list = [] |
|
|
41 |
|
|
|
42 |
threshold = 0.020 # 20 ms |
|
|
43 |
|
|
|
44 |
for index, modality_files in enumerate(img_paths): |
|
|
45 |
data_files = modality_files[1] |
|
|
46 |
modality = modality_files[0] |
|
|
47 |
if modality == 'PCDs': |
|
|
48 |
data = [np.asarray(o3d.io.read_point_cloud(points).points) for points in data_files] |
|
|
49 |
pcd_list = data_files |
|
|
50 |
elif modality == 'RGB_raw': |
|
|
51 |
imgs = [cv2.imread(rgb) for rgb in data_files] |
|
|
52 |
rgb_list = data_files |
|
|
53 |
imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs] |
|
|
54 |
HWs = [img.shape[:2] for img in imgs] |
|
|
55 |
# transpose to (C, H W) |
|
|
56 |
data = [cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for img in imgs] |
|
|
57 |
imgs = [img.transpose(2, 0, 1) for img in imgs] |
|
|
58 |
data = np.asarray(data) |
|
|
59 |
HWs = np.asarray(HWs) |
|
|
60 |
elif modality == 'Sils_raw': |
|
|
61 |
sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files] |
|
|
62 |
data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils] |
|
|
63 |
data = np.asarray(data) |
|
|
64 |
elif modality == 'Sils_aligned': |
|
|
65 |
sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files] |
|
|
66 |
data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils] |
|
|
67 |
data = np.asarray(data) |
|
|
68 |
elif modality == 'Pose': |
|
|
69 |
data = [json.load(open(pose)) for pose in data_files] |
|
|
70 |
data = np.asarray(data) |
|
|
71 |
elif modality == 'PCDs_depths': |
|
|
72 |
imgs = [cv2.imread(rgb) for rgb in data_files] |
|
|
73 |
imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs] |
|
|
74 |
data = [img.transpose(2, 0, 1) for img in imgs] |
|
|
75 |
data = np.asarray(data) |
|
|
76 |
elif modality == 'PCDs_sils': |
|
|
77 |
data = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files] |
|
|
78 |
data = np.asarray(data) |
|
|
79 |
|
|
|
80 |
dst_path = os.path.join(output_path, *sinfo) |
|
|
81 |
os.makedirs(dst_path, exist_ok=True) |
|
|
82 |
if modality == 'RGB_raw': |
|
|
83 |
pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-Camera-Ratios-HW.pkl') |
|
|
84 |
pickle.dump(HWs, open(pkl_path, 'wb')) |
|
|
85 |
cnt += 1 |
|
|
86 |
|
|
|
87 |
if 'PCDs' in modality: |
|
|
88 |
pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-LiDAR-{modality}.pkl') |
|
|
89 |
pickle.dump(data, open(pkl_path, 'wb')) |
|
|
90 |
else: |
|
|
91 |
pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-Camera-{modality}.pkl') |
|
|
92 |
pickle.dump(data, open(pkl_path, 'wb')) |
|
|
93 |
cnt += 1 |
|
|
94 |
|
|
|
95 |
pcd_indexs = [] |
|
|
96 |
rgb_indexs = [] |
|
|
97 |
# print(pcd_list) |
|
|
98 |
for pcd_index in range(len(pcd_list)): |
|
|
99 |
time_diff = 1 |
|
|
100 |
tmp = pcd_index, 0 |
|
|
101 |
for rgb_index in range(len(rgb_list)): |
|
|
102 |
pcd_t, rgb_t = compare_pcd_rgb_timestamp(pcd_list[pcd_index], rgb_list[rgb_index]) |
|
|
103 |
diff = abs(pcd_t - rgb_t) |
|
|
104 |
if diff < time_diff: |
|
|
105 |
tmp = pcd_index, rgb_index |
|
|
106 |
time_diff = diff |
|
|
107 |
if time_diff <= threshold: |
|
|
108 |
pcd_indexs.append(tmp[0]) |
|
|
109 |
rgb_indexs.append(tmp[1]) |
|
|
110 |
|
|
|
111 |
if len(set(pcd_indexs)) != len(pcd_indexs): |
|
|
112 |
print(img_groups[0], pcd_indexs, rgb_indexs, len(pcd_indexs) == len(pcd_indexs)) |
|
|
113 |
|
|
|
114 |
for index, modality_files in enumerate(img_paths): |
|
|
115 |
modality = modality_files[0] |
|
|
116 |
data_files = modality_files[1] |
|
|
117 |
data_files = [data_files[index] for index in pcd_indexs] if 'PCDs' in modality else [data_files[index] for index in rgb_indexs] |
|
|
118 |
|
|
|
119 |
if modality == 'PCDs': |
|
|
120 |
data = [np.asarray(o3d.io.read_point_cloud(points).points) for points in data_files] |
|
|
121 |
pcd_list = data_files |
|
|
122 |
elif modality == 'RGB_raw': |
|
|
123 |
imgs = [cv2.imread(rgb) for rgb in data_files] |
|
|
124 |
rgb_list = data_files |
|
|
125 |
imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs] |
|
|
126 |
HWs = [img.shape[:2] for img in imgs] |
|
|
127 |
# transpose to (C, H W) |
|
|
128 |
data = [cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for img in imgs] |
|
|
129 |
imgs = [img.transpose(2, 0, 1) for img in imgs] |
|
|
130 |
data = np.asarray(data) |
|
|
131 |
HWs = np.asarray(HWs) |
|
|
132 |
elif modality == 'Sils_raw': |
|
|
133 |
sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files] |
|
|
134 |
data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils] |
|
|
135 |
data = np.asarray(data) |
|
|
136 |
elif modality == 'Sils_aligned': |
|
|
137 |
sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files] |
|
|
138 |
data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils] |
|
|
139 |
data = np.asarray(data) |
|
|
140 |
elif modality == 'Pose': |
|
|
141 |
data = [json.load(open(pose)) for pose in data_files] |
|
|
142 |
data = np.asarray(data) |
|
|
143 |
elif modality == 'PCDs_depths': |
|
|
144 |
imgs = [cv2.imread(rgb) for rgb in data_files] |
|
|
145 |
imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs] |
|
|
146 |
data = [img.transpose(2, 0, 1) for img in imgs] |
|
|
147 |
data = np.asarray(data) |
|
|
148 |
elif modality == 'PCDs_sils': |
|
|
149 |
data = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files] |
|
|
150 |
data = np.asarray(data) |
|
|
151 |
|
|
|
152 |
dst_path = os.path.join(output_path, *sinfo) |
|
|
153 |
os.makedirs(dst_path, exist_ok=True) |
|
|
154 |
if modality == 'RGB_raw': |
|
|
155 |
pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-Camera-Ratios-HW.pkl') |
|
|
156 |
pickle.dump(HWs, open(pkl_path, 'wb')) |
|
|
157 |
cnt += 1 |
|
|
158 |
|
|
|
159 |
if 'PCDs' in modality: |
|
|
160 |
pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-LiDAR-{modality}.pkl') |
|
|
161 |
pickle.dump(data, open(pkl_path, 'wb')) |
|
|
162 |
else: |
|
|
163 |
pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-Camera-{modality}.pkl') |
|
|
164 |
pickle.dump(data, open(pkl_path, 'wb')) |
|
|
165 |
cnt += 1 |
|
|
166 |
|
|
|
167 |
|
|
|
168 |
def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: int = 4, verbose: bool = False, dataset: str = 'CASIAB') -> None: |
|
|
169 |
"""Reads a dataset and saves the data in pickle format. |
|
|
170 |
|
|
|
171 |
Args: |
|
|
172 |
input_path (Path): Dataset root path. |
|
|
173 |
output_path (Path): Output path. |
|
|
174 |
img_size (int, optional): Image resizing size. Defaults to 64. |
|
|
175 |
workers (int, optional): Number of thread workers. Defaults to 4. |
|
|
176 |
verbose (bool, optional): Display debug info. Defaults to False. |
|
|
177 |
""" |
|
|
178 |
img_groups = defaultdict(list) |
|
|
179 |
logging.info(f'Listing {input_path}') |
|
|
180 |
total_files = 0 |
|
|
181 |
for id_ in tqdm(sorted(os.listdir(input_path))): |
|
|
182 |
for type_ in os.listdir(os.path.join(input_path,id_)): |
|
|
183 |
for view_ in os.listdir(os.path.join(input_path,id_,type_)): |
|
|
184 |
for modality in sorted(os.listdir(os.path.join(input_path,id_,type_,view_))): |
|
|
185 |
modality_path = os.path.join(input_path,id_,type_,view_,modality) |
|
|
186 |
file_names = sorted(os.listdir(modality_path)) |
|
|
187 |
file_names = [os.path.join(modality_path, file_name) for file_name in file_names] |
|
|
188 |
img_groups[(id_, type_, view_)].append((modality, file_names)) |
|
|
189 |
total_files += 1 |
|
|
190 |
|
|
|
191 |
logging.info(f'Total files listed: {total_files}') |
|
|
192 |
|
|
|
193 |
progress = tqdm(total=len(img_groups), desc='Pretreating', unit='folder') |
|
|
194 |
|
|
|
195 |
with mp.Pool(workers) as pool: |
|
|
196 |
logging.info(f'Start pretreating {input_path}') |
|
|
197 |
for _ in pool.imap_unordered(partial(imgs2pickle, output_path=output_path, img_size=img_size, verbose=verbose, dataset=dataset), img_groups.items()): |
|
|
198 |
progress.update(1) |
|
|
199 |
logging.info('Done') |
|
|
200 |
|
|
|
201 |
|
|
|
202 |
if __name__ == '__main__': |
|
|
203 |
parser = argparse.ArgumentParser(description='OpenGait dataset pretreatment module.') |
|
|
204 |
parser.add_argument('-i', '--input_path', default='', type=str, help='Root path of raw dataset.') |
|
|
205 |
parser.add_argument('-o', '--output_path', default='', type=str, help='Output path of pickled dataset.') |
|
|
206 |
parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log') |
|
|
207 |
parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4') |
|
|
208 |
parser.add_argument('-r', '--img_size', default=64, type=int, help='Image resizing size. Default 64') |
|
|
209 |
parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.') |
|
|
210 |
parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.') |
|
|
211 |
args = parser.parse_args() |
|
|
212 |
|
|
|
213 |
logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s') |
|
|
214 |
|
|
|
215 |
if args.verbose: |
|
|
216 |
logging.getLogger().setLevel(logging.DEBUG) |
|
|
217 |
logging.info('Verbose mode is on.') |
|
|
218 |
for k, v in args.__dict__.items(): |
|
|
219 |
logging.debug(f'{k}: {v}') |
|
|
220 |
|
|
|
221 |
pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset) |