Switch to unified view

a b/scripts/dataset-generation.py
1
import argparse
2
import os
3
import os.path as osp
4
5
import cv2
6
import matplotlib.pyplot as plt
7
import numpy as np
8
import wfdb
9
from sklearn.preprocessing import scale
10
from wfdb import rdrecord
11
12
# Choose from peak to peak or centered
13
# mode = [20, 20]
14
mode = 128
15
16
image_size = 128
17
output_dir = "../data"
18
19
# dpi fix
20
fig = plt.figure(frameon=False)
21
dpi = fig.dpi
22
23
# fig size / image size
24
figsize = (image_size / dpi, image_size / dpi)
25
image_size = (image_size, image_size)
26
27
28
def plot(signal, filename):
29
    plt.figure(figsize=figsize, frameon=False)
30
    plt.axis("off")
31
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
32
    # plt.margins(0, 0) # use for generation images with no margin
33
    plt.plot(signal)
34
    plt.savefig(filename)
35
36
    plt.close()
37
38
    im_gray = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
39
    im_gray = cv2.resize(im_gray, image_size, interpolation=cv2.INTER_LANCZOS4)
40
    cv2.imwrite(filename, im_gray)
41
42
43
if __name__ == "__main__":
44
45
    parser = argparse.ArgumentParser()
46
    parser.add_argument("--file", required=True)
47
    args = parser.parse_args()
48
49
    ecg = args.file
50
    name = osp.basename(ecg)
51
    record = rdrecord(ecg)
52
    ann = wfdb.rdann(ecg, extension="atr")
53
    for sig_name, signal in zip(record.sig_name, record.p_signal.T):
54
        if not np.all(np.isfinite(signal)):
55
            continue
56
        signal = scale(signal)
57
        for i, (label, peak) in enumerate(zip(ann.symbol, ann.sample)):
58
            if label == "/":
59
                label = "\\"
60
            print("\r{} [{}/{}]".format(sig_name, i + 1, len(ann.symbol)), end="")
61
            if isinstance(mode, list):
62
                if np.all([i > 0, i + 1 < len(ann.sample)]):
63
                    left = ann.sample[i - 1] + mode[0]
64
                    right = ann.sample[i + 1] - mode[1]
65
                else:
66
                    continue
67
            elif isinstance(mode, int):
68
                left, right = peak - mode // 2, peak + mode // 2
69
            else:
70
                raise Exception("Wrong mode in script beginning")
71
72
            if np.all([left > 0, right < len(signal)]):
73
                one_dim_data_dir = osp.join(output_dir, "1D", name, sig_name, label)
74
                two_dim_data_dir = osp.join(output_dir, "2D", name, sig_name, label)
75
                os.makedirs(one_dim_data_dir, exist_ok=True)
76
                os.makedirs(two_dim_data_dir, exist_ok=True)
77
78
                filename = osp.join(one_dim_data_dir, "{}.npy".format(peak))
79
                np.save(filename, signal[left:right])
80
                filename = osp.join(two_dim_data_dir, "{}.png".format(peak))
81
82
                plot(signal[left:right], filename)