Diff of /produceDatabase.py [000000] .. [409112]

Switch to unified view

a b/produceDatabase.py
1
"""
2
    Produce train set and test set based on Apnea-ECG database.
3
"""
4
import numpy as np
5
from ECGSegment import ECGSegment
6
import os
7
8
def produce_database(database_name):
9
    """
10
    
11
    :param database_name:
12
    :return None:
13
    """
14
    
15
    if database_name == ["apnea-ecg", "train"] or database_name == ["apnea-ecg", "test"]:
16
        clear_id_set = np.load(database_name[0] + "_" + database_name[1] + "_clear_id.npy")
17
    else:
18
        raise Exception("Error database name.")
19
    dataset = []
20
    RRI_set = []
21
    RAMP_set = []
22
    EDR_set = []
23
    label_set = []
24
    for id in clear_id_set:
25
        eds = ECGSegment()
26
        eds.global_id = id
27
        eds.read_ecg_segment(1, database_name)
28
        eds.read_rri_ramp_edr()
29
        label_set.append(eds.label)
30
        RRI_set.append(eds.RR_intervals)
31
        RAMP_set.append(eds.R_peaks_amplitude)
32
        EDR_set.append(eds.EDR)
33
        dataset.append(eds)
34
        
35
    # substract mean of RRI,RAMP and EDR, RAMP * 10, EDR * 10000
36
    mean = np.mean(RRI_set, axis=1)
37
    mean = np.reshape(mean, (mean.shape[0], 1))
38
    rri_set = RRI_set - mean
39
    rri_set = np.reshape(rri_set, (rri_set.shape[0], rri_set.shape[1], 1))
40
    mean = np.mean(RAMP_set, axis=1)
41
    mean = np.reshape(mean, (mean.shape[0], 1))
42
    ramp_set = RAMP_set - mean
43
    ramp_set = np.reshape(ramp_set, (ramp_set.shape[0], ramp_set.shape[1], 1))
44
    ramp_set = ramp_set * 100
45
    mean = np.mean(EDR_set, axis=1)
46
    mean = np.reshape(mean, (mean.shape[0], 1))
47
    edr_set = EDR_set - mean
48
    edr_set = np.reshape(edr_set, (edr_set.shape[0], edr_set.shape[1], 1))
49
    edr_set = edr_set * 10000
50
    rri_amp_edr_set = np.concatenate([rri_set, ramp_set, edr_set], axis=2)
51
    
52
    if not os.path.exists("data/"):
53
        os.makedirs("data/")
54
    np.save("data/" + database_name[0] + "_" + database_name[1] + "_clear_dataset.npy", np.array(dataset))
55
    np.save("data/" + database_name[0] + "_" + database_name[1] + "_clear_rri.npy", np.array(rri_set))
56
    np.save("data/" + database_name[0] + "_" + database_name[1] + "_clear_ramp.npy", np.array(ramp_set))
57
    np.save("data/" + database_name[0] + "_" + database_name[1] + "_clear_edr.npy", np.array(edr_set))
58
    np.save("data/" + database_name[0] + "_" + database_name[1] + "_clear_rri_ramp_edr.npy", np.array(rri_amp_edr_set))
59
    np.save("data/" + database_name[0] + "_" + database_name[1] + "_clear_label.npy", np.array(label_set))
60
61
62
if __name__ == '__main__':
63
    # produce_database(["apnea-ecg", "train"])
64
    produce_database(["apnea-ecg", "test"])
65
    
66