Diff of /utils.py [000000] .. [857d1b]

Switch to unified view

a b/utils.py
1
# -*- coding:utf-8 -*-
2
import os
3
import cv2
4
import pandas as pd
5
import numpy as np
6
import configparser as cp
7
import matplotlib.pyplot as plt
8
9
RAW_DATA_PATH = '/home/tony/fall_research/fall_data/MobiAct_Dataset_v2.0/Annotated Data/'
10
11
Label = {'STD': 1, 'WAL': 2, 'JOG': 3, 'JUM': 4, 'STU': 5, 'STN': 6, 'SCH': 7, 'SIT': 8, 'CHU': 9,
12
         'LYI': 10, 'FOL': 0, 'FKL': 0, 'BSC': 0, 'SDL': 0, 'CSI': 15, 'CSO': 16}
13
14
def extract_data(data_file, sampling_frequency):
15
    """
16
    从mobileFall中提取数据,用于做实验测试
17
    :param data_file:  原始数据文件
18
    :param sampling_frequency: 原始数据采集频率
19
    :return:
20
    """
21
    data = pd.read_csv(data_file, index_col=0)
22
    data_size = len(data.label)
23
    for i in range(data_size):
24
        data.iat[i, 10] = Label[data.iloc[i, 10]]
25
26
    col_data = np.arange(0, data_size, int(sampling_frequency/50))
27
    extract_data = data.iloc[col_data, [1, 2, 3, 4, 5, 6, 10]]
28
29
    save_path = './dataset/raw/' + os.path.abspath(os.path.dirname(data_file)+os.path.sep+".").replace(RAW_DATA_PATH, '')
30
    if not os.path.exists(save_path):
31
        os.makedirs(save_path)
32
    save_path = './dataset/raw/' + data_file.replace(RAW_DATA_PATH, '')
33
    extract_data.to_csv(save_path, index=0)
34
35
def find_all_data_and_extract(path):
36
    """
37
    递归的查找所有文件并进行转化
38
    :param path:
39
    :return:
40
    """
41
    if not os.path.exists(path):
42
        print('路径存在问题:', path)
43
        return None
44
45
    for i in os.listdir(path):
46
        if os.path.isfile(path+"/"+i):
47
            if 'csv' in i:
48
                extract_data(path+"/"+i, 200)
49
        else:
50
            find_all_data_and_extract(path+"/"+i)
51
52
def parser_cfg_file(cfg_file):
53
    """
54
    读取配置文件中的信息
55
    :param cfg_file: 文件路径
56
    :return:
57
    """
58
    content_params = {}
59
60
    config = cp.ConfigParser()
61
    config.read(cfg_file)
62
63
    for section in config.sections():
64
        # 获取配置文件中的net信息
65
        if section == 'net':
66
            for option in config.options(section):
67
                content_params[option] = config.get(section,option)
68
69
        # 获取配置文件中的train信息
70
        if section == 'train':
71
            for option in config.options(section):
72
                content_params[option] = config.get(section,option)
73
74
    return content_params
75
76
def show_data(data, name=None):
77
    '''
78
    show data
79
    :param data: DataFrame
80
    :return:
81
    '''
82
    num = data.acc_x.size
83
84
    x = np.arange(num)
85
    fig = plt.figure(1, figsize=(100, 60))
86
    # 子表1绘制加速度传感器数据
87
    plt.subplot(2, 1, 1)
88
    plt.title('acc')
89
    plt.plot(x, data.acc_x, label='x')
90
    plt.plot(x, data.acc_y, label='y')
91
    plt.plot(x, data.acc_z, label='z')
92
93
    # 添加解释图标
94
    plt.legend()
95
    x_flag = np.arange(0, num, num / 10)
96
    plt.xticks(x_flag)
97
98
    # 子表2绘制陀螺仪传感器数据
99
    plt.subplot(2, 1, 2)
100
    plt.title('gyro')
101
    plt.plot(x, data.gyro_x, label='x')
102
    plt.plot(x, data.gyro_y, label='y')
103
    plt.plot(x, data.gyro_z, label='z')
104
105
    plt.legend()
106
    plt.xticks(x_flag)
107
    #plt.show()
108
    if name is None:
109
        plt.show()
110
    else:
111
        plt.savefig(name)
112
    plt.close()
113
114
def kalman_filter(data):
115
    kalman = cv2.KalmanFilter(6, 6)
116
    kalman.measurementMatrix = np.array([[1, 0, 0, 0, 0, 0],
117
                                         [0, 1, 0, 0, 0, 0],
118
                                         [0, 0, 1, 0, 0, 0],
119
                                         [0, 0, 0, 1, 0, 0],
120
                                         [0, 0, 0, 0, 1, 0],
121
                                         [0, 0, 0, 0, 0, 1]], np.float32)
122
    kalman.transitionMatrix = np.array([[1, 0, 0, 0, 0, 0],
123
                                         [0, 1, 0, 0, 0, 0],
124
                                         [0, 0, 1, 0, 0, 0],
125
                                         [0, 0, 0, 1, 0, 0],
126
                                         [0, 0, 0, 0, 1, 0],
127
                                         [0, 0, 0, 0, 0, 1]], np.float32)
128
    kalman.processNoiseCov = np.array([[1, 0, 0, 0, 0, 0],
129
                                       [0, 1, 0, 0, 0, 0],
130
                                       [0, 0, 1, 0, 0, 0],
131
                                       [0, 0, 0, 1, 0, 0],
132
                                       [0, 0, 0, 0, 1, 0],
133
                                       [0, 0, 0, 0, 0, 1]], np.float32) * 0.003
134
    kalman.measurementNoiseCov = np.array([[1, 0, 0, 0, 0, 0],
135
                                          [0, 1, 0, 0, 0, 0],
136
                                          [0, 0, 1, 0, 0, 0],
137
                                          [0, 0, 0, 1, 0, 0],
138
                                          [0, 0, 0, 0, 1, 0],
139
                                          [0, 0, 0, 0, 0, 1]], np.float32) * 1
140
141
    row_num = data.acc_x.size
142
143
    for i in range(row_num):
144
        correct = np.array(data.iloc[i, 0:6].values, np.float32).reshape([6, 1])
145
        kalman.correct(correct)
146
        predict = kalman.predict()
147
        data.iloc[i, 0] = predict[0]
148
        data.iloc[i, 1] = predict[1]
149
        data.iloc[i, 2] = predict[2]
150
        data.iloc[i, 3] = predict[3]
151
        data.iloc[i, 4] = predict[4]
152
        data.iloc[i, 5] = predict[5]
153
154
    return data
155
156
def find_all_data_and_filtrate(path):
157
    """
158
    递归的查找所有文件并进行kalman过滤
159
    :param path:
160
    :return:
161
    """
162
    if not os.path.exists(path):
163
        print('路径存在问题:', path)
164
        return None
165
166
    for i in os.listdir(path):
167
        if os.path.isfile(path+"/"+i):
168
            if 'csv' in i:
169
                data = pd.read_csv(path+"/"+i)
170
                data = kalman_filter(data)
171
                data.to_csv(path+"/"+i, index=False)
172
        else:
173
            find_all_data_and_filtrate(path+"/"+i)
174
175
def main():
176
    #find_all_data_and_extract(RAW_DATA_PATH)
177
    find_all_data_and_filtrate('./dataset/kalman/')
178
179
if __name__ == '__main__':
180
    main()
181
    # if os.path.exists('./dataset/train/BSC_1_1_annotated.csv') == False:
182
    #     print('./dataset/train/BSC_1_1_annotated.csv', '文件不存在!')
183
    # data = pd.read_csv('./dataset/train/BSC_1_1_annotated.csv')
184
    #
185
    # #show_data(data)
186
    # data = kalman_filter(data)
187
    # data.to_csv('./dataset/train/BSC_1_1_annotated.csv', index=False)
188
    # #show_data(data)
189
    # # a = data.iloc[4:5,0]
190
    # # print(a)
191
    # data = pd.read_csv('./dataset/train/STU_1_1_annotated.csv')
192
    #
193
    # show_data(data)
194