|
a |
|
b/datasets/pretreatment_heatmap.py |
|
|
1 |
import os |
|
|
2 |
import cv2 |
|
|
3 |
import yaml |
|
|
4 |
import math |
|
|
5 |
import torch |
|
|
6 |
import random |
|
|
7 |
import pickle |
|
|
8 |
import argparse |
|
|
9 |
import numpy as np |
|
|
10 |
from glob import glob |
|
|
11 |
from tqdm import tqdm |
|
|
12 |
import matplotlib.cm as cm |
|
|
13 |
import torch.distributed as dist |
|
|
14 |
from torchvision import transforms as T |
|
|
15 |
from torch.utils.data import Dataset, DataLoader |
|
|
16 |
from sklearn.impute import KNNImputer, SimpleImputer |
|
|
17 |
|
|
|
18 |
torch.manual_seed(347) |
|
|
19 |
random.seed(347) |
|
|
20 |
|
|
|
21 |
######################################################################################################### |
|
|
22 |
# The following code is the base class code for generating heatmap. |
|
|
23 |
######################################################################################################### |
|
|
24 |
|
|
|
25 |
class GeneratePoseTarget: |
|
|
26 |
"""Generate pseudo heatmaps based on joint coordinates and confidence. |
|
|
27 |
Required keys are "keypoint", "img_shape", "keypoint_score" (optional), |
|
|
28 |
added or modified keys are "imgs". |
|
|
29 |
Args: |
|
|
30 |
sigma (float): The sigma of the generated gaussian map. Default: 0.6. |
|
|
31 |
use_score (bool): Use the confidence score of keypoints as the maximum |
|
|
32 |
of the gaussian maps. Default: True. |
|
|
33 |
with_kp (bool): Generate pseudo heatmaps for keypoints. Default: True. |
|
|
34 |
with_limb (bool): Generate pseudo heatmaps for limbs. At least one of |
|
|
35 |
'with_kp' and 'with_limb' should be True. Default: False. |
|
|
36 |
skeletons (tuple[tuple]): The definition of human skeletons. |
|
|
37 |
Default: ((0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (5, 7), (7, 9), |
|
|
38 |
(0, 6), (6, 8), (8, 10), (5, 11), (11, 13), (13, 15), |
|
|
39 |
(6, 12), (12, 14), (14, 16), (11, 12)), |
|
|
40 |
which is the definition of COCO-17p skeletons. |
|
|
41 |
double (bool): Output both original heatmaps and flipped heatmaps. |
|
|
42 |
Default: False. |
|
|
43 |
left_kp (tuple[int]): Indexes of left keypoints, which is used when |
|
|
44 |
flipping heatmaps. Default: (1, 3, 5, 7, 9, 11, 13, 15), |
|
|
45 |
which is left keypoints in COCO-17p. |
|
|
46 |
right_kp (tuple[int]): Indexes of right keypoints, which is used when |
|
|
47 |
flipping heatmaps. Default: (2, 4, 6, 8, 10, 12, 14, 16), |
|
|
48 |
which is right keypoints in COCO-17p. |
|
|
49 |
left_limb (tuple[int]): Indexes of left limbs, which is used when |
|
|
50 |
flipping heatmaps. Default: (1, 3, 5, 7, 9, 11, 13, 15), |
|
|
51 |
which is left limbs of skeletons we defined for COCO-17p. |
|
|
52 |
right_limb (tuple[int]): Indexes of right limbs, which is used when |
|
|
53 |
flipping heatmaps. Default: (2, 4, 6, 8, 10, 12, 14, 16), |
|
|
54 |
which is right limbs of skeletons we defined for COCO-17p. |
|
|
55 |
""" |
|
|
56 |
|
|
|
57 |
def __init__(self, |
|
|
58 |
sigma=0.6, |
|
|
59 |
use_score=True, |
|
|
60 |
with_kp=True, |
|
|
61 |
with_limb=False, |
|
|
62 |
skeletons=((0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (5, 7), |
|
|
63 |
(7, 9), (0, 6), (6, 8), (8, 10), (5, 11), (11, 13), |
|
|
64 |
(13, 15), (6, 12), (12, 14), (14, 16), (11, 12)), |
|
|
65 |
double=False, |
|
|
66 |
left_kp=(1, 3, 5, 7, 9, 11, 13, 15), |
|
|
67 |
right_kp=(2, 4, 6, 8, 10, 12, 14, 16), |
|
|
68 |
left_limb=(0, 2, 4, 5, 6, 10, 11, 12), |
|
|
69 |
right_limb=(1, 3, 7, 8, 9, 13, 14, 15), |
|
|
70 |
scaling=1., |
|
|
71 |
eps= 1e-3, |
|
|
72 |
img_h=64, |
|
|
73 |
img_w = 64): |
|
|
74 |
|
|
|
75 |
self.sigma = sigma |
|
|
76 |
self.use_score = use_score |
|
|
77 |
self.with_kp = with_kp |
|
|
78 |
self.with_limb = with_limb |
|
|
79 |
self.double = double |
|
|
80 |
self.eps = eps |
|
|
81 |
|
|
|
82 |
assert self.with_kp + self.with_limb == 1, ('One of "with_limb" and "with_kp" should be set as True.') |
|
|
83 |
self.left_kp = left_kp |
|
|
84 |
self.right_kp = right_kp |
|
|
85 |
self.skeletons = skeletons |
|
|
86 |
self.left_limb = left_limb |
|
|
87 |
self.right_limb = right_limb |
|
|
88 |
self.scaling = scaling |
|
|
89 |
self.img_h = img_h |
|
|
90 |
self.img_w = img_w |
|
|
91 |
|
|
|
92 |
def generate_a_heatmap(self, arr, centers, max_values, point_center): |
|
|
93 |
"""Generate pseudo heatmap for one keypoint in one frame. |
|
|
94 |
Args: |
|
|
95 |
arr (np.ndarray): The array to store the generated heatmaps. Shape: img_h * img_w. |
|
|
96 |
centers (np.ndarray): The coordinates of corresponding keypoints (of multiple persons). Shape: 1 * 2. |
|
|
97 |
max_values (np.ndarray): The max values of each keypoint. Shape: (1, ). |
|
|
98 |
point_center: Shape: (1, 2) |
|
|
99 |
Returns: |
|
|
100 |
np.ndarray: The generated pseudo heatmap. |
|
|
101 |
""" |
|
|
102 |
|
|
|
103 |
sigma = self.sigma |
|
|
104 |
img_h, img_w = arr.shape |
|
|
105 |
|
|
|
106 |
for center, max_value in zip(centers, max_values): |
|
|
107 |
if max_value < self.eps: |
|
|
108 |
continue |
|
|
109 |
|
|
|
110 |
mu_x, mu_y = center[0], center[1] |
|
|
111 |
|
|
|
112 |
tmp_st_x = int(mu_x - 3 * sigma) |
|
|
113 |
tmp_ed_x = int(mu_x + 3 * sigma) |
|
|
114 |
tmp_st_y = int(mu_y - 3 * sigma) |
|
|
115 |
tmp_ed_y = int(mu_y + 3 * sigma) |
|
|
116 |
|
|
|
117 |
st_x = max(tmp_st_x, 0) |
|
|
118 |
ed_x = min(tmp_ed_x + 1, img_w) |
|
|
119 |
st_y = max(tmp_st_y, 0) |
|
|
120 |
ed_y = min(tmp_ed_y + 1, img_h) |
|
|
121 |
x = np.arange(st_x, ed_x, 1, np.float32) |
|
|
122 |
y = np.arange(st_y, ed_y, 1, np.float32) |
|
|
123 |
|
|
|
124 |
# if the keypoint not in the heatmap coordinate system |
|
|
125 |
if not (len(x) and len(y)): |
|
|
126 |
continue |
|
|
127 |
y = y[:, None] |
|
|
128 |
|
|
|
129 |
patch = np.exp(-((x - mu_x)**2 + (y - mu_y)**2) / 2 / sigma**2) |
|
|
130 |
patch = patch * max_value |
|
|
131 |
|
|
|
132 |
arr[st_y:ed_y, st_x:ed_x] = np.maximum(arr[st_y:ed_y, st_x:ed_x], patch) |
|
|
133 |
|
|
|
134 |
def generate_a_limb_heatmap(self, arr, starts, ends, start_values, end_values, point_center): |
|
|
135 |
"""Generate pseudo heatmap for one limb in one frame. |
|
|
136 |
Args: |
|
|
137 |
arr (np.ndarray): The array to store the generated heatmaps. Shape: img_h * img_w. |
|
|
138 |
starts (np.ndarray): The coordinates of one keypoint in the corresponding limbs. Shape: 1 * 2. |
|
|
139 |
ends (np.ndarray): The coordinates of the other keypoint in the corresponding limbs. Shape: 1 * 2. |
|
|
140 |
start_values (np.ndarray): The max values of one keypoint in the corresponding limbs. Shape: (1, ). |
|
|
141 |
end_values (np.ndarray): The max values of the other keypoint in the corresponding limbs. Shape: (1, ). |
|
|
142 |
Returns: |
|
|
143 |
np.ndarray: The generated pseudo heatmap. |
|
|
144 |
""" |
|
|
145 |
|
|
|
146 |
sigma = self.sigma |
|
|
147 |
img_h, img_w = arr.shape |
|
|
148 |
|
|
|
149 |
for start, end, start_value, end_value in zip(starts, ends, start_values, end_values): |
|
|
150 |
value_coeff = min(start_value, end_value) |
|
|
151 |
if value_coeff < self.eps: |
|
|
152 |
continue |
|
|
153 |
|
|
|
154 |
min_x, max_x = min(start[0], end[0]), max(start[0], end[0]) |
|
|
155 |
min_y, max_y = min(start[1], end[1]), max(start[1], end[1]) |
|
|
156 |
|
|
|
157 |
|
|
|
158 |
|
|
|
159 |
tmp_min_x = int(min_x - 3 * sigma) |
|
|
160 |
tmp_max_x = int(max_x + 3 * sigma) |
|
|
161 |
tmp_min_y = int(min_y - 3 * sigma) |
|
|
162 |
tmp_max_y = int(max_y + 3 * sigma) |
|
|
163 |
|
|
|
164 |
min_x = max(tmp_min_x, 0) |
|
|
165 |
max_x = min(tmp_max_x + 1, img_w) |
|
|
166 |
min_y = max(tmp_min_y, 0) |
|
|
167 |
max_y = min(tmp_max_y + 1, img_h) |
|
|
168 |
|
|
|
169 |
x = np.arange(min_x, max_x, 1, np.float32) |
|
|
170 |
y = np.arange(min_y, max_y, 1, np.float32) |
|
|
171 |
|
|
|
172 |
if not (len(x) and len(y)): |
|
|
173 |
continue |
|
|
174 |
|
|
|
175 |
y = y[:, None] |
|
|
176 |
x_0 = np.zeros_like(x) |
|
|
177 |
y_0 = np.zeros_like(y) |
|
|
178 |
|
|
|
179 |
# distance to start keypoints |
|
|
180 |
d2_start = ((x - start[0])**2 + (y - start[1])**2) |
|
|
181 |
|
|
|
182 |
# distance to end keypoints |
|
|
183 |
d2_end = ((x - end[0])**2 + (y - end[1])**2) |
|
|
184 |
|
|
|
185 |
# the distance between start and end keypoints. |
|
|
186 |
d2_ab = ((start[0] - end[0])**2 + (start[1] - end[1])**2) |
|
|
187 |
|
|
|
188 |
if d2_ab < 1: |
|
|
189 |
self.generate_a_heatmap(arr, start[None], start_value[None], point_center) |
|
|
190 |
continue |
|
|
191 |
|
|
|
192 |
coeff = (d2_start - d2_end + d2_ab) / 2. / d2_ab |
|
|
193 |
|
|
|
194 |
a_dominate = coeff <= 0 |
|
|
195 |
b_dominate = coeff >= 1 |
|
|
196 |
seg_dominate = 1 - a_dominate - b_dominate |
|
|
197 |
|
|
|
198 |
position = np.stack([x + y_0, y + x_0], axis=-1) |
|
|
199 |
projection = start + np.stack([coeff, coeff], axis=-1) * (end - start) |
|
|
200 |
d2_line = position - projection |
|
|
201 |
d2_line = d2_line[:, :, 0]**2 + d2_line[:, :, 1]**2 |
|
|
202 |
d2_seg = a_dominate * d2_start + b_dominate * d2_end + seg_dominate * d2_line |
|
|
203 |
|
|
|
204 |
patch = np.exp(-d2_seg / 2. / sigma**2) |
|
|
205 |
patch = patch * value_coeff |
|
|
206 |
|
|
|
207 |
arr[min_y:max_y, min_x:max_x] = np.maximum(arr[min_y:max_y, min_x:max_x], patch) |
|
|
208 |
def generate_heatmap(self, arr, kps, max_values): |
|
|
209 |
"""Generate pseudo heatmap for all keypoints and limbs in one frame (if |
|
|
210 |
needed). |
|
|
211 |
Args: |
|
|
212 |
arr (np.ndarray): The array to store the generated heatmaps. Shape: V * img_h * img_w. |
|
|
213 |
kps (np.ndarray): The coordinates of keypoints in this frame. Shape: 1 * V * 2. |
|
|
214 |
max_values (np.ndarray): The confidence score of each keypoint. Shape: 1 * V. |
|
|
215 |
Returns: |
|
|
216 |
np.ndarray: The generated pseudo heatmap. |
|
|
217 |
""" |
|
|
218 |
|
|
|
219 |
point_center = kps.mean(1) |
|
|
220 |
|
|
|
221 |
if self.with_kp: |
|
|
222 |
num_kp = kps.shape[1] |
|
|
223 |
for i in range(num_kp): |
|
|
224 |
self.generate_a_heatmap(arr[i], kps[:, i], max_values[:, i], point_center) |
|
|
225 |
|
|
|
226 |
if self.with_limb: |
|
|
227 |
for i, limb in enumerate(self.skeletons): |
|
|
228 |
start_idx, end_idx = limb |
|
|
229 |
starts = kps[:, start_idx] |
|
|
230 |
ends = kps[:, end_idx] |
|
|
231 |
|
|
|
232 |
start_values = max_values[:, start_idx] |
|
|
233 |
end_values = max_values[:, end_idx] |
|
|
234 |
self.generate_a_limb_heatmap(arr[i], starts, ends, start_values, end_values, point_center) |
|
|
235 |
|
|
|
236 |
def gen_an_aug(self, pose_data): |
|
|
237 |
"""Generate pseudo heatmaps for all frames. |
|
|
238 |
Args: |
|
|
239 |
pose_data (array): [1, T, V, C] |
|
|
240 |
Returns: |
|
|
241 |
list[np.ndarray]: The generated pseudo heatmaps. |
|
|
242 |
""" |
|
|
243 |
|
|
|
244 |
all_kps = pose_data[..., :2] |
|
|
245 |
kp_shape = pose_data.shape # [1, T, V, 2] |
|
|
246 |
|
|
|
247 |
if pose_data.shape[-1] == 3: |
|
|
248 |
all_kpscores = pose_data[..., -1] # [1, T, V] |
|
|
249 |
else: |
|
|
250 |
all_kpscores = np.ones(kp_shape[:-1], dtype=np.float32) |
|
|
251 |
|
|
|
252 |
|
|
|
253 |
|
|
|
254 |
# scale img_h, img_w and kps |
|
|
255 |
img_h = int(self.img_h * self.scaling + 0.5) |
|
|
256 |
img_w = int(self.img_w * self.scaling + 0.5) |
|
|
257 |
all_kps[..., :2] *= self.scaling |
|
|
258 |
|
|
|
259 |
num_frame = kp_shape[1] |
|
|
260 |
num_c = 0 |
|
|
261 |
if self.with_kp: |
|
|
262 |
num_c += all_kps.shape[2] |
|
|
263 |
if self.with_limb: |
|
|
264 |
num_c += len(self.skeletons) |
|
|
265 |
ret = np.zeros([num_frame, num_c, img_h, img_w], dtype=np.float32) |
|
|
266 |
|
|
|
267 |
for i in range(num_frame): |
|
|
268 |
# 1, V, C |
|
|
269 |
kps = all_kps[:, i] |
|
|
270 |
# 1, V |
|
|
271 |
kpscores = all_kpscores[:, i] if self.use_score else np.ones_like(all_kpscores[:, i]) |
|
|
272 |
|
|
|
273 |
self.generate_heatmap(ret[i], kps, kpscores) |
|
|
274 |
return ret |
|
|
275 |
|
|
|
276 |
def __call__(self, pose_data): |
|
|
277 |
""" |
|
|
278 |
pose_data: (T, V, C=3/2) |
|
|
279 |
1: means person number |
|
|
280 |
""" |
|
|
281 |
pose_data = pose_data[None,...] # (1, T, V, C=3/2) |
|
|
282 |
|
|
|
283 |
heatmap = self.gen_an_aug(pose_data) |
|
|
284 |
|
|
|
285 |
if self.double: |
|
|
286 |
indices = np.arange(heatmap.shape[1], dtype=np.int64) |
|
|
287 |
left, right = (self.left_kp, self.right_kp) if self.with_kp else (self.left_limb, self.right_limb) |
|
|
288 |
for l, r in zip(left, right): # noqa: E741 |
|
|
289 |
indices[l] = r |
|
|
290 |
indices[r] = l |
|
|
291 |
heatmap_flip = heatmap[..., ::-1][:, indices] |
|
|
292 |
heatmap = np.concatenate([heatmap, heatmap_flip]) |
|
|
293 |
return heatmap |
|
|
294 |
|
|
|
295 |
def __repr__(self): |
|
|
296 |
repr_str = (f'{self.__class__.__name__}(' |
|
|
297 |
f'sigma={self.sigma}, ' |
|
|
298 |
f'use_score={self.use_score}, ' |
|
|
299 |
f'with_kp={self.with_kp}, ' |
|
|
300 |
f'with_limb={self.with_limb}, ' |
|
|
301 |
f'skeletons={self.skeletons}, ' |
|
|
302 |
f'double={self.double}, ' |
|
|
303 |
f'left_kp={self.left_kp}, ' |
|
|
304 |
f'right_kp={self.right_kp})') |
|
|
305 |
return repr_str |
|
|
306 |
|
|
|
307 |
class HeatmapToImage: |
|
|
308 |
""" |
|
|
309 |
Convert the heatmap data to image data. |
|
|
310 |
""" |
|
|
311 |
def __init__(self) -> None: |
|
|
312 |
self.cmap = cm.gray |
|
|
313 |
|
|
|
314 |
def __call__(self, heatmaps): |
|
|
315 |
""" |
|
|
316 |
heatmaps: (T, 17, H, W) |
|
|
317 |
return images: (T, 1, H, W) |
|
|
318 |
""" |
|
|
319 |
heatmaps = [x.transpose(1, 2, 0) for x in heatmaps] |
|
|
320 |
h, w, _ = heatmaps[0].shape |
|
|
321 |
newh, neww = int(h), int(w) |
|
|
322 |
heatmaps = [np.max(x, axis=-1) for x in heatmaps] |
|
|
323 |
heatmaps = [(self.cmap(x)[..., :3] * 255).astype(np.uint8) for x in heatmaps] |
|
|
324 |
heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps] |
|
|
325 |
return np.ascontiguousarray(np.mean(np.array(heatmaps), axis=-1, keepdims=True).transpose(0,3,1,2)) |
|
|
326 |
|
|
|
327 |
class CenterAndScaleNormalizer: |
|
|
328 |
|
|
|
329 |
def __init__(self, pose_format="coco", use_conf=True, heatmap_image_height=128) -> None: |
|
|
330 |
""" |
|
|
331 |
Parameters: |
|
|
332 |
- pose_format (str): Specifies the format of the keypoints. |
|
|
333 |
This parameter determines how the keypoints are structured and indexed. |
|
|
334 |
The supported formats are "coco" or "openpose-x" where 'x' can be either 18 or 25, indicating the number of keypoints used by the OpenPose model. |
|
|
335 |
- use_conf (bool): Indicates whether confidence scores. |
|
|
336 |
- heatmap_image_height (int): Sets the height (in pixels) for the heatmap images that will be normlization. |
|
|
337 |
""" |
|
|
338 |
self.pose_format = pose_format |
|
|
339 |
self.use_conf = use_conf |
|
|
340 |
self.heatmap_image_height = heatmap_image_height |
|
|
341 |
|
|
|
342 |
def __call__(self, data): |
|
|
343 |
""" |
|
|
344 |
Implements step (a) from Figure 2 in the SkeletonGait paper. |
|
|
345 |
data: (T, V, C) |
|
|
346 |
- T: number of frames |
|
|
347 |
- V: number of joints |
|
|
348 |
- C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score |
|
|
349 |
return data: (T, V, C) |
|
|
350 |
""" |
|
|
351 |
|
|
|
352 |
if self.use_conf: |
|
|
353 |
pose_seq = data[..., :-1] |
|
|
354 |
score = np.expand_dims(data[..., -1], axis=-1) |
|
|
355 |
else: |
|
|
356 |
pose_seq = data[..., :-1] |
|
|
357 |
|
|
|
358 |
# Hip as the center point |
|
|
359 |
if self.pose_format.lower() == "coco": |
|
|
360 |
hip = (pose_seq[:, 11] + pose_seq[:, 12]) / 2. # [t, 2] |
|
|
361 |
elif self.pose_format.split('-')[0].lower() == "openpose": |
|
|
362 |
hip = (pose_seq[:, 9] + pose_seq[:, 12]) / 2. # [t, 2] |
|
|
363 |
else: |
|
|
364 |
raise ValueError(f"Error value for pose_format: {self.pose_format} in CenterAndScale Class.") |
|
|
365 |
|
|
|
366 |
# Center-normalization |
|
|
367 |
pose_seq = pose_seq - hip[:, np.newaxis, :] |
|
|
368 |
|
|
|
369 |
# Scale-normalization |
|
|
370 |
y_max = np.max(pose_seq[:, :, 1], axis=-1) # [t] |
|
|
371 |
y_min = np.min(pose_seq[:, :, 1], axis=-1) # [t] |
|
|
372 |
pose_seq *= ((self.heatmap_image_height // 1.5) / (y_max - y_min)[:, np.newaxis, np.newaxis]) # [t, v, 2] |
|
|
373 |
|
|
|
374 |
pose_seq += self.heatmap_image_height // 2 |
|
|
375 |
|
|
|
376 |
if self.use_conf: |
|
|
377 |
pose_seq = np.concatenate([pose_seq, score], axis=-1) |
|
|
378 |
return pose_seq |
|
|
379 |
|
|
|
380 |
class PadKeypoints: |
|
|
381 |
""" |
|
|
382 |
Pad the keypoints with missing values. |
|
|
383 |
""" |
|
|
384 |
|
|
|
385 |
def __init__(self, pad_method="knn", use_conf=True) -> None: |
|
|
386 |
""" |
|
|
387 |
pad_method (str): Specifies the method used to pad the missing values. |
|
|
388 |
The supported methods are "knn" and "simple". |
|
|
389 |
use_conf (bool): Indicates whether confidence scores. |
|
|
390 |
""" |
|
|
391 |
self.use_conf = use_conf |
|
|
392 |
if pad_method.lower() == "knn": |
|
|
393 |
self.imputer = KNNImputer(missing_values=0.0, n_neighbors=4, weights="distance", add_indicator=False) |
|
|
394 |
elif pad_method.lower() == "simple": |
|
|
395 |
self.imputer = SimpleImputer(missing_values=0.0, strategy='mean',add_indicator=True) |
|
|
396 |
else: |
|
|
397 |
raise ValueError(f"Error value for padding method: {pad_method}") |
|
|
398 |
|
|
|
399 |
def __call__(self, raw_data): |
|
|
400 |
""" |
|
|
401 |
raw_data: (T, V, C) |
|
|
402 |
- T: number of frames |
|
|
403 |
- V: number of joints |
|
|
404 |
- C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score |
|
|
405 |
return padded_data: (T, V, C) |
|
|
406 |
""" |
|
|
407 |
T, V, C = raw_data.shape |
|
|
408 |
if self.use_conf: |
|
|
409 |
data = raw_data[..., :-1] |
|
|
410 |
score = np.expand_dims(raw_data[..., -1], axis=-1) |
|
|
411 |
C = C - 1 |
|
|
412 |
else: |
|
|
413 |
data = raw_data[..., :-1] |
|
|
414 |
data = data.reshape((T, V*C)) |
|
|
415 |
padded_data = self.imputer.fit_transform(data) |
|
|
416 |
try: |
|
|
417 |
padded_data = padded_data.reshape((T, V, C)) |
|
|
418 |
except: |
|
|
419 |
padded_data = data.reshape((T, V, C)) |
|
|
420 |
if self.use_conf: |
|
|
421 |
padded_data = np.concatenate([padded_data, score], axis=-1) |
|
|
422 |
return padded_data |
|
|
423 |
|
|
|
424 |
class COCO18toCOCO17: |
|
|
425 |
""" |
|
|
426 |
Transfer COCO18 format (Openpose extracted) to COCO17 format |
|
|
427 |
""" |
|
|
428 |
|
|
|
429 |
def __init__(self, transfer_to_coco17=True): |
|
|
430 |
""" |
|
|
431 |
transfer_to_coco17 (bool): Indicates whether to transfer the keypoints from COCO18 to COCO17 format. |
|
|
432 |
""" |
|
|
433 |
self.map_dict = { |
|
|
434 |
0: 0,# "nose", |
|
|
435 |
1: 15,# "left_eye", |
|
|
436 |
2: 14,# "right_eye", |
|
|
437 |
3: 17,# "left_ear", |
|
|
438 |
4: 16,# "right_ear", |
|
|
439 |
5: 5,# "left_shoulder", |
|
|
440 |
6: 2,# "right_shoulder", |
|
|
441 |
7: 6,# "left_elbow", |
|
|
442 |
8: 3,# "right_elbow", |
|
|
443 |
9: 7,# "left_wrist", |
|
|
444 |
10: 4,# "right_wrist", |
|
|
445 |
11: 11,# "left_hip", |
|
|
446 |
12: 8,# "right_hip", |
|
|
447 |
13: 12,# "left_knee", |
|
|
448 |
14: 9,# "right_knee", |
|
|
449 |
15: 13,# "left_ankle", |
|
|
450 |
16: 10,# "right_ankle" |
|
|
451 |
} |
|
|
452 |
self.transfer = transfer_to_coco17 |
|
|
453 |
|
|
|
454 |
def __call__(self, data): |
|
|
455 |
|
|
|
456 |
""" |
|
|
457 |
data: (T, 18, C) |
|
|
458 |
- T: number of frames |
|
|
459 |
- 18: number of joints of COCO18 format |
|
|
460 |
- C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score |
|
|
461 |
return data: (T, 17, C) |
|
|
462 |
""" |
|
|
463 |
|
|
|
464 |
if self.transfer: |
|
|
465 |
""" |
|
|
466 |
input data [T, 18, C] coco18 format |
|
|
467 |
return data [T, 17, C] coco17 format |
|
|
468 |
""" |
|
|
469 |
T, _, C = data.shape |
|
|
470 |
coco17_pkl_data = np.zeros((T, 17, C)) |
|
|
471 |
for i in range(17): |
|
|
472 |
coco17_pkl_data[:,i,:] = data[:,self.map_dict[i],:] |
|
|
473 |
return coco17_pkl_data |
|
|
474 |
else: |
|
|
475 |
return data |
|
|
476 |
|
|
|
477 |
class GatherTransform(object): |
|
|
478 |
""" |
|
|
479 |
Gather the different transforms. |
|
|
480 |
""" |
|
|
481 |
def __init__(self, base_transform, transform_bone, transform_joint): |
|
|
482 |
|
|
|
483 |
""" |
|
|
484 |
base_transform: Some common transform, e.g., COCO18toCOCO17, PadKeypoints, CenterAndScale |
|
|
485 |
transform_bone: GeneratePoseTarget for generate bone heatmap |
|
|
486 |
transform_joint: GeneratePoseTarget for generate joint heatmap |
|
|
487 |
""" |
|
|
488 |
self.base_transform = base_transform |
|
|
489 |
self.transform_bone = transform_bone |
|
|
490 |
self.transform_joint = transform_joint |
|
|
491 |
|
|
|
492 |
def __call__(self, pose_data): |
|
|
493 |
x = self.base_transform(pose_data) |
|
|
494 |
heatmap_bone = self.transform_bone(x) # [T, 1, H, W] |
|
|
495 |
heatmap_joint = self.transform_joint(x) # [T, 1, H, W] |
|
|
496 |
heatmap = np.concatenate([heatmap_bone, heatmap_joint], axis=1) |
|
|
497 |
return heatmap |
|
|
498 |
|
|
|
499 |
class HeatmapAlignment(): |
|
|
500 |
def __init__(self, align=True, final_img_size=64, offset=0, heatmap_image_size=128) -> None: |
|
|
501 |
self.align = align |
|
|
502 |
self.final_img_size = final_img_size |
|
|
503 |
self.offset = offset |
|
|
504 |
self.heatmap_image_size = heatmap_image_size |
|
|
505 |
|
|
|
506 |
def center_crop(self, heatmap): |
|
|
507 |
""" |
|
|
508 |
Input: [1, heatmap_image_size, heatmap_image_size] |
|
|
509 |
Output: [1, final_img_size, final_img_size] |
|
|
510 |
""" |
|
|
511 |
raw_heatmap = heatmap[0] |
|
|
512 |
if self.align: |
|
|
513 |
y_sum = raw_heatmap.sum(axis=1) |
|
|
514 |
y_top = (y_sum != 0).argmax(axis=0) |
|
|
515 |
y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0) |
|
|
516 |
height = y_btm - y_top + 1 |
|
|
517 |
raw_heatmap = raw_heatmap[y_top - self.offset: y_btm + 1 + self.offset, (self.heatmap_image_size // 2) - (height // 2) : (self.heatmap_image_size // 2) + (height // 2) + 1] |
|
|
518 |
raw_heatmap = cv2.resize(raw_heatmap, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA) |
|
|
519 |
return raw_heatmap[np.newaxis, :, :] # [1, final_img_size, final_img_size] |
|
|
520 |
|
|
|
521 |
def __call__(self, heatmap_imgs): |
|
|
522 |
""" |
|
|
523 |
heatmap_imgs: (T, 1, raw_size, raw_size) |
|
|
524 |
return (T, 1, final_img_size, final_img_size) |
|
|
525 |
""" |
|
|
526 |
heatmap_imgs = heatmap_imgs / 255. |
|
|
527 |
heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs]) |
|
|
528 |
return (heatmap_imgs * 255).astype('uint8') |
|
|
529 |
|
|
|
530 |
def GenerateHeatmapTransform( |
|
|
531 |
coco18tococo17_args, |
|
|
532 |
padkeypoints_args, |
|
|
533 |
norm_args, |
|
|
534 |
heatmap_generator_args, |
|
|
535 |
align_args |
|
|
536 |
): |
|
|
537 |
|
|
|
538 |
base_transform = T.Compose([ |
|
|
539 |
COCO18toCOCO17(**coco18tococo17_args), |
|
|
540 |
PadKeypoints(**padkeypoints_args), |
|
|
541 |
CenterAndScaleNormalizer(**norm_args), |
|
|
542 |
]) |
|
|
543 |
|
|
|
544 |
heatmap_generator_args["with_limb"] = True |
|
|
545 |
heatmap_generator_args["with_kp"] = False |
|
|
546 |
transform_bone = T.Compose([ |
|
|
547 |
GeneratePoseTarget(**heatmap_generator_args), |
|
|
548 |
HeatmapToImage(), |
|
|
549 |
HeatmapAlignment(**align_args) |
|
|
550 |
]) |
|
|
551 |
|
|
|
552 |
heatmap_generator_args["with_limb"] = False |
|
|
553 |
heatmap_generator_args["with_kp"] = True |
|
|
554 |
transform_joint = T.Compose([ |
|
|
555 |
GeneratePoseTarget(**heatmap_generator_args), |
|
|
556 |
HeatmapToImage(), |
|
|
557 |
HeatmapAlignment(**align_args) |
|
|
558 |
]) |
|
|
559 |
|
|
|
560 |
transform = T.Compose([ |
|
|
561 |
GatherTransform(base_transform, transform_bone, transform_joint) # [T, 2, H, W] |
|
|
562 |
]) |
|
|
563 |
|
|
|
564 |
return transform |
|
|
565 |
|
|
|
566 |
######################################################################################################### |
|
|
567 |
# The following code is DDP progress codes. |
|
|
568 |
######################################################################################################### |
|
|
569 |
class SequentialDistributedSampler(torch.utils.data.sampler.Sampler): |
|
|
570 |
""" |
|
|
571 |
Distributed Sampler that subsamples indicies sequentially, |
|
|
572 |
making it easier to collate all results at the end. |
|
|
573 |
Even though we only use this sampler for eval and predict (no training), |
|
|
574 |
which means that the model params won't have to be synced (i.e. will not hang |
|
|
575 |
for synchronization even if varied number of forward passes), we still add extra |
|
|
576 |
samples to the sampler to make it evenly divisible (like in `DistributedSampler`) |
|
|
577 |
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. |
|
|
578 |
""" |
|
|
579 |
|
|
|
580 |
def __init__(self, dataset, batch_size, rank=None, num_replicas=None): |
|
|
581 |
if num_replicas is None: |
|
|
582 |
if not torch.distributed.is_available(): |
|
|
583 |
raise RuntimeError("Requires distributed package to be available") |
|
|
584 |
num_replicas = torch.distributed.get_world_size() |
|
|
585 |
if rank is None: |
|
|
586 |
if not torch.distributed.is_available(): |
|
|
587 |
raise RuntimeError("Requires distributed package to be available") |
|
|
588 |
rank = torch.distributed.get_rank() |
|
|
589 |
self.dataset = dataset |
|
|
590 |
self.num_replicas = num_replicas |
|
|
591 |
self.rank = rank |
|
|
592 |
self.batch_size = batch_size |
|
|
593 |
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size |
|
|
594 |
self.total_size = self.num_samples * self.num_replicas |
|
|
595 |
|
|
|
596 |
def __iter__(self): |
|
|
597 |
indices = list(range(len(self.dataset))) |
|
|
598 |
# add extra samples to make it evenly divisible |
|
|
599 |
indices += [indices[-1]] * (self.total_size - len(indices)) |
|
|
600 |
# subsample |
|
|
601 |
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] |
|
|
602 |
return iter(indices) |
|
|
603 |
|
|
|
604 |
def __len__(self): |
|
|
605 |
return self.num_samples |
|
|
606 |
|
|
|
607 |
|
|
|
608 |
class TransferDataset(Dataset): |
|
|
609 |
def __init__(self, args, generate_heatemap_cfgs) -> None: |
|
|
610 |
super().__init__() |
|
|
611 |
pose_root = args.pose_data_path |
|
|
612 |
sigma = generate_heatemap_cfgs['heatmap_generator_args']['sigma'] |
|
|
613 |
self.dataset_name = args.dataset_name |
|
|
614 |
assert self.dataset_name.lower() in ["sustech1k", "grew", "ccpg", "oumvlp", "ou-mvlp", "gait3d", "casiab", "casiae"], f"Invalid dataset name: {self.dataset_name}" |
|
|
615 |
self.save_root = os.path.join(args.save_root, f"{self.dataset_name}_sigma_{sigma}_{args.ext_name}") |
|
|
616 |
os.makedirs(self.save_root, exist_ok=True) |
|
|
617 |
|
|
|
618 |
self.heatmap_transform = GenerateHeatmapTransform(**generate_heatemap_cfgs) |
|
|
619 |
|
|
|
620 |
if self.dataset_name.lower() == "sustech1k": |
|
|
621 |
self.all_ps_data_paths = sorted(glob(os.path.join(pose_root, "*/*/*/03*.pkl"))) |
|
|
622 |
else: |
|
|
623 |
self.all_ps_data_paths = sorted(glob(os.path.join(pose_root, "*/*/*/*.pkl"))) |
|
|
624 |
|
|
|
625 |
def __len__(self): |
|
|
626 |
return len(self.all_ps_data_paths) |
|
|
627 |
|
|
|
628 |
def __getitem__(self, index): |
|
|
629 |
pose_path = self.all_ps_data_paths[index] |
|
|
630 |
with open(pose_path, "rb") as f: |
|
|
631 |
pose_data = pickle.load(f) |
|
|
632 |
if self.dataset_name.lower() == "grew": |
|
|
633 |
# print(pose_data.shape) |
|
|
634 |
pose_data = pose_data[:,2:].reshape(-1, 17, 3) |
|
|
635 |
|
|
|
636 |
tmp_split = pose_path.split('/') |
|
|
637 |
|
|
|
638 |
heatmap_img = self.heatmap_transform(pose_data) # [T, 2, H, W] |
|
|
639 |
|
|
|
640 |
save_path_pkl = os.path.join(self.save_root, 'pkl', *tmp_split[-4:-1]) |
|
|
641 |
os.makedirs(save_path_pkl, exist_ok=True) |
|
|
642 |
|
|
|
643 |
# save some visualization |
|
|
644 |
if index < 10: |
|
|
645 |
# save images |
|
|
646 |
save_path_img = os.path.join(self.save_root, 'images', *tmp_split[-4:-1]) |
|
|
647 |
os.makedirs(save_path_img, exist_ok=True) |
|
|
648 |
# save_heatemapimg_index = random.choice(list(range(heatmap_img.shape[0]))) |
|
|
649 |
for save_heatemapimg_index in range(heatmap_img.shape[0]): |
|
|
650 |
cv2.imwrite(os.path.join(save_path_img, f'bone_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 0]) |
|
|
651 |
cv2.imwrite(os.path.join(save_path_img, f'pose_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 1]) |
|
|
652 |
|
|
|
653 |
pickle.dump(heatmap_img, open(os.path.join(save_path_pkl, tmp_split[-1]), 'wb')) |
|
|
654 |
return None |
|
|
655 |
|
|
|
656 |
def mycollate(_): |
|
|
657 |
return None |
|
|
658 |
|
|
|
659 |
|
|
|
660 |
def get_args(): |
|
|
661 |
parser = argparse.ArgumentParser(description='Utility for generating heatmaps from pose data.') |
|
|
662 |
parser.add_argument('--pose_data_path', type=str, required=True, help="Path to the root directory containing pose data (.pkl files, ID-level) files.") |
|
|
663 |
parser.add_argument('--save_root', type=str, required=True, help="Root directory where generated heatmap .pkl files will be saved (ID-level).") |
|
|
664 |
parser.add_argument('--ext_name', type=str, default='', help="Extension name to be appended to the 'save_root' for identification.") |
|
|
665 |
parser.add_argument('--dataset_name', type=str, required=True, help="Name of the dataset being preprocessed.") |
|
|
666 |
parser.add_argument('--heatemap_cfg_path', type=str, default='configs/skeletongait/pretreatment_heatmap.yaml', help="Path to the heatmap generator configuration file.") |
|
|
667 |
parser.add_argument("--local_rank", type=int, default=0, help="Local rank for distributed processing, defaults to 0 for non-distributed setups.") |
|
|
668 |
opt = parser.parse_args() |
|
|
669 |
return opt |
|
|
670 |
|
|
|
671 |
def replace_variables(data, context=None): |
|
|
672 |
if context is None: |
|
|
673 |
context = {} |
|
|
674 |
|
|
|
675 |
if isinstance(data, dict): |
|
|
676 |
for key, value in data.items(): |
|
|
677 |
data[key] = replace_variables(value, context) |
|
|
678 |
elif isinstance(data, list): |
|
|
679 |
data = [replace_variables(item, context) for item in data] |
|
|
680 |
elif isinstance(data, str): |
|
|
681 |
if data.startswith('${') and data.endswith('}'): |
|
|
682 |
var_path = data[2:-1].split('.') |
|
|
683 |
var_value = context |
|
|
684 |
try: |
|
|
685 |
for part in var_path: |
|
|
686 |
var_value = var_value[part] |
|
|
687 |
return var_value |
|
|
688 |
except KeyError: |
|
|
689 |
raise ValueError(f"Variable {data} not found in context") |
|
|
690 |
return data |
|
|
691 |
|
|
|
692 |
if __name__ == "__main__": |
|
|
693 |
dist.init_process_group("nccl", init_method='env://') |
|
|
694 |
local_rank = torch.distributed.get_rank() |
|
|
695 |
world_size = torch.distributed.get_world_size() |
|
|
696 |
|
|
|
697 |
args = get_args() |
|
|
698 |
|
|
|
699 |
# Load the heatmap generator configuration |
|
|
700 |
with open(args.heatemap_cfg_path, 'r') as stream: |
|
|
701 |
generate_heatemap_cfgs = yaml.safe_load(stream) |
|
|
702 |
generate_heatemap_cfgs = replace_variables(generate_heatemap_cfgs, generate_heatemap_cfgs) |
|
|
703 |
# Create the dataset |
|
|
704 |
dataset = TransferDataset(args, generate_heatemap_cfgs) |
|
|
705 |
|
|
|
706 |
# Create the dataloader |
|
|
707 |
dist_sampler = SequentialDistributedSampler(dataset, batch_size=1, rank=local_rank, num_replicas=world_size) |
|
|
708 |
dataloader = DataLoader(dataset=dataset, batch_size=1, sampler=dist_sampler, num_workers=8, collate_fn=mycollate) |
|
|
709 |
for _, tmp in tqdm(enumerate(dataloader), total=len(dataloader)): |
|
|
710 |
pass |
|
|
711 |
|
|
|
712 |
|