--- a
+++ b/clinical_ts/eval_utils_cafa.py
@@ -0,0 +1,278 @@
+# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03_eval_utils_cafa.ipynb (unless otherwise specified).
+
+__all__ = ['auc_prrc_uninterpolated', 'multiclass_roc_curve', 'single_eval_prrc', 'eval_prrc', 'eval_prrc_parallel',
+           'eval_scores', 'eval_scores_bootstrap']
+
+# Cell
+import warnings
+import numpy as np
+import pandas as pd
+from sklearn.metrics import roc_auc_score, auc
+from scipy.interpolate import interp1d
+
+from sklearn.metrics import roc_curve, precision_recall_curve
+from sklearn.utils import resample
+
+from tqdm import tqdm
+
+# Cell
+def auc_prrc_uninterpolated(recall,precision):
+    '''uninterpolated auc as used by sklearn https://github.com/scikit-learn/scikit-learn/blob/1495f6924/sklearn/metrics/ranking.py see also the discussion at https://github.com/scikit-learn/scikit-learn/pull/9583'''
+    #print(-np.sum(np.diff(recall) * np.array(precision)[:-1]),auc(recall,precision))
+    return -np.sum(np.diff(recall) * np.array(precision)[:-1])
+
+# Cell
+#label-centric metrics
+def multiclass_roc_curve(y_true, y_pred, classes=None, precision_recall=False):
+    '''Compute ROC curve and ROC area for each class "0"..."n_classes - 1" (or classnames passed via classes), "micro", "macro"
+    returns fpr,tpr,roc (dictionaries) for ROC
+    returns recall,precision,average_precision for precision_recall
+    '''
+
+    fpr = dict()
+    tpr = dict()
+    roc_auc = dict()
+    n_classes=len(y_pred[0])
+    if(classes is None):
+        classes = [str(i) for i in range(n_classes)]
+
+    for i,c in enumerate(classes):
+        if(precision_recall):
+            tpr[c], fpr[c], _ = precision_recall_curve(y_true[:, i], y_pred[:, i])
+            roc_auc[c] = auc_prrc_uninterpolated(fpr[c], tpr[c])
+        else:
+            fpr[c], tpr[c], _ = roc_curve(y_true[:, i], y_pred[:, i])
+            roc_auc[c] = auc(fpr[c], tpr[c])
+
+    # Compute micro-average curve and area
+    if(precision_recall):
+        tpr["micro"], fpr["micro"], _ = precision_recall_curve(y_true.ravel(), y_pred.ravel())
+        roc_auc["micro"] = auc_prrc_uninterpolated(fpr["micro"], tpr["micro"])
+    else:
+        fpr["micro"], tpr["micro"], _ = roc_curve(y_true.ravel(), y_pred.ravel())
+        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
+
+    # Compute macro-average curve and area (linear interpolation is incorrect for PRRC- therefore just for ROC)
+    if(precision_recall is False):
+        # 1. First aggregate all unique x values (false positive rates for ROC)
+        all_fpr = np.unique(np.concatenate([fpr[c] for c in classes]))
+
+        # 2. Then interpolate all curves at this points
+        mean_tpr=None
+        for c in classes:
+            f = interp1d(fpr[c], tpr[c])
+            if(mean_tpr is None):
+                mean_tpr = f(all_fpr)
+            else:
+                mean_tpr += f(all_fpr)
+
+        # 3. Finally average it and compute area
+        mean_tpr /= n_classes
+
+        fpr["macro"] = all_fpr
+        tpr["macro"] = mean_tpr
+        #macro2 differs slightly from macro due to interpolation effects
+        #roc_auc["macro2"] = auc(fpr["macro"], tpr["macro"])
+
+    #calculate macro auc directly by summing
+    roc_auc_macro = 0
+    for c in classes:
+        roc_auc_macro += roc_auc[c]
+    roc_auc["macro"]=roc_auc_macro/n_classes
+
+    #calculate macro auc directly by summing
+    roc_auc_macro = 0
+    macro_auc_nans = 0 #due to an insufficient amount of pos/neg labels
+    for c in classes:
+        if(np.isnan(roc_auc[c])):#conservative choice: replace auc by 0.5 if it could not be calculated
+            roc_auc_macro += 0.5
+            macro_auc_nans += 1
+        else:
+            roc_auc_macro += roc_auc[c]
+    roc_auc["macro"]=roc_auc_macro/n_classes
+    roc_auc["macro_nans"] = macro_auc_nans
+
+    return fpr, tpr, roc_auc
+
+# Cell
+def single_eval_prrc(y_true,y_pred,threshold):
+    '''evaluate instance-wise scores for a single sample and a single threshold'''
+    y_pred_bin = (y_pred >= threshold)
+    TP = np.sum(np.logical_and(y_true == y_pred_bin,y_true>0))
+    count = np.sum(y_pred_bin)#TP+FP
+
+    # Find precision: TP / (TP + FP)
+    precision = TP / count if count > 0 else np.nan
+    # Find recall/TPR/sensitivity: TP / (TP + FN)
+    recall = TP/np.sum(y_true>0)
+    # Find FPR/specificity: FP/ (FP + TN)=FP/N
+    FP = np.sum(np.logical_and(y_true != y_pred_bin,y_pred_bin>0))
+    specificity = FP/ np.sum(y_true==0)
+    return precision, recall, specificity
+
+# Cell
+def eval_prrc(y_true,y_pred,threshold):
+    '''eval instance-wise scores across all samples for a single threshold'''
+    # Initialize Variables
+    PR = 0.0
+    RC = 0.0
+    SP = 0.0
+
+    counts_above_threshold = 0
+
+    for i in range(len(y_true)):
+        pr,rc,sp = single_eval_prrc(y_true[i],y_pred[i],threshold)
+        if pr is not np.nan:
+            PR += pr
+            counts_above_threshold += 1
+        RC += rc
+        SP += sp
+
+    recall = RC/len(y_true)
+    specificity = SP/len(y_true)
+
+    if counts_above_threshold > 0:
+        precision = PR/counts_above_threshold
+    else:
+        precision = np.nan
+        if(threshold<1.0):
+            print("No prediction is made above the %.2f threshold\n" % threshold)
+    return precision, recall, specificity, counts_above_threshold/len(y_true)
+
+# Cell
+def eval_prrc_parallel(y_true,y_pred,thresholds):
+
+    y_pred_bin = np.repeat(y_pred[None, :, :], len(thresholds), axis=0)>=thresholds[:,None,None]#thresholds, samples, classes
+    TP = np.sum(np.logical_and( y_true == True, y_pred_bin== True),axis=2)#threshold, samples
+
+    with np.errstate(divide='ignore', invalid='ignore'):
+        den = np.sum(y_pred_bin,axis=2)>0
+        precision = TP/np.sum(y_pred_bin,axis=2)
+        precision[den==0] = np.nan
+
+    recall = TP/np.sum(y_true==True, axis=1)#threshold,samples/samples=threshold,samples
+
+    FP = np.sum(np.logical_and((y_true ==False),(y_pred_bin==True)),axis=2)
+    specificity = FP/np.sum(y_true==False, axis=1)
+
+    with warnings.catch_warnings(): #for nan slices
+        warnings.simplefilter("ignore", category=RuntimeWarning)
+        av_precision = np.nanmean(precision,axis=1)
+
+    av_recall = np.mean(recall,axis=1)
+    av_specificity = np.mean(specificity,axis=1)
+    av_coverage = np.mean(den,axis=1)
+
+    return av_precision, av_recall, av_specificity, av_coverage
+
+
+# Cell
+def eval_scores(y_true,y_pred,classes=None,num_thresholds=100,full_output=False,parallel=True):
+    '''returns a dictionary of performance metrics:
+    sample centric c.f. https://github.com/ashleyzhou972/CAFA_assessment_tool/blob/master/precrec/precRec.py
+    https://www.nature.com/articles/nmeth.2340 vs https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3694662/ and https://arxiv.org/pdf/1601.00891
+    * Fmax, sample AUC, sample Average Precision (as in sklearn)
+
+    label-centric: micro,macro,individual AUC and Average Precision
+    '''
+    results = {}
+
+    # thresholds = np.arange(0.00, 1.01, 1./num_thresholds, float)
+    # if(parallel is False):
+    #     PR = np.zeros(len(thresholds))
+    #     RC = np.zeros(len(thresholds))
+    #     SP = np.zeros(len(thresholds))
+    #     COV = np.zeros(len(thresholds))
+
+    #     for i,t in enumerate(thresholds):
+    #         PR[i],RC[i],SP[i],COV[i] = eval_prrc(y_true,y_pred,t)
+    #     F =  (2*PR*RC)/(PR+RC)
+    # else:
+    #     PR,RC,SP,COV = eval_prrc_parallel(y_true,y_pred,thresholds)
+    #     F = (2*PR*RC)/(PR+RC)
+
+    # if(full_output is True):
+    #     results["PR"] = PR
+    #     results["RC"] = RC
+    #     results["SP"] = SP
+    #     results["F"] = F
+    #     results["COV"] = COV
+
+    # if np.isnan(F).sum() == len(F):
+    #     results["Fmax"] = 0
+    #     results["precision_at_Fmax"] = 0
+    #     results["recall_at_Fmax"] = 0
+    #     results["threshold_at_Fmax"] = 0
+    #     results["coverage_at_Fmax"]= 0
+    # else:
+    #     imax = np.nanargmax(F)
+    #     results["Fmax"] = F[imax]
+    #     results["precision_at_Fmax"] = PR[imax]
+    #     results["recall_at_Fmax"] = RC[imax]
+    #     results["threshold_at_Fmax"] = thresholds[imax]
+    #     results["coverage_at_Fmax"]=COV[imax]
+
+    # results["sample_AUC"]=auc(1-SP,RC)
+    # #https://github.com/scikit-learn/scikit-learn/blob/1495f6924/sklearn/metrics/ranking.py set final PR value to 1
+    # PR[-1]=1
+    # results["sample_APR"]=auc_prrc_uninterpolated(RC,PR)#skip last point with undefined precision
+    ###########################################################
+    #label-centric
+    #"micro","macro",i=0...n_classes-1
+    fpr, tpr, roc_auc = multiclass_roc_curve(y_true, y_pred,classes=classes,precision_recall=False)
+    if(full_output is True):
+        results["fpr"]=fpr
+        results["tpr"]=tpr
+    results["label_AUC"]=roc_auc
+
+    # rc, pr, prrc_auc = multiclass_roc_curve(y_true, y_pred,classes=classes,precision_recall=True)
+    # if(full_output is True):
+    #     results["pr"]=pr
+    #     results["rc"]=rc
+    # results["label_APR"]=prrc_auc
+
+    return results
+
+# Cell
+def eval_scores_bootstrap(y_true, y_pred,classes=None, n_iterations = 10000, alpha=0.95):
+    #https://ocw.mit.edu/courses/mathematics/18-05-introduction-to-probability-and-statistics-spring-2014/readings/MIT18_05S14_Reading24.pdf empirical bootstrap rather than bootstrap percentiles
+    Fmax_diff = []
+    sample_AUC_diff = []
+    sample_APR_diff = []
+    label_AUC_diff = []
+    label_APR_diff = []
+    label_AUC_keys = None
+
+    #point estimate
+    res_point = eval_scores(y_true,y_pred,classes=classes)
+    Fmax_point = res_point["Fmax"]
+    sample_AUC_point = res_point["sample_AUC"]
+    sample_APR_point = res_point["sample_APR"]
+    label_AUC_point = np.array(list(res_point["label_AUC"].values()))
+    label_APR_point = np.array(list(res_point["label_APR"].values()))
+
+    #bootstrap
+    for i in tqdm(range(n_iterations)):
+        ids = resample(range(len(y_true)), n_samples=len(y_true))
+        res = eval_scores(y_true[ids],y_pred[ids],classes=classes)
+        Fmax_diff.append(res["Fmax"]-Fmax_point)
+        sample_AUC_diff.append(res["sample_AUC"]-sample_AUC_point)
+        sample_APR_diff.append(res["sample_APR"]-sample_APR_point)
+        label_AUC_keys = list(res["label_AUC"].keys())
+        label_AUC_diff.append(np.array(list(res["label_AUC"].values()))-label_AUC_point)
+        label_APR_diff.append(np.array(list(res["label_APR"].values()))-label_APR_point)
+
+    p = ((1.0-alpha)/2.0) * 100
+    Fmax_low = Fmax_point + np.percentile(Fmax_diff, p)
+    sample_AUC_low = sample_AUC_point + np.percentile(sample_AUC_diff, p)
+    sample_APR_low = sample_APR_point + np.percentile(sample_APR_diff, p)
+    label_AUC_low = label_AUC_point + np.percentile(label_AUC_diff,p,axis=0)
+    label_APR_low = label_APR_point + np.percentile(label_APR_diff,p,axis=0)
+    p = (alpha+((1.0-alpha)/2.0)) * 100
+    Fmax_high = Fmax_point + np.percentile(Fmax_diff, p)
+    sample_AUC_high = sample_AUC_point + np.percentile(sample_AUC_diff, p)
+    sample_APR_high = sample_APR_point + np.percentile(sample_APR_diff, p)
+    label_AUC_high = label_AUC_point + np.percentile(label_AUC_diff,p,axis=0)
+    label_APR_high = label_APR_point + np.percentile(label_APR_diff,p,axis=0)
+
+    return {"Fmax":[Fmax_low,Fmax_point,Fmax_high], "sample_AUC":[sample_AUC_low,sample_AUC_point,sample_AUC_high], "sample_APR":[sample_APR_low,sample_APR_point,sample_APR_high], "label_AUC":{k:[v1,v2,v3] for k,v1,v2,v3 in zip(label_AUC_keys,label_AUC_low,label_AUC_point,label_AUC_high)}, "label_APR":{k:[v1,v2,v3] for k,v1,v2,v3 in zip(label_AUC_keys,label_APR_low,label_APR_point,label_APR_high)}}
\ No newline at end of file