Download this file

160 lines (138 with data), 5.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
References
----------
[1] http://2019.icbeb.org/Challenge.html
"""
import math
from numbers import Real
from typing import Sequence, Union
import numpy as np
__all__ = [
"compute_metrics",
]
def compute_metrics(
rpeaks_truths: Sequence[Union[np.ndarray, Sequence[int]]],
rpeaks_preds: Sequence[Union[np.ndarray, Sequence[int]]],
fs: Real,
thr: float = 0.075,
verbose: int = 0,
) -> float:
"""
Parameters
----------
rpeaks_truths: sequence,
sequence of ground truths of rpeaks locations (indices) from multiple records
rpeaks_preds: sequence,
predictions of ground truths of rpeaks locations (indices) for multiple records
fs: real number,
sampling frequency of ECG signal
thr: float, default 0.075,
threshold for a prediction to be truth positive,
with units in seconds,
verbose: int, default 0,
print verbosity
Returns
-------
rec_acc: float,
accuracy of predictions
"""
assert len(rpeaks_truths) == len(
rpeaks_preds
), f"number of records does not match, truth indicates {len(rpeaks_truths)}, while pred indicates {len(rpeaks_preds)}"
n_records = len(rpeaks_truths)
record_flags = np.ones((len(rpeaks_truths),), dtype=float)
thr_ = thr * fs
if verbose >= 1:
print(f"number of records = {n_records}")
print(f"threshold in number of sample points = {thr_}")
for idx, (truth_arr, pred_arr) in enumerate(zip(rpeaks_truths, rpeaks_preds)):
false_negative = 0
false_positive = 0
true_positive = 0
extended_truth_arr = np.concatenate((truth_arr.astype(int), [int(9.5 * fs)]))
for j, t_ind in enumerate(extended_truth_arr[:-1]):
next_t_ind = extended_truth_arr[j + 1]
loc = np.where(np.abs(pred_arr - t_ind) <= thr_)[0]
if j == 0:
err = np.where((pred_arr >= 0.5 * fs + thr_) & (pred_arr <= t_ind - thr_))[0]
else:
err = np.array([], dtype=int)
err = np.append(
err,
np.where((pred_arr >= t_ind + thr_) & (pred_arr <= next_t_ind - thr_))[0],
)
false_positive += len(err)
if len(loc) >= 1:
true_positive += 1
false_positive += len(loc) - 1
elif len(loc) == 0:
false_negative += 1
if false_negative + false_positive > 1:
record_flags[idx] = 0
elif false_negative == 1 and false_positive == 0:
record_flags[idx] = 0.3
elif false_negative == 0 and false_positive == 1:
record_flags[idx] = 0.7
if verbose >= 2:
print(
f"for the {idx}-th record,\ntrue positive = {true_positive}\nfalse positive = {false_positive}\nfalse negative = {false_negative}"
)
rec_acc = round(np.sum(record_flags) / n_records, 4)
print(f"QRS_acc: {rec_acc}")
print("Scoring complete.")
return rec_acc
def score(r_ref, hr_ref, r_ans, hr_ans, fs_, thr_):
"""
the official scoring function
"""
HR_score = 0
record_flags = np.ones(len(r_ref))
for i in range(len(r_ref)):
FN = 0
FP = 0
TP = 0
if math.isnan(hr_ans[i]):
hr_ans[i] = 0
hr_der = abs(int(hr_ans[i]) - int(hr_ref[i]))
if hr_der <= 0.02 * hr_ref[i]:
HR_score = HR_score + 1
elif hr_der <= 0.05 * hr_ref[i]:
HR_score = HR_score + 0.75
elif hr_der <= 0.1 * hr_ref[i]:
HR_score = HR_score + 0.5
elif hr_der <= 0.2 * hr_ref[i]:
HR_score = HR_score + 0.25
r_ref[i] = r_ref[i].astype(int) # added by wenh06
for j in range(len(r_ref[i])):
loc = np.where(np.abs(r_ans[i] - r_ref[i][j]) <= thr_ * fs_)[0]
if j == 0:
err = np.where((r_ans[i] >= 0.5 * fs_ + thr_ * fs_) & (r_ans[i] <= r_ref[i][j] - thr_ * fs_))[0]
# comments by wenh06:
# elif j == len(r_ref[i])-1:
# the above would falsely omit the interval between the 0-th and the 1-st ref rpeaks
# for example for
# r_ref = [np.array([500, 1000])]
# r_ans = [np.array([500, 700, 1000])]
# a false positive is missed
if j == len(r_ref[i]) - 1:
err = np.where((r_ans[i] >= r_ref[i][j] + thr_ * fs_) & (r_ans[i] <= 9.5 * fs_ - thr_ * fs_))[0]
else:
err = np.where((r_ans[i] >= r_ref[i][j] + thr_ * fs_) & (r_ans[i] <= r_ref[i][j + 1] - thr_ * fs_))[0]
FP = FP + len(err)
if len(loc) >= 1:
TP += 1
FP = FP + len(loc) - 1
elif len(loc) == 0:
FN += 1
if FN + FP > 1:
record_flags[i] = 0
elif FN == 1 and FP == 0:
record_flags[i] = 0.3
elif FN == 0 and FP == 1:
record_flags[i] = 0.7
rec_acc = round(np.sum(record_flags) / len(r_ref), 4)
hr_acc = round(HR_score / len(r_ref), 4)
print("QRS_acc: {}".format(rec_acc))
print("HR_acc: {}".format(hr_acc))
print("Scoring complete.")
return rec_acc, hr_acc