Diff of /pretreatment.py [000000] .. [40f229]

Switch to unified view

a b/pretreatment.py
1
# -*- coding: utf-8 -*-
2
# @Author  : Abner
3
# @Time    : 2018/12/19
4
5
import os
6
from scipy import misc as scisc
7
import cv2
8
import numpy as np
9
from warnings import warn
10
from time import sleep
11
import argparse
12
13
from multiprocessing import Pool
14
from multiprocessing import TimeoutError as MP_TimeoutError
15
16
START = "START"
17
FINISH = "FINISH"
18
WARNING = "WARNING"
19
FAIL = "FAIL"
20
21
22
def boolean_string(s):
23
    if s.upper() not in {'FALSE', 'TRUE'}:
24
        raise ValueError('Not a valid boolean string')
25
    return s.upper() == 'TRUE'
26
27
28
parser = argparse.ArgumentParser(description='Test')
29
parser.add_argument('--input_path', default='', type=str,
30
                    help='Root path of raw dataset.')
31
parser.add_argument('--output_path', default='', type=str,
32
                    help='Root path for output.')
33
parser.add_argument('--log_file', default='./pretreatment.log', type=str,
34
                    help='Log file path. Default: ./pretreatment.log')
35
parser.add_argument('--log', default=False, type=boolean_string,
36
                    help='If set as True, all logs will be saved. '
37
                         'Otherwise, only warnings and errors will be saved.'
38
                         'Default: False')
39
parser.add_argument('--worker_num', default=1, type=int,
40
                    help='How many subprocesses to use for data pretreatment. '
41
                         'Default: 1')
42
opt = parser.parse_args()
43
44
INPUT_PATH = opt.input_path
45
OUTPUT_PATH = opt.output_path
46
IF_LOG = opt.log
47
LOG_PATH = opt.log_file
48
WORKERS = opt.worker_num
49
50
T_H = 64
51
T_W = 64
52
53
54
def log2str(pid, comment, logs):
55
    str_log = ''
56
    if type(logs) is str:
57
        logs = [logs]
58
    for log in logs:
59
        str_log += "# JOB %d : --%s-- %s\n" % (
60
            pid, comment, log)
61
    return str_log
62
63
64
def log_print(pid, comment, logs):
65
    str_log = log2str(pid, comment, logs)
66
    if comment in [WARNING, FAIL]:
67
        with open(LOG_PATH, 'a') as log_f:
68
            log_f.write(str_log)
69
    if comment in [START, FINISH]:
70
        if pid % 500 != 0:
71
            return
72
    print(str_log, end='')
73
74
75
def cut_img(img, seq_info, frame_name, pid):
76
    # A silhouette contains too little white pixels
77
    # might be not valid for identification.
78
    if img.sum() <= 10000:
79
        message = 'seq:%s, frame:%s, no data, %d.' % (
80
            '-'.join(seq_info), frame_name, img.sum())
81
        warn(message)
82
        log_print(pid, WARNING, message)
83
        return None
84
    # Get the top and bottom point
85
    y = img.sum(axis=1)
86
    y_top = (y != 0).argmax(axis=0)
87
    y_btm = (y != 0).cumsum(axis=0).argmax(axis=0)
88
    img = img[y_top:y_btm + 1, :]
89
    # As the height of a person is larger than the width,
90
    # use the height to calculate resize ratio.
91
    _r = img.shape[1] / img.shape[0]
92
    _t_w = int(T_H * _r)
93
    img = cv2.resize(img, (_t_w, T_H), interpolation=cv2.INTER_CUBIC)
94
    # Get the median of x axis and regard it as the x center of the person.
95
    sum_point = img.sum()
96
    sum_column = img.sum(axis=0).cumsum()
97
    x_center = -1
98
    for i in range(sum_column.size):
99
        if sum_column[i] > sum_point / 2:
100
            x_center = i
101
            break
102
    if x_center < 0:
103
        message = 'seq:%s, frame:%s, no center.' % (
104
            '-'.join(seq_info), frame_name)
105
        warn(message)
106
        log_print(pid, WARNING, message)
107
        return None
108
    h_T_W = int(T_W / 2)
109
    left = x_center - h_T_W
110
    right = x_center + h_T_W
111
    if left <= 0 or right >= img.shape[1]:
112
        left += h_T_W
113
        right += h_T_W
114
        _ = np.zeros((img.shape[0], h_T_W))
115
        img = np.concatenate([_, img, _], axis=1)
116
    img = img[:, left:right]
117
    return img.astype('uint8')
118
119
120
def cut_pickle(seq_info, pid):
121
    seq_name = '-'.join(seq_info)
122
    log_print(pid, START, seq_name)
123
    seq_path = os.path.join(INPUT_PATH, *seq_info)
124
    out_dir = os.path.join(OUTPUT_PATH, *seq_info)
125
    frame_list = os.listdir(seq_path)
126
    frame_list.sort()
127
    count_frame = 0
128
    for _frame_name in frame_list:
129
        frame_path = os.path.join(seq_path, _frame_name)
130
        img = cv2.imread(frame_path)[:, :, 0]
131
        img = cut_img(img, seq_info, _frame_name, pid)
132
        if img is not None:
133
            # Save the cut img
134
            save_path = os.path.join(out_dir, _frame_name)
135
            scisc.imsave(save_path, img)
136
            count_frame += 1
137
    # Warn if the sequence contains less than 5 frames
138
    if count_frame < 5:
139
        message = 'seq:%s, less than 5 valid data.' % (
140
            '-'.join(seq_info))
141
        warn(message)
142
        log_print(pid, WARNING, message)
143
144
    log_print(pid, FINISH,
145
              'Contain %d valid frames. Saved to %s.'
146
              % (count_frame, out_dir))
147
148
149
pool = Pool(WORKERS)
150
results = list()
151
pid = 0
152
153
print('Pretreatment Start.\n'
154
      'Input path: %s\n'
155
      'Output path: %s\n'
156
      'Log file: %s\n'
157
      'Worker num: %d' % (
158
          INPUT_PATH, OUTPUT_PATH, LOG_PATH, WORKERS))
159
160
id_list = os.listdir(INPUT_PATH)
161
id_list.sort()
162
# Walk the input path
163
for _id in id_list:
164
    seq_type = os.listdir(os.path.join(INPUT_PATH, _id))
165
    seq_type.sort()
166
    for _seq_type in seq_type:
167
        view = os.listdir(os.path.join(INPUT_PATH, _id, _seq_type))
168
        view.sort()
169
        for _view in view:
170
            seq_info = [_id, _seq_type, _view]
171
            out_dir = os.path.join(OUTPUT_PATH, *seq_info)
172
            os.makedirs(out_dir)
173
            results.append(
174
                pool.apply_async(
175
                    cut_pickle,
176
                    args=(seq_info, pid)))
177
            sleep(0.02)
178
            pid += 1
179
180
pool.close()
181
unfinish = 1
182
while unfinish > 0:
183
    unfinish = 0
184
    for i, res in enumerate(results):
185
        try:
186
            res.get(timeout=0.1)
187
        except Exception as e:
188
            if type(e) == MP_TimeoutError:
189
                unfinish += 1
190
                continue
191
            else:
192
                print('\n\n\nERROR OCCUR: PID ##%d##, ERRORTYPE: %s\n\n\n',
193
                      i, type(e))
194
                raise e
195
pool.join()