In [21]:
import keras
import math
import numpy as np
import os
import sklearn.metrics as skm
import sys
sys.path.append("../../../ecg")

import load
import util
%matplotlib inline

In [2]:
model_path = "/deep/group/awni/ecg_models/default/1527627404-9/0.337-0.880-012-0.255-0.906.hdf5"
data_json = "../dev.json"

preproc = util.load(os.path.dirname(model_path))
dataset = load.load_dataset(data_json)
ecgs, labels = preproc.process(*dataset)

model = keras.models.load_model(model_path)
probs = model.predict(ecgs, verbose=1)

100%|██████████| 8761/8761 [00:03<00:00, 2561.29it/s]




In [25]:
def stats(ground_truth, preds):
    labels = range(ground_truth.shape[2])
    g = np.argmax(ground_truth, axis=2).ravel()
    p = np.argmax(preds, axis=2).ravel()
    stat_dict = {}
    for i in labels:
        # compute all the stats for each label
        tp = np.sum(g[g==i] == p[g==i])
        fp = np.sum(g[p==i] != p[p==i])
        fn = np.sum(g==i) - tp
        tn = np.sum(g!=i) - fp
        stat_dict[i] = (tp, fp, fn, tn)
    return stat_dict

def to_set(preds):
    idxs = np.argmax(preds, axis=2)
    return [list(set(r)) for r in idxs]

def set_stats(ground_truth, preds):
    labels = range(ground_truth.shape[2])
    ground_truth = to_set(ground_truth)
    preds = to_set(preds)
    stat_dict = {}
    for x in labels:
        tp = 0; fp = 0; fn = 0; tn = 0;
        for g, p in zip(ground_truth, preds):
            if x in g and x in p: # tp
                tp += 1
            if x not in g and x in p: # fp
                fp += 1
            if x in g and x not in p:
                fn += 1
            if x not in g and x not in p:
                tn += 1
        stat_dict[x] = (tp, fp, fn, tn)
    return stat_dict

def compute_f1(tp, fp, fn, tn):
    precision = tp / float(tp + fp)
    recall = tp / float(tp + fn)
    specificity = tn / float(tn + fp)
    npv = tn / float(tn + fn)
    f1 = 2 * precision * recall / (precision + recall)
    return f1, tp + fn

def print_results(seq_sd, set_sd):
    print "\t\t Seq F1    Set F1"
    seq_tf1 = 0; seq_tot = 0
    set_tf1 = 0; set_tot = 0
    for k, v in seq_sd.items():
        set_f1, n = compute_f1(*set_sd[k])
        set_tf1 += n * set_f1
        set_tot += n
        seq_f1, n = compute_f1(*v)
        seq_tf1 += n * seq_f1
        seq_tot += n
        print "{:>10} {:10.3f} {:10.3f}".format(
            preproc.classes[k], seq_f1, set_f1)
    print "{:>10} {:10.3f} {:10.3f}".format(
        "Average", seq_tf1 / float(seq_tot), set_tf1 / float(set_tot))
    
def c_statistic_with_95p_confidence_interval(cstat, num_positives, num_negatives, z_alpha_2=1.96):
    """
    Calculates the confidence interval of an ROC curve (c-statistic), using the method described
    under "Confidence Interval for AUC" here:
      https://ncss-wpengine.netdna-ssl.com/wp-content/themes/ncss/pdf/Procedures/PASS/Confidence_Intervals_for_the_Area_Under_an_ROC_Curve.pdf
    Args:
        cstat: the c-statistic (equivalent to area under the ROC curve)
        num_positives: number of positive examples in the set.
        num_negatives: number of negative examples in the set.
        z_alpha_2 (optional): the critical value for an N% confidence interval, e.g., 1.96 for 95%,
            2.326 for 98%, 2.576 for 99%, etc.
    Returns:
        The 95% confidence interval half-width, e.g., the Y in X ± Y.
    """
    q1 = cstat / (2 - cstat)
    q2 = 2 * cstat**2 / (1 + cstat)
    numerator = cstat * (1 - cstat) \
        + (num_positives - 1) * (q1 - cstat**2) \
        + (num_negatives - 1) * (q2 - cstat**2)
    standard_error_auc = math.sqrt(numerator / (num_positives * num_negatives))
    return z_alpha_2 * standard_error_auc

def roc_auc(ground_truth, probs, index):
    gts = np.argmax(ground_truth, axis=2)
    n_gts = np.zeros_like(gts)
    n_gts[gts==index] = 1
    n_pos = np.sum(n_gts == 1)
    n_neg = n_gts.size - n_pos
    n_ps = probs[..., index].squeeze()
    n_gts, n_ps = n_gts.ravel(), n_ps.ravel()
    return n_pos, n_neg, skm.roc_auc_score(n_gts, n_ps)

def roc_auc_set(ground_truth, probs, index):
    gts = np.argmax(ground_truth, axis=2)
    max_ps = np.max(probs[...,index], axis=1)
    max_gts = np.any(gts==index, axis=1)
    pos = np.sum(max_gts)
    neg = max_gts.size - pos
    return pos, neg, skm.roc_auc_score(max_gts, max_ps)

def print_aucs(ground_truth, probs):
    seq_tauc = 0.0; seq_tot = 0.0
    set_tauc = 0.0; set_tot = 0.0
    print "\t        AUC"
    for idx, cname in preproc.int_to_class.items():
        pos, neg, seq_auc = roc_auc(ground_truth, probs, idx)
        seq_tot += pos
        seq_tauc += pos * seq_auc
        seq_conf = c_statistic_with_95p_confidence_interval(seq_auc, pos, neg)
        pos, neg, set_auc = roc_auc_set(ground_truth, probs, idx)
        set_tot += pos
        set_tauc += pos * set_auc
        set_conf = c_statistic_with_95p_confidence_interval(set_auc, pos, neg)
        print "{: <8}\t{:.3f} ({:.3f}-{:.3f})\t{:.3f} ({:.3f}-{:.3f})".format(
            cname, seq_auc, seq_auc-seq_conf,seq_auc+seq_conf,
            set_auc, set_auc-set_conf, set_auc+set_conf)
    print "Average\t\t{:.3f}\t{:.3f}".format(seq_tauc/seq_tot, set_tauc/set_tot)

In [26]:
print_results(stats(labels, probs), set_stats(labels, probs))
print_aucs(labels, probs)

		 Seq F1    Set F1
        AF      0.914      0.914
       AVB      0.805      0.839
  BIGEMINY      0.917      0.896
       EAR      0.652      0.699
       IVR      0.721      0.758
JUNCTIONAL      0.706      0.740
     NOISE      0.911      0.847
     SINUS      0.920      0.960
       SVT      0.700      0.812
 TRIGEMINY      0.924      0.918
        VT      0.769      0.848
WENCKEBACH      0.779      0.822
   Average      0.879      0.889
	        AUC
AF      	0.994 (0.994-0.995)	0.994 (0.991-0.996)
AVB     	0.992 (0.990-0.993)	0.990 (0.985-0.995)
BIGEMINY	0.999 (0.998-1.000)	0.998 (0.994-1.001)
EAR     	0.977 (0.975-0.980)	0.967 (0.957-0.977)
IVR     	0.996 (0.994-0.998)	0.991 (0.984-0.998)
JUNCTIONAL	0.987 (0.985-0.989)	0.984 (0.976-0.992)
NOISE   	0.994 (0.993-0.994)	0.978 (0.973-0.984)
SINUS   	0.979 (0.979-0.980)	0.987 (0.985-0.989)
SVT     	0.986 (0.984-0.988)	0.983 (0.977-0.989)
TRIGEMINY	0.999 (0.999-1.000)	0.998 (0.994-1.001)
VT      	0.997 (0.995-0.998)	0.992 (0.988-0.9