Diff of /utils.py [000000] .. [857d1b]

Switch to side-by-side view

--- a
+++ b/utils.py
@@ -0,0 +1,194 @@
+# -*- coding:utf-8 -*-
+import os
+import cv2
+import pandas as pd
+import numpy as np
+import configparser as cp
+import matplotlib.pyplot as plt
+
+RAW_DATA_PATH = '/home/tony/fall_research/fall_data/MobiAct_Dataset_v2.0/Annotated Data/'
+
+Label = {'STD': 1, 'WAL': 2, 'JOG': 3, 'JUM': 4, 'STU': 5, 'STN': 6, 'SCH': 7, 'SIT': 8, 'CHU': 9,
+         'LYI': 10, 'FOL': 0, 'FKL': 0, 'BSC': 0, 'SDL': 0, 'CSI': 15, 'CSO': 16}
+
+def extract_data(data_file, sampling_frequency):
+    """
+    从mobileFall中提取数据,用于做实验测试
+    :param data_file:  原始数据文件
+    :param sampling_frequency: 原始数据采集频率
+    :return:
+    """
+    data = pd.read_csv(data_file, index_col=0)
+    data_size = len(data.label)
+    for i in range(data_size):
+        data.iat[i, 10] = Label[data.iloc[i, 10]]
+
+    col_data = np.arange(0, data_size, int(sampling_frequency/50))
+    extract_data = data.iloc[col_data, [1, 2, 3, 4, 5, 6, 10]]
+
+    save_path = './dataset/raw/' + os.path.abspath(os.path.dirname(data_file)+os.path.sep+".").replace(RAW_DATA_PATH, '')
+    if not os.path.exists(save_path):
+        os.makedirs(save_path)
+    save_path = './dataset/raw/' + data_file.replace(RAW_DATA_PATH, '')
+    extract_data.to_csv(save_path, index=0)
+
+def find_all_data_and_extract(path):
+    """
+    递归的查找所有文件并进行转化
+    :param path:
+    :return:
+    """
+    if not os.path.exists(path):
+        print('路径存在问题:', path)
+        return None
+
+    for i in os.listdir(path):
+        if os.path.isfile(path+"/"+i):
+            if 'csv' in i:
+                extract_data(path+"/"+i, 200)
+        else:
+            find_all_data_and_extract(path+"/"+i)
+
+def parser_cfg_file(cfg_file):
+    """
+    读取配置文件中的信息
+    :param cfg_file: 文件路径
+    :return:
+    """
+    content_params = {}
+
+    config = cp.ConfigParser()
+    config.read(cfg_file)
+
+    for section in config.sections():
+        # 获取配置文件中的net信息
+        if section == 'net':
+            for option in config.options(section):
+                content_params[option] = config.get(section,option)
+
+        # 获取配置文件中的train信息
+        if section == 'train':
+            for option in config.options(section):
+                content_params[option] = config.get(section,option)
+
+    return content_params
+
+def show_data(data, name=None):
+    '''
+    show data
+    :param data: DataFrame
+    :return:
+    '''
+    num = data.acc_x.size
+
+    x = np.arange(num)
+    fig = plt.figure(1, figsize=(100, 60))
+    # 子表1绘制加速度传感器数据
+    plt.subplot(2, 1, 1)
+    plt.title('acc')
+    plt.plot(x, data.acc_x, label='x')
+    plt.plot(x, data.acc_y, label='y')
+    plt.plot(x, data.acc_z, label='z')
+
+    # 添加解释图标
+    plt.legend()
+    x_flag = np.arange(0, num, num / 10)
+    plt.xticks(x_flag)
+
+    # 子表2绘制陀螺仪传感器数据
+    plt.subplot(2, 1, 2)
+    plt.title('gyro')
+    plt.plot(x, data.gyro_x, label='x')
+    plt.plot(x, data.gyro_y, label='y')
+    plt.plot(x, data.gyro_z, label='z')
+
+    plt.legend()
+    plt.xticks(x_flag)
+    #plt.show()
+    if name is None:
+        plt.show()
+    else:
+        plt.savefig(name)
+    plt.close()
+
+def kalman_filter(data):
+    kalman = cv2.KalmanFilter(6, 6)
+    kalman.measurementMatrix = np.array([[1, 0, 0, 0, 0, 0],
+                                         [0, 1, 0, 0, 0, 0],
+                                         [0, 0, 1, 0, 0, 0],
+                                         [0, 0, 0, 1, 0, 0],
+                                         [0, 0, 0, 0, 1, 0],
+                                         [0, 0, 0, 0, 0, 1]], np.float32)
+    kalman.transitionMatrix = np.array([[1, 0, 0, 0, 0, 0],
+                                         [0, 1, 0, 0, 0, 0],
+                                         [0, 0, 1, 0, 0, 0],
+                                         [0, 0, 0, 1, 0, 0],
+                                         [0, 0, 0, 0, 1, 0],
+                                         [0, 0, 0, 0, 0, 1]], np.float32)
+    kalman.processNoiseCov = np.array([[1, 0, 0, 0, 0, 0],
+                                       [0, 1, 0, 0, 0, 0],
+                                       [0, 0, 1, 0, 0, 0],
+                                       [0, 0, 0, 1, 0, 0],
+                                       [0, 0, 0, 0, 1, 0],
+                                       [0, 0, 0, 0, 0, 1]], np.float32) * 0.003
+    kalman.measurementNoiseCov = np.array([[1, 0, 0, 0, 0, 0],
+                                          [0, 1, 0, 0, 0, 0],
+                                          [0, 0, 1, 0, 0, 0],
+                                          [0, 0, 0, 1, 0, 0],
+                                          [0, 0, 0, 0, 1, 0],
+                                          [0, 0, 0, 0, 0, 1]], np.float32) * 1
+
+    row_num = data.acc_x.size
+
+    for i in range(row_num):
+        correct = np.array(data.iloc[i, 0:6].values, np.float32).reshape([6, 1])
+        kalman.correct(correct)
+        predict = kalman.predict()
+        data.iloc[i, 0] = predict[0]
+        data.iloc[i, 1] = predict[1]
+        data.iloc[i, 2] = predict[2]
+        data.iloc[i, 3] = predict[3]
+        data.iloc[i, 4] = predict[4]
+        data.iloc[i, 5] = predict[5]
+
+    return data
+
+def find_all_data_and_filtrate(path):
+    """
+    递归的查找所有文件并进行kalman过滤
+    :param path:
+    :return:
+    """
+    if not os.path.exists(path):
+        print('路径存在问题:', path)
+        return None
+
+    for i in os.listdir(path):
+        if os.path.isfile(path+"/"+i):
+            if 'csv' in i:
+                data = pd.read_csv(path+"/"+i)
+                data = kalman_filter(data)
+                data.to_csv(path+"/"+i, index=False)
+        else:
+            find_all_data_and_filtrate(path+"/"+i)
+
+def main():
+    #find_all_data_and_extract(RAW_DATA_PATH)
+    find_all_data_and_filtrate('./dataset/kalman/')
+
+if __name__ == '__main__':
+    main()
+    # if os.path.exists('./dataset/train/BSC_1_1_annotated.csv') == False:
+    #     print('./dataset/train/BSC_1_1_annotated.csv', '文件不存在!')
+    # data = pd.read_csv('./dataset/train/BSC_1_1_annotated.csv')
+    #
+    # #show_data(data)
+    # data = kalman_filter(data)
+    # data.to_csv('./dataset/train/BSC_1_1_annotated.csv', index=False)
+    # #show_data(data)
+    # # a = data.iloc[4:5,0]
+    # # print(a)
+    # data = pd.read_csv('./dataset/train/STU_1_1_annotated.csv')
+    #
+    # show_data(data)
+