a b/datasets/pretreatment.py
1
# This source is based on https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py
2
import argparse
3
import logging
4
import multiprocessing as mp
5
import os
6
import pickle
7
from collections import defaultdict
8
from functools import partial
9
from pathlib import Path
10
from typing import Tuple
11
12
import cv2
13
import numpy as np
14
from tqdm import tqdm
15
import json
16
17
def imgs2pickle(img_groups: Tuple, output_path: Path, img_size: int = 64, verbose: bool = False, dataset='CASIAB') -> None:
18
    """Reads a group of images and saves the data in pickle format.
19
20
    Args:
21
        img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths.
22
        output_path (Path): Output path.
23
        img_size (int, optional): Image resizing size. Defaults to 64.
24
        verbose (bool, optional): Display debug info. Defaults to False.
25
    """    
26
    sinfo = img_groups[0]
27
    img_paths = img_groups[1]
28
    to_pickle = []
29
    for img_file in sorted(img_paths):
30
        if verbose:
31
            logging.debug(f'Reading sid {sinfo[0]}, seq {sinfo[1]}, view {sinfo[2]} from {img_file}')
32
33
        img = cv2.imread(str(img_file), cv2.IMREAD_GRAYSCALE)
34
        
35
        if dataset == 'GREW':
36
            to_pickle.append(img.astype('uint8'))
37
            continue
38
39
        if img.sum() <= 10000:
40
            if verbose:
41
                logging.debug(f'Image sum: {img.sum()}')
42
            logging.warning(f'{img_file} has no data.')
43
            continue
44
45
        # Get the upper and lower points
46
        y_sum = img.sum(axis=1)
47
        y_top = (y_sum != 0).argmax(axis=0)
48
        y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0)
49
        img = img[y_top: y_btm + 1, :]
50
51
        # As the height of a person is larger than the width,
52
        # use the height to calculate resize ratio.
53
        ratio = img.shape[1] / img.shape[0]
54
        img = cv2.resize(img, (int(img_size * ratio), img_size), interpolation=cv2.INTER_CUBIC)
55
56
        # Get the median of the x-axis and take it as the person's x-center.
57
        x_csum = img.sum(axis=0).cumsum()
58
        x_center = None
59
        for idx, csum in enumerate(x_csum):
60
            if csum > img.sum() / 2:
61
                x_center = idx
62
                break
63
64
        if not x_center:
65
            logging.warning(f'{img_file} has no center.')
66
            continue
67
68
        # Get the left and right points
69
        half_width = img_size // 2
70
        left = x_center - half_width
71
        right = x_center + half_width
72
        if left <= 0 or right >= img.shape[1]:
73
            left += half_width
74
            right += half_width
75
            _ = np.zeros((img.shape[0], half_width))
76
            img = np.concatenate([_, img, _], axis=1)
77
78
        to_pickle.append(img[:, left: right].astype('uint8'))
79
80
    if to_pickle:
81
        to_pickle = np.asarray(to_pickle)
82
        dst_path = os.path.join(output_path, *sinfo)
83
        # print(img_paths[0].as_posix().split('/'),img_paths[0].as_posix().split('/')[-5])
84
        # dst_path = os.path.join(output_path, img_paths[0].as_posix().split('/')[-5], *sinfo) if dataset == 'GREW' else dst
85
        os.makedirs(dst_path, exist_ok=True)
86
        pkl_path = os.path.join(dst_path, f'{sinfo[2]}.pkl')
87
        if verbose:
88
            logging.debug(f'Saving {pkl_path}...')
89
        pickle.dump(to_pickle, open(pkl_path, 'wb'))   
90
        logging.info(f'Saved {len(to_pickle)} valid frames to {pkl_path}.')
91
92
93
    if len(to_pickle) < 5:
94
        logging.warning(f'{sinfo} has less than 5 valid data.')
95
96
97
98
def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: int = 4, verbose: bool = False, dataset: str = 'CASIAB') -> None:
99
    """Reads a dataset and saves the data in pickle format.
100
101
    Args:
102
        input_path (Path): Dataset root path.
103
        output_path (Path): Output path.
104
        img_size (int, optional): Image resizing size. Defaults to 64.
105
        workers (int, optional): Number of thread workers. Defaults to 4.
106
        verbose (bool, optional): Display debug info. Defaults to False.
107
    """
108
    img_groups = defaultdict(list)
109
    logging.info(f'Listing {input_path}')
110
    total_files = 0
111
    for img_path in input_path.rglob('*.png'):
112
        if 'gei.png' in img_path.as_posix():
113
            continue
114
        if verbose:
115
            logging.debug(f'Adding {img_path}')
116
        *_, sid, seq, view, _ = img_path.as_posix().split('/')
117
        img_groups[(sid, seq, view)].append(img_path)
118
        total_files += 1
119
120
    logging.info(f'Total files listed: {total_files}')
121
122
    progress = tqdm(total=len(img_groups), desc='Pretreating', unit='folder')
123
124
    with mp.Pool(workers) as pool:
125
        logging.info(f'Start pretreating {input_path}')
126
        for _ in pool.imap_unordered(partial(imgs2pickle, output_path=output_path, img_size=img_size, verbose=verbose, dataset=dataset), img_groups.items()):
127
            progress.update(1)
128
    logging.info('Done')
129
130
def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dataset='CASIAB') -> None:
131
    """
132
    Reads a group of images and saves the data in pickle format.
133
134
    Args:
135
        img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths.
136
        output_path (Path): Output path.
137
        img_size (int, optional): Image resizing size. Defaults to 64.
138
        verbose (bool, optional): Display debug info. Defaults to False.
139
    """    
140
    
141
    sinfo = txt_groups[0]
142
    txt_paths = txt_groups[1]
143
    to_pickle = []
144
    if dataset == 'OUMVLP':
145
        for txt_file in sorted(txt_paths):
146
            try:
147
                with open(txt_file) as f:
148
                    jsondata = json.load(f)
149
                if len(jsondata['people'])==0:
150
                    continue
151
                data = np.array(jsondata["people"][0]["pose_keypoints_2d"]).reshape(-1,3)
152
                to_pickle.append(data)
153
            except:
154
                print(txt_file)
155
    else:
156
        for txt_file in sorted(txt_paths):
157
            if verbose:
158
                logging.debug(f'Reading sid {sinfo[0]}, seq {sinfo[1]}, view {sinfo[2]} from {txt_file}')
159
            data = np.genfromtxt(txt_file, delimiter=',')[2:].reshape(-1,3)
160
            to_pickle.append(data)
161
        
162
    if to_pickle:
163
        dst_path = os.path.join(output_path, *sinfo)
164
        keypoints = np.stack(to_pickle)
165
        os.makedirs(dst_path, exist_ok=True)
166
        pkl_path = os.path.join(dst_path, f'{sinfo[2]}.pkl')
167
        if verbose:
168
            logging.debug(f'Saving {pkl_path}...')
169
        pickle.dump(keypoints, open(pkl_path, 'wb'))   
170
        logging.info(f'Saved {len(to_pickle)} valid frames\' keypoints to {pkl_path}.')
171
172
    if len(to_pickle) < 5:
173
        logging.warning(f'{sinfo} has less than 5 valid data.')
174
175
176
177
def pretreat_pose(input_path: Path, output_path: Path, workers: int = 4, verbose: bool = False, dataset='CASIAB') -> None:
178
    """Reads a dataset and saves the data in pickle format.
179
180
    Args:
181
        input_path (Path): Dataset root path.
182
        output_path (Path): Output path.
183
        img_size (int, optional): Image resizing size. Defaults to 64.
184
        workers (int, optional): Number of thread workers. Defaults to 4.
185
        verbose (bool, optional): Display debug info. Defaults to False.
186
    """
187
    txt_groups = defaultdict(list)
188
    logging.info(f'Listing {input_path}')
189
    total_files = 0
190
    if dataset == 'OUMVLP':
191
        for json_path in input_path.rglob('*.json'):
192
            if verbose:
193
                logging.debug(f'Adding {json_path}')
194
            *_, sid, seq, view, _ = json_path.as_posix().split('/')
195
            txt_groups[(sid, seq, view)].append(json_path)
196
            total_files += 1
197
    else:
198
        for txt_path in input_path.rglob('*.txt'):
199
            if verbose:
200
                logging.debug(f'Adding {txt_path}')
201
            *_, sid, seq, view, _ = txt_path.as_posix().split('/')
202
            txt_groups[(sid, seq, view)].append(txt_path)
203
            total_files += 1
204
205
    logging.info(f'Total files listed: {total_files}')
206
207
    progress = tqdm(total=len(txt_groups), desc='Pretreating', unit='folder')
208
209
    with mp.Pool(workers) as pool:
210
        logging.info(f'Start pretreating {input_path}')
211
        for _ in pool.imap_unordered(partial(txts2pickle, output_path=output_path, verbose=verbose, dataset=args.dataset), txt_groups.items()):
212
            progress.update(1)
213
    logging.info('Done')
214
215
216
217
if __name__ == '__main__':
218
    parser = argparse.ArgumentParser(description='OpenGait dataset pretreatment module.')
219
    parser.add_argument('-i', '--input_path', default='', type=str, help='Root path of raw dataset.')
220
    parser.add_argument('-o', '--output_path', default='', type=str, help='Output path of pickled dataset.')
221
    parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log')
222
    parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4')
223
    parser.add_argument('-r', '--img_size', default=64, type=int, help='Image resizing size. Default 64')
224
    parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.')
225
    parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.')
226
    parser.add_argument('-p', '--pose', default=False, action='store_true', help='Processing pose.')
227
    args = parser.parse_args()
228
229
    logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s')
230
    
231
    if args.verbose:
232
        logging.getLogger().setLevel(logging.DEBUG)
233
        logging.info('Verbose mode is on.')
234
        for k, v in args.__dict__.items():
235
            logging.debug(f'{k}: {v}')
236
    if args.pose:
237
        pretreat_pose(input_path=Path(args.input_path), output_path=Path(args.output_path), workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)
238
    else:
239
        pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)