--- a
+++ b/datasets/pretreatment_heatmap.py
@@ -0,0 +1,712 @@
+import os
+import cv2
+import yaml
+import math
+import torch
+import random
+import pickle
+import argparse
+import numpy as np
+from glob import glob 
+from tqdm import tqdm
+import matplotlib.cm as cm
+import torch.distributed as dist
+from torchvision import transforms as T
+from torch.utils.data import Dataset, DataLoader
+from sklearn.impute import KNNImputer, SimpleImputer
+
+torch.manual_seed(347)
+random.seed(347)
+
+#########################################################################################################
+# The following code is the base class code for generating heatmap.
+#########################################################################################################
+
+class GeneratePoseTarget:
+    """Generate pseudo heatmaps based on joint coordinates and confidence.
+    Required keys are "keypoint", "img_shape", "keypoint_score" (optional),
+    added or modified keys are "imgs".
+    Args:
+        sigma (float): The sigma of the generated gaussian map. Default: 0.6.
+        use_score (bool): Use the confidence score of keypoints as the maximum
+            of the gaussian maps. Default: True.
+        with_kp (bool): Generate pseudo heatmaps for keypoints. Default: True.
+        with_limb (bool): Generate pseudo heatmaps for limbs. At least one of
+            'with_kp' and 'with_limb' should be True. Default: False.
+        skeletons (tuple[tuple]): The definition of human skeletons.
+            Default: ((0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (5, 7), (7, 9),
+                      (0, 6), (6, 8), (8, 10), (5, 11), (11, 13), (13, 15),
+                      (6, 12), (12, 14), (14, 16), (11, 12)),
+            which is the definition of COCO-17p skeletons.
+        double (bool): Output both original heatmaps and flipped heatmaps.
+            Default: False.
+        left_kp (tuple[int]): Indexes of left keypoints, which is used when
+            flipping heatmaps. Default: (1, 3, 5, 7, 9, 11, 13, 15),
+            which is left keypoints in COCO-17p.
+        right_kp (tuple[int]): Indexes of right keypoints, which is used when
+            flipping heatmaps. Default: (2, 4, 6, 8, 10, 12, 14, 16),
+            which is right keypoints in COCO-17p.
+        left_limb (tuple[int]): Indexes of left limbs, which is used when
+            flipping heatmaps. Default: (1, 3, 5, 7, 9, 11, 13, 15),
+            which is left limbs of skeletons we defined for COCO-17p.
+        right_limb (tuple[int]): Indexes of right limbs, which is used when
+            flipping heatmaps. Default: (2, 4, 6, 8, 10, 12, 14, 16),
+            which is right limbs of skeletons we defined for COCO-17p.
+    """
+
+    def __init__(self,
+                 sigma=0.6,
+                 use_score=True,
+                 with_kp=True,
+                 with_limb=False,
+                 skeletons=((0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (5, 7),
+                            (7, 9), (0, 6), (6, 8), (8, 10), (5, 11), (11, 13),
+                            (13, 15), (6, 12), (12, 14), (14, 16), (11, 12)),
+                 double=False,
+                 left_kp=(1, 3, 5, 7, 9, 11, 13, 15),
+                 right_kp=(2, 4, 6, 8, 10, 12, 14, 16),
+                 left_limb=(0, 2, 4, 5, 6, 10, 11, 12),
+                 right_limb=(1, 3, 7, 8, 9, 13, 14, 15),
+                 scaling=1.,
+                 eps= 1e-3,
+                 img_h=64,
+                 img_w = 64):
+
+        self.sigma = sigma
+        self.use_score = use_score
+        self.with_kp = with_kp
+        self.with_limb = with_limb
+        self.double = double
+        self.eps = eps
+
+        assert self.with_kp + self.with_limb == 1, ('One of "with_limb" and "with_kp" should be set as True.')
+        self.left_kp = left_kp
+        self.right_kp = right_kp
+        self.skeletons = skeletons
+        self.left_limb = left_limb
+        self.right_limb = right_limb
+        self.scaling = scaling
+        self.img_h = img_h
+        self.img_w = img_w
+
+    def generate_a_heatmap(self, arr, centers, max_values, point_center):
+        """Generate pseudo heatmap for one keypoint in one frame.
+        Args:
+            arr (np.ndarray): The array to store the generated heatmaps. Shape: img_h * img_w.
+            centers (np.ndarray): The coordinates of corresponding keypoints (of multiple persons). Shape: 1 * 2.
+            max_values (np.ndarray): The max values of each keypoint. Shape: (1, ).
+            point_center: Shape: (1, 2)
+        Returns:
+            np.ndarray: The generated pseudo heatmap.
+        """
+
+        sigma = self.sigma
+        img_h, img_w = arr.shape
+
+        for center, max_value in zip(centers, max_values):
+            if max_value < self.eps:
+                continue
+
+            mu_x, mu_y = center[0], center[1]
+
+            tmp_st_x = int(mu_x - 3 * sigma)
+            tmp_ed_x = int(mu_x + 3 * sigma)
+            tmp_st_y = int(mu_y - 3 * sigma)
+            tmp_ed_y = int(mu_y + 3 * sigma)
+
+            st_x = max(tmp_st_x, 0) 
+            ed_x = min(tmp_ed_x + 1, img_w)
+            st_y = max(tmp_st_y, 0)
+            ed_y = min(tmp_ed_y + 1, img_h)
+            x = np.arange(st_x, ed_x, 1, np.float32)
+            y = np.arange(st_y, ed_y, 1, np.float32)
+
+            # if the keypoint not in the heatmap coordinate system
+            if not (len(x) and len(y)):
+                continue
+            y = y[:, None]
+
+            patch = np.exp(-((x - mu_x)**2 + (y - mu_y)**2) / 2 / sigma**2)
+            patch = patch * max_value
+
+            arr[st_y:ed_y, st_x:ed_x] = np.maximum(arr[st_y:ed_y, st_x:ed_x], patch)
+
+    def generate_a_limb_heatmap(self, arr, starts, ends, start_values, end_values, point_center):
+        """Generate pseudo heatmap for one limb in one frame.
+        Args:
+            arr (np.ndarray): The array to store the generated heatmaps. Shape: img_h * img_w.
+            starts (np.ndarray): The coordinates of one keypoint in the corresponding limbs. Shape: 1 * 2.
+            ends (np.ndarray): The coordinates of the other keypoint in the corresponding limbs. Shape: 1 * 2.
+            start_values (np.ndarray): The max values of one keypoint in the corresponding limbs. Shape: (1, ).
+            end_values (np.ndarray): The max values of the other keypoint in the corresponding limbs. Shape: (1, ).
+        Returns:
+            np.ndarray: The generated pseudo heatmap.
+        """
+
+        sigma = self.sigma
+        img_h, img_w = arr.shape
+
+        for start, end, start_value, end_value in zip(starts, ends, start_values, end_values):
+            value_coeff = min(start_value, end_value)
+            if value_coeff < self.eps:
+                continue
+
+            min_x, max_x = min(start[0], end[0]), max(start[0], end[0])
+            min_y, max_y = min(start[1], end[1]), max(start[1], end[1])
+
+
+            
+            tmp_min_x = int(min_x - 3 * sigma)
+            tmp_max_x = int(max_x + 3 * sigma) 
+            tmp_min_y = int(min_y - 3 * sigma)
+            tmp_max_y = int(max_y + 3 * sigma)
+
+            min_x = max(tmp_min_x, 0)
+            max_x = min(tmp_max_x + 1, img_w)
+            min_y = max(tmp_min_y, 0)
+            max_y = min(tmp_max_y + 1, img_h)
+
+            x = np.arange(min_x, max_x, 1, np.float32)
+            y = np.arange(min_y, max_y, 1, np.float32)
+
+            if not (len(x) and len(y)):
+                continue
+
+            y = y[:, None]
+            x_0 = np.zeros_like(x)
+            y_0 = np.zeros_like(y)
+
+            # distance to start keypoints
+            d2_start = ((x - start[0])**2 + (y - start[1])**2)
+
+            # distance to end keypoints
+            d2_end = ((x - end[0])**2 + (y - end[1])**2)
+
+            # the distance between start and end keypoints.
+            d2_ab = ((start[0] - end[0])**2 + (start[1] - end[1])**2)
+
+            if d2_ab < 1:
+                self.generate_a_heatmap(arr, start[None], start_value[None], point_center)
+                continue
+
+            coeff = (d2_start - d2_end + d2_ab) / 2. / d2_ab
+
+            a_dominate = coeff <= 0
+            b_dominate = coeff >= 1
+            seg_dominate = 1 - a_dominate - b_dominate
+
+            position = np.stack([x + y_0, y + x_0], axis=-1)
+            projection = start + np.stack([coeff, coeff], axis=-1) * (end - start)
+            d2_line = position - projection
+            d2_line = d2_line[:, :, 0]**2 + d2_line[:, :, 1]**2
+            d2_seg = a_dominate * d2_start + b_dominate * d2_end + seg_dominate * d2_line
+
+            patch = np.exp(-d2_seg / 2. / sigma**2)
+            patch = patch * value_coeff
+
+            arr[min_y:max_y, min_x:max_x] = np.maximum(arr[min_y:max_y, min_x:max_x], patch)
+    def generate_heatmap(self, arr, kps, max_values):
+        """Generate pseudo heatmap for all keypoints and limbs in one frame (if
+        needed).
+        Args:
+            arr (np.ndarray): The array to store the generated heatmaps. Shape: V * img_h * img_w.
+            kps (np.ndarray): The coordinates of keypoints in this frame. Shape: 1 * V * 2.
+            max_values (np.ndarray): The confidence score of each keypoint. Shape: 1 * V.
+        Returns:
+            np.ndarray: The generated pseudo heatmap.
+        """
+
+        point_center = kps.mean(1)
+
+        if self.with_kp:
+            num_kp = kps.shape[1]
+            for i in range(num_kp):
+                self.generate_a_heatmap(arr[i], kps[:, i], max_values[:, i], point_center)
+
+        if self.with_limb:
+            for i, limb in enumerate(self.skeletons):
+                start_idx, end_idx = limb
+                starts = kps[:, start_idx]
+                ends = kps[:, end_idx]
+
+                start_values = max_values[:, start_idx]
+                end_values = max_values[:, end_idx]
+                self.generate_a_limb_heatmap(arr[i], starts, ends, start_values, end_values, point_center)
+
+    def gen_an_aug(self, pose_data):
+        """Generate pseudo heatmaps for all frames.
+        Args:
+            pose_data (array): [1, T, V, C]
+        Returns:
+            list[np.ndarray]: The generated pseudo heatmaps.
+        """
+
+        all_kps = pose_data[..., :2]
+        kp_shape = pose_data.shape # [1, T, V, 2]
+
+        if pose_data.shape[-1] == 3:
+            all_kpscores = pose_data[..., -1] # [1, T, V]
+        else:
+            all_kpscores = np.ones(kp_shape[:-1], dtype=np.float32)
+
+        
+
+        # scale img_h, img_w and kps
+        img_h = int(self.img_h * self.scaling + 0.5)
+        img_w = int(self.img_w * self.scaling + 0.5)
+        all_kps[..., :2] *= self.scaling
+
+        num_frame = kp_shape[1]
+        num_c = 0
+        if self.with_kp:
+            num_c += all_kps.shape[2]
+        if self.with_limb:
+            num_c += len(self.skeletons)
+        ret = np.zeros([num_frame, num_c, img_h, img_w], dtype=np.float32)
+
+        for i in range(num_frame):
+            # 1, V, C
+            kps = all_kps[:, i]
+            # 1, V
+            kpscores = all_kpscores[:, i] if self.use_score else np.ones_like(all_kpscores[:, i])
+
+            self.generate_heatmap(ret[i], kps, kpscores)
+        return ret
+
+    def __call__(self, pose_data):
+        """
+        pose_data: (T, V, C=3/2)
+        1: means person number
+        """
+        pose_data = pose_data[None,...] # (1, T, V, C=3/2)
+
+        heatmap = self.gen_an_aug(pose_data)
+        
+        if self.double:
+            indices = np.arange(heatmap.shape[1], dtype=np.int64)
+            left, right = (self.left_kp, self.right_kp) if self.with_kp else (self.left_limb, self.right_limb)
+            for l, r in zip(left, right):  # noqa: E741
+                indices[l] = r
+                indices[r] = l
+            heatmap_flip = heatmap[..., ::-1][:, indices]
+            heatmap = np.concatenate([heatmap, heatmap_flip])
+        return heatmap
+
+    def __repr__(self):
+        repr_str = (f'{self.__class__.__name__}('
+                    f'sigma={self.sigma}, '
+                    f'use_score={self.use_score}, '
+                    f'with_kp={self.with_kp}, '
+                    f'with_limb={self.with_limb}, '
+                    f'skeletons={self.skeletons}, '
+                    f'double={self.double}, '
+                    f'left_kp={self.left_kp}, '
+                    f'right_kp={self.right_kp})')
+        return repr_str
+
+class HeatmapToImage:
+    """
+    Convert the heatmap data to image data.
+    """
+    def __init__(self) -> None:
+        self.cmap = cm.gray
+    
+    def __call__(self, heatmaps):
+        """
+        heatmaps: (T, 17, H, W)
+        return images: (T, 1, H, W)
+        """
+        heatmaps = [x.transpose(1, 2, 0) for x in heatmaps]
+        h, w, _ = heatmaps[0].shape
+        newh, neww = int(h), int(w)
+        heatmaps = [np.max(x, axis=-1) for x in heatmaps]
+        heatmaps = [(self.cmap(x)[..., :3] * 255).astype(np.uint8) for x in heatmaps]
+        heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps]
+        return np.ascontiguousarray(np.mean(np.array(heatmaps), axis=-1, keepdims=True).transpose(0,3,1,2))
+
+class CenterAndScaleNormalizer:
+
+    def __init__(self, pose_format="coco", use_conf=True, heatmap_image_height=128) -> None:
+        """
+        Parameters:
+        - pose_format (str): Specifies the format of the keypoints. 
+                            This parameter determines how the keypoints are structured and indexed. 
+                            The supported formats are "coco" or "openpose-x" where 'x' can be either 18 or 25, indicating the number of keypoints used by the OpenPose model. 
+        - use_conf (bool): Indicates whether confidence scores.
+        - heatmap_image_height (int): Sets the height (in pixels) for the heatmap images that will be normlization. 
+        """
+        self.pose_format = pose_format
+        self.use_conf = use_conf
+        self.heatmap_image_height = heatmap_image_height
+
+    def __call__(self, data):
+        """
+            Implements step (a) from Figure 2 in the SkeletonGait paper.
+            data: (T, V, C)
+            - T: number of frames
+            - V: number of joints
+            - C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score
+            return data: (T, V, C)
+        """
+
+        if self.use_conf:
+            pose_seq = data[..., :-1]
+            score = np.expand_dims(data[..., -1], axis=-1)
+        else:
+            pose_seq = data[..., :-1]
+        
+        # Hip as the center point
+        if self.pose_format.lower() == "coco":
+            hip  = (pose_seq[:, 11] + pose_seq[:, 12]) / 2. # [t, 2]
+        elif self.pose_format.split('-')[0].lower() == "openpose":
+            hip  = (pose_seq[:, 9] + pose_seq[:, 12]) / 2. # [t, 2]
+        else:
+            raise ValueError(f"Error value for pose_format: {self.pose_format} in CenterAndScale Class.")
+
+        # Center-normalization
+        pose_seq = pose_seq - hip[:, np.newaxis, :]
+
+        # Scale-normalization
+        y_max = np.max(pose_seq[:, :, 1], axis=-1) # [t]
+        y_min = np.min(pose_seq[:, :, 1], axis=-1) # [t]
+        pose_seq *= ((self.heatmap_image_height // 1.5) / (y_max - y_min)[:, np.newaxis, np.newaxis]) # [t, v, 2]
+        
+        pose_seq += self.heatmap_image_height // 2
+        
+        if self.use_conf:
+            pose_seq = np.concatenate([pose_seq, score], axis=-1)
+        return pose_seq
+
+class PadKeypoints:
+    """
+    Pad the keypoints with missing values.
+    """
+
+    def __init__(self, pad_method="knn", use_conf=True) -> None:
+        """
+        pad_method (str): Specifies the method used to pad the missing values.
+                        The supported methods are "knn" and "simple".
+        use_conf (bool): Indicates whether confidence scores.
+        """
+        self.use_conf = use_conf
+        if pad_method.lower() == "knn":
+            self.imputer = KNNImputer(missing_values=0.0, n_neighbors=4, weights="distance", add_indicator=False)
+        elif pad_method.lower()  == "simple":
+            self.imputer = SimpleImputer(missing_values=0.0, strategy='mean',add_indicator=True)
+        else:
+            raise ValueError(f"Error value for padding method: {pad_method}")
+    
+    def __call__(self, raw_data):
+        """
+        raw_data: (T, V, C)
+        - T: number of frames
+        - V: number of joints
+        - C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score
+        return padded_data: (T, V, C)
+        """
+        T, V, C = raw_data.shape
+        if self.use_conf:
+            data = raw_data[..., :-1]
+            score = np.expand_dims(raw_data[..., -1], axis=-1)
+            C = C - 1
+        else:
+            data = raw_data[..., :-1]
+        data = data.reshape((T, V*C))
+        padded_data = self.imputer.fit_transform(data)
+        try:
+            padded_data = padded_data.reshape((T, V, C))
+        except:
+            padded_data = data.reshape((T, V, C))
+        if self.use_conf:
+            padded_data = np.concatenate([padded_data, score], axis=-1)
+        return padded_data
+
+class COCO18toCOCO17:
+    """
+    Transfer COCO18 format (Openpose extracted) to COCO17 format
+    """
+
+    def __init__(self, transfer_to_coco17=True):
+        """
+        transfer_to_coco17 (bool): Indicates whether to transfer the keypoints from COCO18 to COCO17 format.
+        """
+        self.map_dict = {
+                0: 0,# "nose",
+                1: 15,# "left_eye",
+                2: 14,# "right_eye",
+                3: 17,# "left_ear",
+                4: 16,# "right_ear",
+                5: 5,# "left_shoulder",
+                6: 2,# "right_shoulder",
+                7: 6,# "left_elbow",
+                8: 3,# "right_elbow",
+                9: 7,# "left_wrist",
+                10: 4,# "right_wrist",
+                11: 11,# "left_hip",
+                12: 8,# "right_hip",
+                13: 12,# "left_knee",
+                14: 9,# "right_knee",
+                15: 13,# "left_ankle",
+                16: 10,# "right_ankle"
+            }
+        self.transfer = transfer_to_coco17
+    
+    def __call__(self, data):
+
+        """
+        data: (T, 18, C)
+        - T: number of frames
+        - 18: number of joints of COCO18 format
+        - C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score
+        return data: (T, 17, C)
+        """
+
+        if self.transfer:
+            """
+            input data [T, 18, C] coco18 format
+            return data [T, 17, C] coco17 format
+            """
+            T, _, C = data.shape
+            coco17_pkl_data = np.zeros((T, 17, C))
+            for i in range(17):
+                coco17_pkl_data[:,i,:] = data[:,self.map_dict[i],:]
+            return coco17_pkl_data
+        else:
+            return data
+
+class GatherTransform(object):
+    """
+    Gather the different transforms.
+    """
+    def __init__(self, base_transform, transform_bone, transform_joint):
+
+        """
+        base_transform: Some common transform, e.g., COCO18toCOCO17, PadKeypoints, CenterAndScale
+        transform_bone: GeneratePoseTarget for generate bone heatmap
+        transform_joint: GeneratePoseTarget for generate joint heatmap
+        """
+        self.base_transform = base_transform
+        self.transform_bone = transform_bone
+        self.transform_joint = transform_joint
+
+    def __call__(self, pose_data):
+        x = self.base_transform(pose_data)
+        heatmap_bone = self.transform_bone(x) # [T, 1, H, W]
+        heatmap_joint = self.transform_joint(x) # [T, 1, H, W]
+        heatmap = np.concatenate([heatmap_bone, heatmap_joint], axis=1)
+        return heatmap
+
+class HeatmapAlignment():
+    def __init__(self, align=True, final_img_size=64, offset=0, heatmap_image_size=128) -> None:
+        self.align = align
+        self.final_img_size = final_img_size
+        self.offset = offset
+        self.heatmap_image_size = heatmap_image_size
+
+    def center_crop(self, heatmap):
+        """
+        Input: [1, heatmap_image_size, heatmap_image_size]
+        Output: [1, final_img_size, final_img_size]
+        """
+        raw_heatmap = heatmap[0]
+        if self.align: 
+            y_sum = raw_heatmap.sum(axis=1)
+            y_top = (y_sum != 0).argmax(axis=0)
+            y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0)
+            height = y_btm - y_top + 1
+            raw_heatmap = raw_heatmap[y_top - self.offset: y_btm + 1 + self.offset, (self.heatmap_image_size // 2) - (height // 2) : (self.heatmap_image_size // 2) + (height // 2) + 1]
+        raw_heatmap = cv2.resize(raw_heatmap, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA)
+        return raw_heatmap[np.newaxis, :, :] # [1, final_img_size, final_img_size]
+
+    def __call__(self, heatmap_imgs):
+        """
+        heatmap_imgs: (T, 1, raw_size, raw_size)
+        return (T, 1, final_img_size, final_img_size)
+        """
+        heatmap_imgs = heatmap_imgs / 255.
+        heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs]) 
+        return (heatmap_imgs * 255).astype('uint8')
+
+def GenerateHeatmapTransform(
+    coco18tococo17_args,
+    padkeypoints_args,
+    norm_args,
+    heatmap_generator_args,
+    align_args
+):
+
+    base_transform = T.Compose([
+        COCO18toCOCO17(**coco18tococo17_args),
+        PadKeypoints(**padkeypoints_args), 
+        CenterAndScaleNormalizer(**norm_args), 
+    ])
+
+    heatmap_generator_args["with_limb"] = True
+    heatmap_generator_args["with_kp"] = False
+    transform_bone = T.Compose([
+        GeneratePoseTarget(**heatmap_generator_args), 
+        HeatmapToImage(), 
+        HeatmapAlignment(**align_args) 
+    ])
+
+    heatmap_generator_args["with_limb"] = False
+    heatmap_generator_args["with_kp"] = True
+    transform_joint = T.Compose([
+        GeneratePoseTarget(**heatmap_generator_args), 
+        HeatmapToImage(), 
+        HeatmapAlignment(**align_args) 
+    ])
+
+    transform = T.Compose([
+        GatherTransform(base_transform, transform_bone, transform_joint) # [T, 2, H, W]
+    ])
+
+    return transform
+
+#########################################################################################################
+# The following code is DDP progress codes.
+#########################################################################################################
+class SequentialDistributedSampler(torch.utils.data.sampler.Sampler):
+    """
+    Distributed Sampler that subsamples indicies sequentially,
+    making it easier to collate all results at the end.
+    Even though we only use this sampler for eval and predict (no training),
+    which means that the model params won't have to be synced (i.e. will not hang
+    for synchronization even if varied number of forward passes), we still add extra
+    samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
+    to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
+    """
+
+    def __init__(self, dataset, batch_size, rank=None, num_replicas=None):
+        if num_replicas is None:
+            if not torch.distributed.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            num_replicas = torch.distributed.get_world_size()
+        if rank is None:
+            if not torch.distributed.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            rank = torch.distributed.get_rank()
+        self.dataset = dataset
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.batch_size = batch_size
+        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size
+        self.total_size = self.num_samples * self.num_replicas
+
+    def __iter__(self):
+        indices = list(range(len(self.dataset)))
+        # add extra samples to make it evenly divisible
+        indices += [indices[-1]] * (self.total_size - len(indices))
+        # subsample
+        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
+        return iter(indices)
+
+    def __len__(self):
+        return self.num_samples
+
+
+class TransferDataset(Dataset):
+    def __init__(self, args, generate_heatemap_cfgs) -> None:
+        super().__init__()
+        pose_root = args.pose_data_path
+        sigma = generate_heatemap_cfgs['heatmap_generator_args']['sigma']
+        self.dataset_name = args.dataset_name
+        assert self.dataset_name.lower() in ["sustech1k", "grew", "ccpg", "oumvlp", "ou-mvlp", "gait3d", "casiab", "casiae"], f"Invalid dataset name: {self.dataset_name}"
+        self.save_root = os.path.join(args.save_root, f"{self.dataset_name}_sigma_{sigma}_{args.ext_name}")
+        os.makedirs(self.save_root, exist_ok=True)
+
+        self.heatmap_transform = GenerateHeatmapTransform(**generate_heatemap_cfgs)
+
+        if self.dataset_name.lower() == "sustech1k":
+            self.all_ps_data_paths = sorted(glob(os.path.join(pose_root, "*/*/*/03*.pkl")))
+        else:
+            self.all_ps_data_paths = sorted(glob(os.path.join(pose_root, "*/*/*/*.pkl")))
+
+    def __len__(self):
+        return len(self.all_ps_data_paths)
+    
+    def __getitem__(self, index):
+        pose_path = self.all_ps_data_paths[index]
+        with open(pose_path, "rb") as f:
+            pose_data = pickle.load(f)
+            if self.dataset_name.lower() == "grew":
+                # print(pose_data.shape)
+                pose_data = pose_data[:,2:].reshape(-1, 17, 3)
+        
+        tmp_split = pose_path.split('/')
+
+        heatmap_img = self.heatmap_transform(pose_data) # [T, 2, H, W]
+        
+        save_path_pkl = os.path.join(self.save_root, 'pkl', *tmp_split[-4:-1])
+        os.makedirs(save_path_pkl, exist_ok=True)
+
+        # save some visualization
+        if index < 10:
+            # save images
+            save_path_img = os.path.join(self.save_root, 'images', *tmp_split[-4:-1])
+            os.makedirs(save_path_img, exist_ok=True)
+            # save_heatemapimg_index = random.choice(list(range(heatmap_img.shape[0])))
+            for save_heatemapimg_index in range(heatmap_img.shape[0]):
+                cv2.imwrite(os.path.join(save_path_img, f'bone_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 0])
+                cv2.imwrite(os.path.join(save_path_img, f'pose_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 1])
+
+        pickle.dump(heatmap_img, open(os.path.join(save_path_pkl, tmp_split[-1]), 'wb'))
+        return None
+
+def mycollate(_):
+    return None
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description='Utility for generating heatmaps from pose data.')
+    parser.add_argument('--pose_data_path', type=str, required=True, help="Path to the root directory containing pose data (.pkl files, ID-level) files.")
+    parser.add_argument('--save_root', type=str, required=True, help="Root directory where generated heatmap .pkl files will be saved (ID-level).")
+    parser.add_argument('--ext_name', type=str, default='', help="Extension name to be appended to the 'save_root' for identification.")
+    parser.add_argument('--dataset_name', type=str, required=True, help="Name of the dataset being preprocessed.")
+    parser.add_argument('--heatemap_cfg_path', type=str, default='configs/skeletongait/pretreatment_heatmap.yaml', help="Path to the heatmap generator configuration file.")
+    parser.add_argument("--local_rank", type=int, default=0, help="Local rank for distributed processing, defaults to 0 for non-distributed setups.")
+    opt = parser.parse_args()
+    return opt
+
+def replace_variables(data, context=None):
+    if context is None:
+        context = {}
+
+    if isinstance(data, dict):
+        for key, value in data.items():
+            data[key] = replace_variables(value, context)
+    elif isinstance(data, list):
+        data = [replace_variables(item, context) for item in data]
+    elif isinstance(data, str):
+        if data.startswith('${') and data.endswith('}'):
+            var_path = data[2:-1].split('.')
+            var_value = context
+            try:
+                for part in var_path:
+                    var_value = var_value[part]
+                return var_value
+            except KeyError:
+                raise ValueError(f"Variable {data} not found in context")
+    return data
+
+if __name__ == "__main__":
+    dist.init_process_group("nccl", init_method='env://')
+    local_rank = torch.distributed.get_rank()
+    world_size = torch.distributed.get_world_size()
+
+    args = get_args()
+
+    # Load the heatmap generator configuration
+    with open(args.heatemap_cfg_path, 'r') as stream:
+        generate_heatemap_cfgs = yaml.safe_load(stream)
+        generate_heatemap_cfgs = replace_variables(generate_heatemap_cfgs, generate_heatemap_cfgs)
+    # Create the dataset
+    dataset = TransferDataset(args, generate_heatemap_cfgs)
+
+    # Create the dataloader
+    dist_sampler = SequentialDistributedSampler(dataset, batch_size=1, rank=local_rank, num_replicas=world_size)
+    dataloader = DataLoader(dataset=dataset, batch_size=1, sampler=dist_sampler, num_workers=8, collate_fn=mycollate)
+    for _, tmp in tqdm(enumerate(dataloader), total=len(dataloader)):
+        pass
+
+