Switch to unified view

a b/datasets/pretreatment_heatmap.py
1
import os
2
import cv2
3
import yaml
4
import math
5
import torch
6
import random
7
import pickle
8
import argparse
9
import numpy as np
10
from glob import glob 
11
from tqdm import tqdm
12
import matplotlib.cm as cm
13
import torch.distributed as dist
14
from torchvision import transforms as T
15
from torch.utils.data import Dataset, DataLoader
16
from sklearn.impute import KNNImputer, SimpleImputer
17
18
torch.manual_seed(347)
19
random.seed(347)
20
21
#########################################################################################################
22
# The following code is the base class code for generating heatmap.
23
#########################################################################################################
24
25
class GeneratePoseTarget:
26
    """Generate pseudo heatmaps based on joint coordinates and confidence.
27
    Required keys are "keypoint", "img_shape", "keypoint_score" (optional),
28
    added or modified keys are "imgs".
29
    Args:
30
        sigma (float): The sigma of the generated gaussian map. Default: 0.6.
31
        use_score (bool): Use the confidence score of keypoints as the maximum
32
            of the gaussian maps. Default: True.
33
        with_kp (bool): Generate pseudo heatmaps for keypoints. Default: True.
34
        with_limb (bool): Generate pseudo heatmaps for limbs. At least one of
35
            'with_kp' and 'with_limb' should be True. Default: False.
36
        skeletons (tuple[tuple]): The definition of human skeletons.
37
            Default: ((0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (5, 7), (7, 9),
38
                      (0, 6), (6, 8), (8, 10), (5, 11), (11, 13), (13, 15),
39
                      (6, 12), (12, 14), (14, 16), (11, 12)),
40
            which is the definition of COCO-17p skeletons.
41
        double (bool): Output both original heatmaps and flipped heatmaps.
42
            Default: False.
43
        left_kp (tuple[int]): Indexes of left keypoints, which is used when
44
            flipping heatmaps. Default: (1, 3, 5, 7, 9, 11, 13, 15),
45
            which is left keypoints in COCO-17p.
46
        right_kp (tuple[int]): Indexes of right keypoints, which is used when
47
            flipping heatmaps. Default: (2, 4, 6, 8, 10, 12, 14, 16),
48
            which is right keypoints in COCO-17p.
49
        left_limb (tuple[int]): Indexes of left limbs, which is used when
50
            flipping heatmaps. Default: (1, 3, 5, 7, 9, 11, 13, 15),
51
            which is left limbs of skeletons we defined for COCO-17p.
52
        right_limb (tuple[int]): Indexes of right limbs, which is used when
53
            flipping heatmaps. Default: (2, 4, 6, 8, 10, 12, 14, 16),
54
            which is right limbs of skeletons we defined for COCO-17p.
55
    """
56
57
    def __init__(self,
58
                 sigma=0.6,
59
                 use_score=True,
60
                 with_kp=True,
61
                 with_limb=False,
62
                 skeletons=((0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (5, 7),
63
                            (7, 9), (0, 6), (6, 8), (8, 10), (5, 11), (11, 13),
64
                            (13, 15), (6, 12), (12, 14), (14, 16), (11, 12)),
65
                 double=False,
66
                 left_kp=(1, 3, 5, 7, 9, 11, 13, 15),
67
                 right_kp=(2, 4, 6, 8, 10, 12, 14, 16),
68
                 left_limb=(0, 2, 4, 5, 6, 10, 11, 12),
69
                 right_limb=(1, 3, 7, 8, 9, 13, 14, 15),
70
                 scaling=1.,
71
                 eps= 1e-3,
72
                 img_h=64,
73
                 img_w = 64):
74
75
        self.sigma = sigma
76
        self.use_score = use_score
77
        self.with_kp = with_kp
78
        self.with_limb = with_limb
79
        self.double = double
80
        self.eps = eps
81
82
        assert self.with_kp + self.with_limb == 1, ('One of "with_limb" and "with_kp" should be set as True.')
83
        self.left_kp = left_kp
84
        self.right_kp = right_kp
85
        self.skeletons = skeletons
86
        self.left_limb = left_limb
87
        self.right_limb = right_limb
88
        self.scaling = scaling
89
        self.img_h = img_h
90
        self.img_w = img_w
91
92
    def generate_a_heatmap(self, arr, centers, max_values, point_center):
93
        """Generate pseudo heatmap for one keypoint in one frame.
94
        Args:
95
            arr (np.ndarray): The array to store the generated heatmaps. Shape: img_h * img_w.
96
            centers (np.ndarray): The coordinates of corresponding keypoints (of multiple persons). Shape: 1 * 2.
97
            max_values (np.ndarray): The max values of each keypoint. Shape: (1, ).
98
            point_center: Shape: (1, 2)
99
        Returns:
100
            np.ndarray: The generated pseudo heatmap.
101
        """
102
103
        sigma = self.sigma
104
        img_h, img_w = arr.shape
105
106
        for center, max_value in zip(centers, max_values):
107
            if max_value < self.eps:
108
                continue
109
110
            mu_x, mu_y = center[0], center[1]
111
112
            tmp_st_x = int(mu_x - 3 * sigma)
113
            tmp_ed_x = int(mu_x + 3 * sigma)
114
            tmp_st_y = int(mu_y - 3 * sigma)
115
            tmp_ed_y = int(mu_y + 3 * sigma)
116
117
            st_x = max(tmp_st_x, 0) 
118
            ed_x = min(tmp_ed_x + 1, img_w)
119
            st_y = max(tmp_st_y, 0)
120
            ed_y = min(tmp_ed_y + 1, img_h)
121
            x = np.arange(st_x, ed_x, 1, np.float32)
122
            y = np.arange(st_y, ed_y, 1, np.float32)
123
124
            # if the keypoint not in the heatmap coordinate system
125
            if not (len(x) and len(y)):
126
                continue
127
            y = y[:, None]
128
129
            patch = np.exp(-((x - mu_x)**2 + (y - mu_y)**2) / 2 / sigma**2)
130
            patch = patch * max_value
131
132
            arr[st_y:ed_y, st_x:ed_x] = np.maximum(arr[st_y:ed_y, st_x:ed_x], patch)
133
134
    def generate_a_limb_heatmap(self, arr, starts, ends, start_values, end_values, point_center):
135
        """Generate pseudo heatmap for one limb in one frame.
136
        Args:
137
            arr (np.ndarray): The array to store the generated heatmaps. Shape: img_h * img_w.
138
            starts (np.ndarray): The coordinates of one keypoint in the corresponding limbs. Shape: 1 * 2.
139
            ends (np.ndarray): The coordinates of the other keypoint in the corresponding limbs. Shape: 1 * 2.
140
            start_values (np.ndarray): The max values of one keypoint in the corresponding limbs. Shape: (1, ).
141
            end_values (np.ndarray): The max values of the other keypoint in the corresponding limbs. Shape: (1, ).
142
        Returns:
143
            np.ndarray: The generated pseudo heatmap.
144
        """
145
146
        sigma = self.sigma
147
        img_h, img_w = arr.shape
148
149
        for start, end, start_value, end_value in zip(starts, ends, start_values, end_values):
150
            value_coeff = min(start_value, end_value)
151
            if value_coeff < self.eps:
152
                continue
153
154
            min_x, max_x = min(start[0], end[0]), max(start[0], end[0])
155
            min_y, max_y = min(start[1], end[1]), max(start[1], end[1])
156
157
158
            
159
            tmp_min_x = int(min_x - 3 * sigma)
160
            tmp_max_x = int(max_x + 3 * sigma) 
161
            tmp_min_y = int(min_y - 3 * sigma)
162
            tmp_max_y = int(max_y + 3 * sigma)
163
164
            min_x = max(tmp_min_x, 0)
165
            max_x = min(tmp_max_x + 1, img_w)
166
            min_y = max(tmp_min_y, 0)
167
            max_y = min(tmp_max_y + 1, img_h)
168
169
            x = np.arange(min_x, max_x, 1, np.float32)
170
            y = np.arange(min_y, max_y, 1, np.float32)
171
172
            if not (len(x) and len(y)):
173
                continue
174
175
            y = y[:, None]
176
            x_0 = np.zeros_like(x)
177
            y_0 = np.zeros_like(y)
178
179
            # distance to start keypoints
180
            d2_start = ((x - start[0])**2 + (y - start[1])**2)
181
182
            # distance to end keypoints
183
            d2_end = ((x - end[0])**2 + (y - end[1])**2)
184
185
            # the distance between start and end keypoints.
186
            d2_ab = ((start[0] - end[0])**2 + (start[1] - end[1])**2)
187
188
            if d2_ab < 1:
189
                self.generate_a_heatmap(arr, start[None], start_value[None], point_center)
190
                continue
191
192
            coeff = (d2_start - d2_end + d2_ab) / 2. / d2_ab
193
194
            a_dominate = coeff <= 0
195
            b_dominate = coeff >= 1
196
            seg_dominate = 1 - a_dominate - b_dominate
197
198
            position = np.stack([x + y_0, y + x_0], axis=-1)
199
            projection = start + np.stack([coeff, coeff], axis=-1) * (end - start)
200
            d2_line = position - projection
201
            d2_line = d2_line[:, :, 0]**2 + d2_line[:, :, 1]**2
202
            d2_seg = a_dominate * d2_start + b_dominate * d2_end + seg_dominate * d2_line
203
204
            patch = np.exp(-d2_seg / 2. / sigma**2)
205
            patch = patch * value_coeff
206
207
            arr[min_y:max_y, min_x:max_x] = np.maximum(arr[min_y:max_y, min_x:max_x], patch)
208
    def generate_heatmap(self, arr, kps, max_values):
209
        """Generate pseudo heatmap for all keypoints and limbs in one frame (if
210
        needed).
211
        Args:
212
            arr (np.ndarray): The array to store the generated heatmaps. Shape: V * img_h * img_w.
213
            kps (np.ndarray): The coordinates of keypoints in this frame. Shape: 1 * V * 2.
214
            max_values (np.ndarray): The confidence score of each keypoint. Shape: 1 * V.
215
        Returns:
216
            np.ndarray: The generated pseudo heatmap.
217
        """
218
219
        point_center = kps.mean(1)
220
221
        if self.with_kp:
222
            num_kp = kps.shape[1]
223
            for i in range(num_kp):
224
                self.generate_a_heatmap(arr[i], kps[:, i], max_values[:, i], point_center)
225
226
        if self.with_limb:
227
            for i, limb in enumerate(self.skeletons):
228
                start_idx, end_idx = limb
229
                starts = kps[:, start_idx]
230
                ends = kps[:, end_idx]
231
232
                start_values = max_values[:, start_idx]
233
                end_values = max_values[:, end_idx]
234
                self.generate_a_limb_heatmap(arr[i], starts, ends, start_values, end_values, point_center)
235
236
    def gen_an_aug(self, pose_data):
237
        """Generate pseudo heatmaps for all frames.
238
        Args:
239
            pose_data (array): [1, T, V, C]
240
        Returns:
241
            list[np.ndarray]: The generated pseudo heatmaps.
242
        """
243
244
        all_kps = pose_data[..., :2]
245
        kp_shape = pose_data.shape # [1, T, V, 2]
246
247
        if pose_data.shape[-1] == 3:
248
            all_kpscores = pose_data[..., -1] # [1, T, V]
249
        else:
250
            all_kpscores = np.ones(kp_shape[:-1], dtype=np.float32)
251
252
        
253
254
        # scale img_h, img_w and kps
255
        img_h = int(self.img_h * self.scaling + 0.5)
256
        img_w = int(self.img_w * self.scaling + 0.5)
257
        all_kps[..., :2] *= self.scaling
258
259
        num_frame = kp_shape[1]
260
        num_c = 0
261
        if self.with_kp:
262
            num_c += all_kps.shape[2]
263
        if self.with_limb:
264
            num_c += len(self.skeletons)
265
        ret = np.zeros([num_frame, num_c, img_h, img_w], dtype=np.float32)
266
267
        for i in range(num_frame):
268
            # 1, V, C
269
            kps = all_kps[:, i]
270
            # 1, V
271
            kpscores = all_kpscores[:, i] if self.use_score else np.ones_like(all_kpscores[:, i])
272
273
            self.generate_heatmap(ret[i], kps, kpscores)
274
        return ret
275
276
    def __call__(self, pose_data):
277
        """
278
        pose_data: (T, V, C=3/2)
279
        1: means person number
280
        """
281
        pose_data = pose_data[None,...] # (1, T, V, C=3/2)
282
283
        heatmap = self.gen_an_aug(pose_data)
284
        
285
        if self.double:
286
            indices = np.arange(heatmap.shape[1], dtype=np.int64)
287
            left, right = (self.left_kp, self.right_kp) if self.with_kp else (self.left_limb, self.right_limb)
288
            for l, r in zip(left, right):  # noqa: E741
289
                indices[l] = r
290
                indices[r] = l
291
            heatmap_flip = heatmap[..., ::-1][:, indices]
292
            heatmap = np.concatenate([heatmap, heatmap_flip])
293
        return heatmap
294
295
    def __repr__(self):
296
        repr_str = (f'{self.__class__.__name__}('
297
                    f'sigma={self.sigma}, '
298
                    f'use_score={self.use_score}, '
299
                    f'with_kp={self.with_kp}, '
300
                    f'with_limb={self.with_limb}, '
301
                    f'skeletons={self.skeletons}, '
302
                    f'double={self.double}, '
303
                    f'left_kp={self.left_kp}, '
304
                    f'right_kp={self.right_kp})')
305
        return repr_str
306
307
class HeatmapToImage:
308
    """
309
    Convert the heatmap data to image data.
310
    """
311
    def __init__(self) -> None:
312
        self.cmap = cm.gray
313
    
314
    def __call__(self, heatmaps):
315
        """
316
        heatmaps: (T, 17, H, W)
317
        return images: (T, 1, H, W)
318
        """
319
        heatmaps = [x.transpose(1, 2, 0) for x in heatmaps]
320
        h, w, _ = heatmaps[0].shape
321
        newh, neww = int(h), int(w)
322
        heatmaps = [np.max(x, axis=-1) for x in heatmaps]
323
        heatmaps = [(self.cmap(x)[..., :3] * 255).astype(np.uint8) for x in heatmaps]
324
        heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps]
325
        return np.ascontiguousarray(np.mean(np.array(heatmaps), axis=-1, keepdims=True).transpose(0,3,1,2))
326
327
class CenterAndScaleNormalizer:
328
329
    def __init__(self, pose_format="coco", use_conf=True, heatmap_image_height=128) -> None:
330
        """
331
        Parameters:
332
        - pose_format (str): Specifies the format of the keypoints. 
333
                            This parameter determines how the keypoints are structured and indexed. 
334
                            The supported formats are "coco" or "openpose-x" where 'x' can be either 18 or 25, indicating the number of keypoints used by the OpenPose model. 
335
        - use_conf (bool): Indicates whether confidence scores.
336
        - heatmap_image_height (int): Sets the height (in pixels) for the heatmap images that will be normlization. 
337
        """
338
        self.pose_format = pose_format
339
        self.use_conf = use_conf
340
        self.heatmap_image_height = heatmap_image_height
341
342
    def __call__(self, data):
343
        """
344
            Implements step (a) from Figure 2 in the SkeletonGait paper.
345
            data: (T, V, C)
346
            - T: number of frames
347
            - V: number of joints
348
            - C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score
349
            return data: (T, V, C)
350
        """
351
352
        if self.use_conf:
353
            pose_seq = data[..., :-1]
354
            score = np.expand_dims(data[..., -1], axis=-1)
355
        else:
356
            pose_seq = data[..., :-1]
357
        
358
        # Hip as the center point
359
        if self.pose_format.lower() == "coco":
360
            hip  = (pose_seq[:, 11] + pose_seq[:, 12]) / 2. # [t, 2]
361
        elif self.pose_format.split('-')[0].lower() == "openpose":
362
            hip  = (pose_seq[:, 9] + pose_seq[:, 12]) / 2. # [t, 2]
363
        else:
364
            raise ValueError(f"Error value for pose_format: {self.pose_format} in CenterAndScale Class.")
365
366
        # Center-normalization
367
        pose_seq = pose_seq - hip[:, np.newaxis, :]
368
369
        # Scale-normalization
370
        y_max = np.max(pose_seq[:, :, 1], axis=-1) # [t]
371
        y_min = np.min(pose_seq[:, :, 1], axis=-1) # [t]
372
        pose_seq *= ((self.heatmap_image_height // 1.5) / (y_max - y_min)[:, np.newaxis, np.newaxis]) # [t, v, 2]
373
        
374
        pose_seq += self.heatmap_image_height // 2
375
        
376
        if self.use_conf:
377
            pose_seq = np.concatenate([pose_seq, score], axis=-1)
378
        return pose_seq
379
380
class PadKeypoints:
381
    """
382
    Pad the keypoints with missing values.
383
    """
384
385
    def __init__(self, pad_method="knn", use_conf=True) -> None:
386
        """
387
        pad_method (str): Specifies the method used to pad the missing values.
388
                        The supported methods are "knn" and "simple".
389
        use_conf (bool): Indicates whether confidence scores.
390
        """
391
        self.use_conf = use_conf
392
        if pad_method.lower() == "knn":
393
            self.imputer = KNNImputer(missing_values=0.0, n_neighbors=4, weights="distance", add_indicator=False)
394
        elif pad_method.lower()  == "simple":
395
            self.imputer = SimpleImputer(missing_values=0.0, strategy='mean',add_indicator=True)
396
        else:
397
            raise ValueError(f"Error value for padding method: {pad_method}")
398
    
399
    def __call__(self, raw_data):
400
        """
401
        raw_data: (T, V, C)
402
        - T: number of frames
403
        - V: number of joints
404
        - C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score
405
        return padded_data: (T, V, C)
406
        """
407
        T, V, C = raw_data.shape
408
        if self.use_conf:
409
            data = raw_data[..., :-1]
410
            score = np.expand_dims(raw_data[..., -1], axis=-1)
411
            C = C - 1
412
        else:
413
            data = raw_data[..., :-1]
414
        data = data.reshape((T, V*C))
415
        padded_data = self.imputer.fit_transform(data)
416
        try:
417
            padded_data = padded_data.reshape((T, V, C))
418
        except:
419
            padded_data = data.reshape((T, V, C))
420
        if self.use_conf:
421
            padded_data = np.concatenate([padded_data, score], axis=-1)
422
        return padded_data
423
424
class COCO18toCOCO17:
425
    """
426
    Transfer COCO18 format (Openpose extracted) to COCO17 format
427
    """
428
429
    def __init__(self, transfer_to_coco17=True):
430
        """
431
        transfer_to_coco17 (bool): Indicates whether to transfer the keypoints from COCO18 to COCO17 format.
432
        """
433
        self.map_dict = {
434
                0: 0,# "nose",
435
                1: 15,# "left_eye",
436
                2: 14,# "right_eye",
437
                3: 17,# "left_ear",
438
                4: 16,# "right_ear",
439
                5: 5,# "left_shoulder",
440
                6: 2,# "right_shoulder",
441
                7: 6,# "left_elbow",
442
                8: 3,# "right_elbow",
443
                9: 7,# "left_wrist",
444
                10: 4,# "right_wrist",
445
                11: 11,# "left_hip",
446
                12: 8,# "right_hip",
447
                13: 12,# "left_knee",
448
                14: 9,# "right_knee",
449
                15: 13,# "left_ankle",
450
                16: 10,# "right_ankle"
451
            }
452
        self.transfer = transfer_to_coco17
453
    
454
    def __call__(self, data):
455
456
        """
457
        data: (T, 18, C)
458
        - T: number of frames
459
        - 18: number of joints of COCO18 format
460
        - C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score
461
        return data: (T, 17, C)
462
        """
463
464
        if self.transfer:
465
            """
466
            input data [T, 18, C] coco18 format
467
            return data [T, 17, C] coco17 format
468
            """
469
            T, _, C = data.shape
470
            coco17_pkl_data = np.zeros((T, 17, C))
471
            for i in range(17):
472
                coco17_pkl_data[:,i,:] = data[:,self.map_dict[i],:]
473
            return coco17_pkl_data
474
        else:
475
            return data
476
477
class GatherTransform(object):
478
    """
479
    Gather the different transforms.
480
    """
481
    def __init__(self, base_transform, transform_bone, transform_joint):
482
483
        """
484
        base_transform: Some common transform, e.g., COCO18toCOCO17, PadKeypoints, CenterAndScale
485
        transform_bone: GeneratePoseTarget for generate bone heatmap
486
        transform_joint: GeneratePoseTarget for generate joint heatmap
487
        """
488
        self.base_transform = base_transform
489
        self.transform_bone = transform_bone
490
        self.transform_joint = transform_joint
491
492
    def __call__(self, pose_data):
493
        x = self.base_transform(pose_data)
494
        heatmap_bone = self.transform_bone(x) # [T, 1, H, W]
495
        heatmap_joint = self.transform_joint(x) # [T, 1, H, W]
496
        heatmap = np.concatenate([heatmap_bone, heatmap_joint], axis=1)
497
        return heatmap
498
499
class HeatmapAlignment():
500
    def __init__(self, align=True, final_img_size=64, offset=0, heatmap_image_size=128) -> None:
501
        self.align = align
502
        self.final_img_size = final_img_size
503
        self.offset = offset
504
        self.heatmap_image_size = heatmap_image_size
505
506
    def center_crop(self, heatmap):
507
        """
508
        Input: [1, heatmap_image_size, heatmap_image_size]
509
        Output: [1, final_img_size, final_img_size]
510
        """
511
        raw_heatmap = heatmap[0]
512
        if self.align: 
513
            y_sum = raw_heatmap.sum(axis=1)
514
            y_top = (y_sum != 0).argmax(axis=0)
515
            y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0)
516
            height = y_btm - y_top + 1
517
            raw_heatmap = raw_heatmap[y_top - self.offset: y_btm + 1 + self.offset, (self.heatmap_image_size // 2) - (height // 2) : (self.heatmap_image_size // 2) + (height // 2) + 1]
518
        raw_heatmap = cv2.resize(raw_heatmap, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
519
        return raw_heatmap[np.newaxis, :, :] # [1, final_img_size, final_img_size]
520
521
    def __call__(self, heatmap_imgs):
522
        """
523
        heatmap_imgs: (T, 1, raw_size, raw_size)
524
        return (T, 1, final_img_size, final_img_size)
525
        """
526
        heatmap_imgs = heatmap_imgs / 255.
527
        heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs]) 
528
        return (heatmap_imgs * 255).astype('uint8')
529
530
def GenerateHeatmapTransform(
531
    coco18tococo17_args,
532
    padkeypoints_args,
533
    norm_args,
534
    heatmap_generator_args,
535
    align_args
536
):
537
538
    base_transform = T.Compose([
539
        COCO18toCOCO17(**coco18tococo17_args),
540
        PadKeypoints(**padkeypoints_args), 
541
        CenterAndScaleNormalizer(**norm_args), 
542
    ])
543
544
    heatmap_generator_args["with_limb"] = True
545
    heatmap_generator_args["with_kp"] = False
546
    transform_bone = T.Compose([
547
        GeneratePoseTarget(**heatmap_generator_args), 
548
        HeatmapToImage(), 
549
        HeatmapAlignment(**align_args) 
550
    ])
551
552
    heatmap_generator_args["with_limb"] = False
553
    heatmap_generator_args["with_kp"] = True
554
    transform_joint = T.Compose([
555
        GeneratePoseTarget(**heatmap_generator_args), 
556
        HeatmapToImage(), 
557
        HeatmapAlignment(**align_args) 
558
    ])
559
560
    transform = T.Compose([
561
        GatherTransform(base_transform, transform_bone, transform_joint) # [T, 2, H, W]
562
    ])
563
564
    return transform
565
566
#########################################################################################################
567
# The following code is DDP progress codes.
568
#########################################################################################################
569
class SequentialDistributedSampler(torch.utils.data.sampler.Sampler):
570
    """
571
    Distributed Sampler that subsamples indicies sequentially,
572
    making it easier to collate all results at the end.
573
    Even though we only use this sampler for eval and predict (no training),
574
    which means that the model params won't have to be synced (i.e. will not hang
575
    for synchronization even if varied number of forward passes), we still add extra
576
    samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
577
    to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
578
    """
579
580
    def __init__(self, dataset, batch_size, rank=None, num_replicas=None):
581
        if num_replicas is None:
582
            if not torch.distributed.is_available():
583
                raise RuntimeError("Requires distributed package to be available")
584
            num_replicas = torch.distributed.get_world_size()
585
        if rank is None:
586
            if not torch.distributed.is_available():
587
                raise RuntimeError("Requires distributed package to be available")
588
            rank = torch.distributed.get_rank()
589
        self.dataset = dataset
590
        self.num_replicas = num_replicas
591
        self.rank = rank
592
        self.batch_size = batch_size
593
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size
594
        self.total_size = self.num_samples * self.num_replicas
595
596
    def __iter__(self):
597
        indices = list(range(len(self.dataset)))
598
        # add extra samples to make it evenly divisible
599
        indices += [indices[-1]] * (self.total_size - len(indices))
600
        # subsample
601
        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
602
        return iter(indices)
603
604
    def __len__(self):
605
        return self.num_samples
606
607
608
class TransferDataset(Dataset):
609
    def __init__(self, args, generate_heatemap_cfgs) -> None:
610
        super().__init__()
611
        pose_root = args.pose_data_path
612
        sigma = generate_heatemap_cfgs['heatmap_generator_args']['sigma']
613
        self.dataset_name = args.dataset_name
614
        assert self.dataset_name.lower() in ["sustech1k", "grew", "ccpg", "oumvlp", "ou-mvlp", "gait3d", "casiab", "casiae"], f"Invalid dataset name: {self.dataset_name}"
615
        self.save_root = os.path.join(args.save_root, f"{self.dataset_name}_sigma_{sigma}_{args.ext_name}")
616
        os.makedirs(self.save_root, exist_ok=True)
617
618
        self.heatmap_transform = GenerateHeatmapTransform(**generate_heatemap_cfgs)
619
620
        if self.dataset_name.lower() == "sustech1k":
621
            self.all_ps_data_paths = sorted(glob(os.path.join(pose_root, "*/*/*/03*.pkl")))
622
        else:
623
            self.all_ps_data_paths = sorted(glob(os.path.join(pose_root, "*/*/*/*.pkl")))
624
625
    def __len__(self):
626
        return len(self.all_ps_data_paths)
627
    
628
    def __getitem__(self, index):
629
        pose_path = self.all_ps_data_paths[index]
630
        with open(pose_path, "rb") as f:
631
            pose_data = pickle.load(f)
632
            if self.dataset_name.lower() == "grew":
633
                # print(pose_data.shape)
634
                pose_data = pose_data[:,2:].reshape(-1, 17, 3)
635
        
636
        tmp_split = pose_path.split('/')
637
638
        heatmap_img = self.heatmap_transform(pose_data) # [T, 2, H, W]
639
        
640
        save_path_pkl = os.path.join(self.save_root, 'pkl', *tmp_split[-4:-1])
641
        os.makedirs(save_path_pkl, exist_ok=True)
642
643
        # save some visualization
644
        if index < 10:
645
            # save images
646
            save_path_img = os.path.join(self.save_root, 'images', *tmp_split[-4:-1])
647
            os.makedirs(save_path_img, exist_ok=True)
648
            # save_heatemapimg_index = random.choice(list(range(heatmap_img.shape[0])))
649
            for save_heatemapimg_index in range(heatmap_img.shape[0]):
650
                cv2.imwrite(os.path.join(save_path_img, f'bone_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 0])
651
                cv2.imwrite(os.path.join(save_path_img, f'pose_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 1])
652
653
        pickle.dump(heatmap_img, open(os.path.join(save_path_pkl, tmp_split[-1]), 'wb'))
654
        return None
655
656
def mycollate(_):
657
    return None
658
659
660
def get_args():
661
    parser = argparse.ArgumentParser(description='Utility for generating heatmaps from pose data.')
662
    parser.add_argument('--pose_data_path', type=str, required=True, help="Path to the root directory containing pose data (.pkl files, ID-level) files.")
663
    parser.add_argument('--save_root', type=str, required=True, help="Root directory where generated heatmap .pkl files will be saved (ID-level).")
664
    parser.add_argument('--ext_name', type=str, default='', help="Extension name to be appended to the 'save_root' for identification.")
665
    parser.add_argument('--dataset_name', type=str, required=True, help="Name of the dataset being preprocessed.")
666
    parser.add_argument('--heatemap_cfg_path', type=str, default='configs/skeletongait/pretreatment_heatmap.yaml', help="Path to the heatmap generator configuration file.")
667
    parser.add_argument("--local_rank", type=int, default=0, help="Local rank for distributed processing, defaults to 0 for non-distributed setups.")
668
    opt = parser.parse_args()
669
    return opt
670
671
def replace_variables(data, context=None):
672
    if context is None:
673
        context = {}
674
675
    if isinstance(data, dict):
676
        for key, value in data.items():
677
            data[key] = replace_variables(value, context)
678
    elif isinstance(data, list):
679
        data = [replace_variables(item, context) for item in data]
680
    elif isinstance(data, str):
681
        if data.startswith('${') and data.endswith('}'):
682
            var_path = data[2:-1].split('.')
683
            var_value = context
684
            try:
685
                for part in var_path:
686
                    var_value = var_value[part]
687
                return var_value
688
            except KeyError:
689
                raise ValueError(f"Variable {data} not found in context")
690
    return data
691
692
if __name__ == "__main__":
693
    dist.init_process_group("nccl", init_method='env://')
694
    local_rank = torch.distributed.get_rank()
695
    world_size = torch.distributed.get_world_size()
696
697
    args = get_args()
698
699
    # Load the heatmap generator configuration
700
    with open(args.heatemap_cfg_path, 'r') as stream:
701
        generate_heatemap_cfgs = yaml.safe_load(stream)
702
        generate_heatemap_cfgs = replace_variables(generate_heatemap_cfgs, generate_heatemap_cfgs)
703
    # Create the dataset
704
    dataset = TransferDataset(args, generate_heatemap_cfgs)
705
706
    # Create the dataloader
707
    dist_sampler = SequentialDistributedSampler(dataset, batch_size=1, rank=local_rank, num_replicas=world_size)
708
    dataloader = DataLoader(dataset=dataset, batch_size=1, sampler=dist_sampler, num_workers=8, collate_fn=mycollate)
709
    for _, tmp in tqdm(enumerate(dataloader), total=len(dataloader)):
710
        pass
711
712