--- a
+++ b/pretreatment.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# @Author  : Abner
+# @Time    : 2018/12/19
+
+import os
+from scipy import misc as scisc
+import cv2
+import numpy as np
+from warnings import warn
+from time import sleep
+import argparse
+
+from multiprocessing import Pool
+from multiprocessing import TimeoutError as MP_TimeoutError
+
+START = "START"
+FINISH = "FINISH"
+WARNING = "WARNING"
+FAIL = "FAIL"
+
+
+def boolean_string(s):
+    if s.upper() not in {'FALSE', 'TRUE'}:
+        raise ValueError('Not a valid boolean string')
+    return s.upper() == 'TRUE'
+
+
+parser = argparse.ArgumentParser(description='Test')
+parser.add_argument('--input_path', default='', type=str,
+                    help='Root path of raw dataset.')
+parser.add_argument('--output_path', default='', type=str,
+                    help='Root path for output.')
+parser.add_argument('--log_file', default='./pretreatment.log', type=str,
+                    help='Log file path. Default: ./pretreatment.log')
+parser.add_argument('--log', default=False, type=boolean_string,
+                    help='If set as True, all logs will be saved. '
+                         'Otherwise, only warnings and errors will be saved.'
+                         'Default: False')
+parser.add_argument('--worker_num', default=1, type=int,
+                    help='How many subprocesses to use for data pretreatment. '
+                         'Default: 1')
+opt = parser.parse_args()
+
+INPUT_PATH = opt.input_path
+OUTPUT_PATH = opt.output_path
+IF_LOG = opt.log
+LOG_PATH = opt.log_file
+WORKERS = opt.worker_num
+
+T_H = 64
+T_W = 64
+
+
+def log2str(pid, comment, logs):
+    str_log = ''
+    if type(logs) is str:
+        logs = [logs]
+    for log in logs:
+        str_log += "# JOB %d : --%s-- %s\n" % (
+            pid, comment, log)
+    return str_log
+
+
+def log_print(pid, comment, logs):
+    str_log = log2str(pid, comment, logs)
+    if comment in [WARNING, FAIL]:
+        with open(LOG_PATH, 'a') as log_f:
+            log_f.write(str_log)
+    if comment in [START, FINISH]:
+        if pid % 500 != 0:
+            return
+    print(str_log, end='')
+
+
+def cut_img(img, seq_info, frame_name, pid):
+    # A silhouette contains too little white pixels
+    # might be not valid for identification.
+    if img.sum() <= 10000:
+        message = 'seq:%s, frame:%s, no data, %d.' % (
+            '-'.join(seq_info), frame_name, img.sum())
+        warn(message)
+        log_print(pid, WARNING, message)
+        return None
+    # Get the top and bottom point
+    y = img.sum(axis=1)
+    y_top = (y != 0).argmax(axis=0)
+    y_btm = (y != 0).cumsum(axis=0).argmax(axis=0)
+    img = img[y_top:y_btm + 1, :]
+    # As the height of a person is larger than the width,
+    # use the height to calculate resize ratio.
+    _r = img.shape[1] / img.shape[0]
+    _t_w = int(T_H * _r)
+    img = cv2.resize(img, (_t_w, T_H), interpolation=cv2.INTER_CUBIC)
+    # Get the median of x axis and regard it as the x center of the person.
+    sum_point = img.sum()
+    sum_column = img.sum(axis=0).cumsum()
+    x_center = -1
+    for i in range(sum_column.size):
+        if sum_column[i] > sum_point / 2:
+            x_center = i
+            break
+    if x_center < 0:
+        message = 'seq:%s, frame:%s, no center.' % (
+            '-'.join(seq_info), frame_name)
+        warn(message)
+        log_print(pid, WARNING, message)
+        return None
+    h_T_W = int(T_W / 2)
+    left = x_center - h_T_W
+    right = x_center + h_T_W
+    if left <= 0 or right >= img.shape[1]:
+        left += h_T_W
+        right += h_T_W
+        _ = np.zeros((img.shape[0], h_T_W))
+        img = np.concatenate([_, img, _], axis=1)
+    img = img[:, left:right]
+    return img.astype('uint8')
+
+
+def cut_pickle(seq_info, pid):
+    seq_name = '-'.join(seq_info)
+    log_print(pid, START, seq_name)
+    seq_path = os.path.join(INPUT_PATH, *seq_info)
+    out_dir = os.path.join(OUTPUT_PATH, *seq_info)
+    frame_list = os.listdir(seq_path)
+    frame_list.sort()
+    count_frame = 0
+    for _frame_name in frame_list:
+        frame_path = os.path.join(seq_path, _frame_name)
+        img = cv2.imread(frame_path)[:, :, 0]
+        img = cut_img(img, seq_info, _frame_name, pid)
+        if img is not None:
+            # Save the cut img
+            save_path = os.path.join(out_dir, _frame_name)
+            scisc.imsave(save_path, img)
+            count_frame += 1
+    # Warn if the sequence contains less than 5 frames
+    if count_frame < 5:
+        message = 'seq:%s, less than 5 valid data.' % (
+            '-'.join(seq_info))
+        warn(message)
+        log_print(pid, WARNING, message)
+
+    log_print(pid, FINISH,
+              'Contain %d valid frames. Saved to %s.'
+              % (count_frame, out_dir))
+
+
+pool = Pool(WORKERS)
+results = list()
+pid = 0
+
+print('Pretreatment Start.\n'
+      'Input path: %s\n'
+      'Output path: %s\n'
+      'Log file: %s\n'
+      'Worker num: %d' % (
+          INPUT_PATH, OUTPUT_PATH, LOG_PATH, WORKERS))
+
+id_list = os.listdir(INPUT_PATH)
+id_list.sort()
+# Walk the input path
+for _id in id_list:
+    seq_type = os.listdir(os.path.join(INPUT_PATH, _id))
+    seq_type.sort()
+    for _seq_type in seq_type:
+        view = os.listdir(os.path.join(INPUT_PATH, _id, _seq_type))
+        view.sort()
+        for _view in view:
+            seq_info = [_id, _seq_type, _view]
+            out_dir = os.path.join(OUTPUT_PATH, *seq_info)
+            os.makedirs(out_dir)
+            results.append(
+                pool.apply_async(
+                    cut_pickle,
+                    args=(seq_info, pid)))
+            sleep(0.02)
+            pid += 1
+
+pool.close()
+unfinish = 1
+while unfinish > 0:
+    unfinish = 0
+    for i, res in enumerate(results):
+        try:
+            res.get(timeout=0.1)
+        except Exception as e:
+            if type(e) == MP_TimeoutError:
+                unfinish += 1
+                continue
+            else:
+                print('\n\n\nERROR OCCUR: PID ##%d##, ERRORTYPE: %s\n\n\n',
+                      i, type(e))
+                raise e
+pool.join()