|
a |
|
b/pretreatment.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
# @Author : Abner |
|
|
3 |
# @Time : 2018/12/19 |
|
|
4 |
|
|
|
5 |
import os |
|
|
6 |
from scipy import misc as scisc |
|
|
7 |
import cv2 |
|
|
8 |
import numpy as np |
|
|
9 |
from warnings import warn |
|
|
10 |
from time import sleep |
|
|
11 |
import argparse |
|
|
12 |
|
|
|
13 |
from multiprocessing import Pool |
|
|
14 |
from multiprocessing import TimeoutError as MP_TimeoutError |
|
|
15 |
|
|
|
16 |
START = "START" |
|
|
17 |
FINISH = "FINISH" |
|
|
18 |
WARNING = "WARNING" |
|
|
19 |
FAIL = "FAIL" |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def boolean_string(s): |
|
|
23 |
if s.upper() not in {'FALSE', 'TRUE'}: |
|
|
24 |
raise ValueError('Not a valid boolean string') |
|
|
25 |
return s.upper() == 'TRUE' |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
parser = argparse.ArgumentParser(description='Test') |
|
|
29 |
parser.add_argument('--input_path', default='', type=str, |
|
|
30 |
help='Root path of raw dataset.') |
|
|
31 |
parser.add_argument('--output_path', default='', type=str, |
|
|
32 |
help='Root path for output.') |
|
|
33 |
parser.add_argument('--log_file', default='./pretreatment.log', type=str, |
|
|
34 |
help='Log file path. Default: ./pretreatment.log') |
|
|
35 |
parser.add_argument('--log', default=False, type=boolean_string, |
|
|
36 |
help='If set as True, all logs will be saved. ' |
|
|
37 |
'Otherwise, only warnings and errors will be saved.' |
|
|
38 |
'Default: False') |
|
|
39 |
parser.add_argument('--worker_num', default=1, type=int, |
|
|
40 |
help='How many subprocesses to use for data pretreatment. ' |
|
|
41 |
'Default: 1') |
|
|
42 |
opt = parser.parse_args() |
|
|
43 |
|
|
|
44 |
INPUT_PATH = opt.input_path |
|
|
45 |
OUTPUT_PATH = opt.output_path |
|
|
46 |
IF_LOG = opt.log |
|
|
47 |
LOG_PATH = opt.log_file |
|
|
48 |
WORKERS = opt.worker_num |
|
|
49 |
|
|
|
50 |
T_H = 64 |
|
|
51 |
T_W = 64 |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
def log2str(pid, comment, logs): |
|
|
55 |
str_log = '' |
|
|
56 |
if type(logs) is str: |
|
|
57 |
logs = [logs] |
|
|
58 |
for log in logs: |
|
|
59 |
str_log += "# JOB %d : --%s-- %s\n" % ( |
|
|
60 |
pid, comment, log) |
|
|
61 |
return str_log |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
def log_print(pid, comment, logs): |
|
|
65 |
str_log = log2str(pid, comment, logs) |
|
|
66 |
if comment in [WARNING, FAIL]: |
|
|
67 |
with open(LOG_PATH, 'a') as log_f: |
|
|
68 |
log_f.write(str_log) |
|
|
69 |
if comment in [START, FINISH]: |
|
|
70 |
if pid % 500 != 0: |
|
|
71 |
return |
|
|
72 |
print(str_log, end='') |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
def cut_img(img, seq_info, frame_name, pid): |
|
|
76 |
# A silhouette contains too little white pixels |
|
|
77 |
# might be not valid for identification. |
|
|
78 |
if img.sum() <= 10000: |
|
|
79 |
message = 'seq:%s, frame:%s, no data, %d.' % ( |
|
|
80 |
'-'.join(seq_info), frame_name, img.sum()) |
|
|
81 |
warn(message) |
|
|
82 |
log_print(pid, WARNING, message) |
|
|
83 |
return None |
|
|
84 |
# Get the top and bottom point |
|
|
85 |
y = img.sum(axis=1) |
|
|
86 |
y_top = (y != 0).argmax(axis=0) |
|
|
87 |
y_btm = (y != 0).cumsum(axis=0).argmax(axis=0) |
|
|
88 |
img = img[y_top:y_btm + 1, :] |
|
|
89 |
# As the height of a person is larger than the width, |
|
|
90 |
# use the height to calculate resize ratio. |
|
|
91 |
_r = img.shape[1] / img.shape[0] |
|
|
92 |
_t_w = int(T_H * _r) |
|
|
93 |
img = cv2.resize(img, (_t_w, T_H), interpolation=cv2.INTER_CUBIC) |
|
|
94 |
# Get the median of x axis and regard it as the x center of the person. |
|
|
95 |
sum_point = img.sum() |
|
|
96 |
sum_column = img.sum(axis=0).cumsum() |
|
|
97 |
x_center = -1 |
|
|
98 |
for i in range(sum_column.size): |
|
|
99 |
if sum_column[i] > sum_point / 2: |
|
|
100 |
x_center = i |
|
|
101 |
break |
|
|
102 |
if x_center < 0: |
|
|
103 |
message = 'seq:%s, frame:%s, no center.' % ( |
|
|
104 |
'-'.join(seq_info), frame_name) |
|
|
105 |
warn(message) |
|
|
106 |
log_print(pid, WARNING, message) |
|
|
107 |
return None |
|
|
108 |
h_T_W = int(T_W / 2) |
|
|
109 |
left = x_center - h_T_W |
|
|
110 |
right = x_center + h_T_W |
|
|
111 |
if left <= 0 or right >= img.shape[1]: |
|
|
112 |
left += h_T_W |
|
|
113 |
right += h_T_W |
|
|
114 |
_ = np.zeros((img.shape[0], h_T_W)) |
|
|
115 |
img = np.concatenate([_, img, _], axis=1) |
|
|
116 |
img = img[:, left:right] |
|
|
117 |
return img.astype('uint8') |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
def cut_pickle(seq_info, pid): |
|
|
121 |
seq_name = '-'.join(seq_info) |
|
|
122 |
log_print(pid, START, seq_name) |
|
|
123 |
seq_path = os.path.join(INPUT_PATH, *seq_info) |
|
|
124 |
out_dir = os.path.join(OUTPUT_PATH, *seq_info) |
|
|
125 |
frame_list = os.listdir(seq_path) |
|
|
126 |
frame_list.sort() |
|
|
127 |
count_frame = 0 |
|
|
128 |
for _frame_name in frame_list: |
|
|
129 |
frame_path = os.path.join(seq_path, _frame_name) |
|
|
130 |
img = cv2.imread(frame_path)[:, :, 0] |
|
|
131 |
img = cut_img(img, seq_info, _frame_name, pid) |
|
|
132 |
if img is not None: |
|
|
133 |
# Save the cut img |
|
|
134 |
save_path = os.path.join(out_dir, _frame_name) |
|
|
135 |
scisc.imsave(save_path, img) |
|
|
136 |
count_frame += 1 |
|
|
137 |
# Warn if the sequence contains less than 5 frames |
|
|
138 |
if count_frame < 5: |
|
|
139 |
message = 'seq:%s, less than 5 valid data.' % ( |
|
|
140 |
'-'.join(seq_info)) |
|
|
141 |
warn(message) |
|
|
142 |
log_print(pid, WARNING, message) |
|
|
143 |
|
|
|
144 |
log_print(pid, FINISH, |
|
|
145 |
'Contain %d valid frames. Saved to %s.' |
|
|
146 |
% (count_frame, out_dir)) |
|
|
147 |
|
|
|
148 |
|
|
|
149 |
pool = Pool(WORKERS) |
|
|
150 |
results = list() |
|
|
151 |
pid = 0 |
|
|
152 |
|
|
|
153 |
print('Pretreatment Start.\n' |
|
|
154 |
'Input path: %s\n' |
|
|
155 |
'Output path: %s\n' |
|
|
156 |
'Log file: %s\n' |
|
|
157 |
'Worker num: %d' % ( |
|
|
158 |
INPUT_PATH, OUTPUT_PATH, LOG_PATH, WORKERS)) |
|
|
159 |
|
|
|
160 |
id_list = os.listdir(INPUT_PATH) |
|
|
161 |
id_list.sort() |
|
|
162 |
# Walk the input path |
|
|
163 |
for _id in id_list: |
|
|
164 |
seq_type = os.listdir(os.path.join(INPUT_PATH, _id)) |
|
|
165 |
seq_type.sort() |
|
|
166 |
for _seq_type in seq_type: |
|
|
167 |
view = os.listdir(os.path.join(INPUT_PATH, _id, _seq_type)) |
|
|
168 |
view.sort() |
|
|
169 |
for _view in view: |
|
|
170 |
seq_info = [_id, _seq_type, _view] |
|
|
171 |
out_dir = os.path.join(OUTPUT_PATH, *seq_info) |
|
|
172 |
os.makedirs(out_dir) |
|
|
173 |
results.append( |
|
|
174 |
pool.apply_async( |
|
|
175 |
cut_pickle, |
|
|
176 |
args=(seq_info, pid))) |
|
|
177 |
sleep(0.02) |
|
|
178 |
pid += 1 |
|
|
179 |
|
|
|
180 |
pool.close() |
|
|
181 |
unfinish = 1 |
|
|
182 |
while unfinish > 0: |
|
|
183 |
unfinish = 0 |
|
|
184 |
for i, res in enumerate(results): |
|
|
185 |
try: |
|
|
186 |
res.get(timeout=0.1) |
|
|
187 |
except Exception as e: |
|
|
188 |
if type(e) == MP_TimeoutError: |
|
|
189 |
unfinish += 1 |
|
|
190 |
continue |
|
|
191 |
else: |
|
|
192 |
print('\n\n\nERROR OCCUR: PID ##%d##, ERRORTYPE: %s\n\n\n', |
|
|
193 |
i, type(e)) |
|
|
194 |
raise e |
|
|
195 |
pool.join() |