Switch to unified view

a b/datasets/SUSTech1K/point2depth.py
1
import matplotlib.pyplot as plt
2
3
import open3d as o3d
4
# This source is based on https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py
5
import argparse
6
import logging
7
import multiprocessing as mp
8
import os
9
import pickle
10
from collections import defaultdict
11
from functools import partial
12
from pathlib import Path
13
from typing import Tuple
14
15
import cv2
16
import numpy as np
17
from tqdm import tqdm
18
19
def align_img(img: np.ndarray, img_size: int = 64) -> np.ndarray:
20
    """Aligns the image to the center.
21
    Args:
22
        img (np.ndarray): Image to align.
23
        img_size (int, optional): Image resizing size. Defaults to 64.
24
    Returns:
25
        np.ndarray: Aligned image.
26
    """    
27
    if img.sum() <= 10000:
28
        y_top = 0
29
        y_btm = img.shape[0]
30
    else:
31
        # Get the upper and lower points
32
        # img.sum
33
        y_sum = img.sum(axis=2).sum(axis=1)
34
        y_top = (y_sum != 0).argmax(axis=0)
35
        y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0)
36
37
    img = img[y_top: y_btm, :,:]
38
39
    # As the height of a person is larger than the width,
40
    # use the height to calculate resize ratio.
41
    ratio = img.shape[1] / img.shape[0]
42
    img = cv2.resize(img, (int(img_size * ratio), img_size), interpolation=cv2.INTER_CUBIC)
43
    
44
    # Get the median of the x-axis and take it as the person's x-center.
45
    x_csum = img.sum(axis=2).sum(axis=0).cumsum()
46
    x_center = img.shape[1] // 2
47
    for idx, csum in enumerate(x_csum):
48
        if csum > img.sum() / 2:
49
            x_center = idx
50
            break
51
52
    # if not x_center:
53
    #     logging.warning(f'{img_file} has no center.')
54
    #     continue
55
56
    # Get the left and right points
57
    half_width = img_size // 2
58
    left = x_center - half_width
59
    right = x_center + half_width
60
    if left <= 0 or right >= img.shape[1]:
61
        left += half_width
62
        right += half_width
63
        # _ = np.zeros((img.shape[0], half_width,3))
64
        # img = np.concatenate([_, img, _], axis=1)
65
    
66
    img = img[:, left: right,:].astype('uint8')
67
    return img
68
69
70
71
72
73
def lidar_to_2d_front_view(points,
74
                           v_res,
75
                           h_res,
76
                           v_fov,
77
                           val="depth",
78
                           cmap="jet",
79
                           saveto=None,
80
                           y_fudge=0.0
81
                           ):
82
    """ Takes points in 3D space from LIDAR data and projects them to a 2D
83
        "front view" image, and saves that image.
84
85
    Args:
86
        points: (np array)
87
            The numpy array containing the lidar points.
88
            The shape should be Nx4
89
            - Where N is the number of points, and
90
            - each point is specified by 4 values (x, y, z, reflectance)
91
        v_res: (float)
92
            vertical resolution of the lidar sensor used.
93
        h_res: (float)
94
            horizontal resolution of the lidar sensor used.
95
        v_fov: (tuple of two floats)
96
            (minimum_negative_angle, max_positive_angle)
97
        val: (str)
98
            What value to use to encode the points that get plotted.
99
            One of {"depth", "height", "reflectance"}
100
        cmap: (str)
101
            Color map to use to color code the `val` values.
102
            NOTE: Must be a value accepted by matplotlib's scatter function
103
            Examples: "jet", "gray"
104
        saveto: (str or None)
105
            If a string is provided, it saves the image as this filename.
106
            If None, then it just shows the image.
107
        y_fudge: (float)
108
            A hacky fudge factor to use if the theoretical calculations of
109
            vertical range do not match the actual data.
110
111
            For a Velodyne HDL 64E, set this value to 5.
112
    """
113
114
    # DUMMY PROOFING
115
    assert len(v_fov) ==2, "v_fov must be list/tuple of length 2"
116
    assert v_fov[0] <= 0, "first element in v_fov must be 0 or negative"
117
    assert val in {"depth", "height", "reflectance"}, \
118
        'val must be one of {"depth", "height", "reflectance"}'
119
120
121
    x_lidar = - points[:, 0]
122
    y_lidar = - points[:, 1]
123
    z_lidar = points[:, 2]
124
    # Distance relative to origin when looked from top
125
    d_lidar = np.sqrt(x_lidar ** 2 + y_lidar ** 2)
126
    # Absolute distance relative to origin
127
    # d_lidar = np.sqrt(x_lidar ** 2 + y_lidar ** 2, z_lidar ** 2)
128
129
    v_fov_total = -v_fov[0] + v_fov[1]
130
131
    # Convert to Radians
132
    v_res_rad = v_res * (np.pi/180)
133
    h_res_rad = h_res * (np.pi/180)
134
135
    # PROJECT INTO IMAGE COORDINATES
136
    x_img = np.arctan2(-y_lidar, x_lidar)/ h_res_rad
137
    y_img = np.arctan2(z_lidar, d_lidar)/ v_res_rad
138
139
    # SHIFT COORDINATES TO MAKE 0,0 THE MINIMUM
140
    x_min = -360.0 / h_res / 2  # Theoretical min x value based on sensor specs
141
    x_img -= x_min              # Shift
142
    x_max = 360.0 / h_res       # Theoretical max x value after shifting
143
144
    y_min = v_fov[0] / v_res    # theoretical min y value based on sensor specs
145
    y_img -= y_min              # Shift
146
    y_max = v_fov_total / v_res # Theoretical max x value after shifting
147
148
    y_max += y_fudge            # Fudge factor if the calculations based on
149
                                # spec sheet do not match the range of
150
                                # angles collected by in the data.
151
152
    # WHAT DATA TO USE TO ENCODE THE VALUE FOR EACH PIXEL
153
    if val == "reflectance":
154
        pass
155
    elif val == "height":
156
        pixel_values = z_lidar
157
    else:
158
        pixel_values = -d_lidar
159
        # pixel_values = 'w'
160
161
    # PLOT THE IMAGE
162
    cmap = "jet"            # Color map to use
163
    dpi = 100               # Image resolution
164
    fig, ax = plt.subplots(figsize=(x_max/dpi, y_max/dpi), dpi=dpi)
165
    ax.scatter(x_img,y_img, s=1, c=pixel_values, linewidths=0, alpha=1, cmap=cmap)
166
    ax.set_facecolor((0, 0, 0)) # Set regions with no points to black
167
    ax.axis('scaled')              # {equal, scaled}
168
    ax.xaxis.set_visible(False)    # Do not draw axis tick marks
169
    ax.yaxis.set_visible(False)    # Do not draw axis tick marks
170
    plt.xlim([0, x_max])   # prevent drawing empty space outside of horizontal FOV
171
    plt.ylim([0, y_max])   # prevent drawing empty space outside of vertical FOV
172
173
    saveto = saveto.replace('.pcd','.png')
174
    fig.savefig(saveto, dpi=dpi, bbox_inches='tight', pad_inches=0.0)
175
    plt.close()
176
    img = cv2.imread(saveto)
177
    img = align_img(img)
178
179
    aligned_path = saveto.replace('offline','aligned')
180
    os.makedirs(os.path.dirname(aligned_path), exist_ok=True)
181
    cv2.imwrite(aligned_path, img)
182
    # fig, ax = plt.subplots(figsize=(x_max/dpi, y_max/dpi), dpi=dpi)
183
    # ax.scatter(x_img,y_img, s=1, c='white', linewidths=0, alpha=1)
184
    # ax.set_facecolor((0, 0, 0)) # Set regions with no points to black
185
    # ax.axis('scaled')              # {equal, scaled}
186
    # ax.xaxis.set_visible(False)    # Do not draw axis tick marks
187
    # ax.yaxis.set_visible(False)    # Do not draw axis tick marks
188
    # plt.xlim([0, x_max])   # prevent drawing empty space outside of horizontal FOV
189
    # plt.ylim([0, y_max])   # prevent drawing empty space outside of vertical FOV
190
191
    # fig.savefig(saveto.replace('depth','sils'), dpi=dpi, bbox_inches='tight', pad_inches=0.0)
192
    # plt.close()
193
194
195
def pcd2depth(img_groups: Tuple, output_path: Path, img_size: int = 64, verbose: bool = False, dataset='CASIAB') -> None:
196
    """Reads a group of images and saves the data in pickle format.
197
    Args:
198
        img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths.
199
        output_path (Path): Output path.
200
        img_size (int, optional): Image resizing size. Defaults to 64.
201
        verbose (bool, optional): Display debug info. Defaults to False.
202
    """    
203
    sinfo = img_groups[0]
204
    img_paths = img_groups[1]
205
    for img_file in sorted(img_paths):
206
        pcd_name = img_file.split('/')[-1]
207
        pcd = o3d.io.read_point_cloud(img_file)
208
        points = np.asarray(pcd.points)
209
        HRES = 0.19188        # horizontal resolution (assuming 20Hz setting)
210
        VRES = 0.2   
211
        VFOV = (-25.0, 15.0) # Field of view (-ve, +ve) along vertical axis
212
        Y_FUDGE = 0  # y fudge factor for velodyne HDL 64E
213
        dst_path = os.path.join(output_path, *sinfo)
214
        os.makedirs(dst_path, exist_ok=True)
215
        dst_path = os.path.join(dst_path,pcd_name)
216
        lidar_to_2d_front_view(points, v_res=VRES, h_res=HRES, v_fov=VFOV, val="depth",
217
                            saveto=dst_path, y_fudge=Y_FUDGE)
218
        # if len(points) == 0:
219
        #     print(img_file)
220
    #     to_pickle.append(points)
221
    # dst_path = os.path.join(output_path, *sinfo)
222
    # os.makedirs(dst_path, exist_ok=True)
223
    # pkl_path = os.path.join(dst_path, f'pcd-{sinfo[2]}.pkl')
224
    # pickle.dump(to_pickle, open(pkl_path, 'wb'))  
225
    # if len(to_pickle) < 5:
226
    #     logging.warning(f'{sinfo} has less than 5 valid data.')
227
228
229
230
def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: int = 4, verbose: bool = False, dataset: str = 'CASIAB') -> None:
231
    """Reads a dataset and saves the data in pickle format.
232
    Args:
233
        input_path (Path): Dataset root path.
234
        output_path (Path): Output path.
235
        img_size (int, optional): Image resizing size. Defaults to 64.
236
        workers (int, optional): Number of thread workers. Defaults to 4.
237
        verbose (bool, optional): Display debug info. Defaults to False.
238
    """
239
    img_groups = defaultdict(list)
240
    logging.info(f'Listing {input_path}')
241
    total_files = 0
242
    for sid in tqdm(sorted(os.listdir(input_path))):
243
        for seq in os.listdir(os.path.join(input_path,sid)):
244
            for view in os.listdir(os.path.join(input_path,sid,seq)):
245
                for img_path in os.listdir(os.path.join(input_path,sid,seq,view,'PCDs')):
246
                    img_groups[(sid, seq, view,'PCDs_offline_depths')].append(os.path.join(input_path,sid,seq,view, 'PCDs',img_path))
247
                    total_files += 1
248
249
    logging.info(f'Total files listed: {total_files}')
250
251
    progress = tqdm(total=len(img_groups), desc='Pretreating', unit='folder')
252
253
    with mp.Pool(workers) as pool:
254
        logging.info(f'Start pretreating {input_path}')
255
        for _ in pool.imap_unordered(partial(pcd2depth, output_path=output_path, img_size=img_size, verbose=verbose, dataset=dataset), img_groups.items()):
256
            progress.update(1)
257
    logging.info('Done')
258
259
260
if __name__ == '__main__':
261
    parser = argparse.ArgumentParser(description='OpenGait dataset pretreatment module.')
262
    parser.add_argument('-i', '--input_path', default='', type=str, help='Root path of raw dataset.')
263
    parser.add_argument('-o', '--output_path', default='', type=str, help='Output path of pickled dataset.')
264
    parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log')
265
    parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4')
266
    parser.add_argument('-r', '--img_size', default=64, type=int, help='Image resizing size. Default 64')
267
    parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.')
268
    parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.')
269
    args = parser.parse_args()
270
271
    logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s')
272
    
273
    if args.verbose:
274
        logging.getLogger().setLevel(logging.DEBUG)
275
        logging.info('Verbose mode is on.')
276
        for k, v in args.__dict__.items():
277
            logging.debug(f'{k}: {v}')
278
279
    pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)