Diff of /sybil/utils/metrics.py [000000] .. [d9566e]

Switch to side-by-side view

--- a
+++ b/sybil/utils/metrics.py
@@ -0,0 +1,385 @@
+from collections import OrderedDict
+from sklearn.metrics import (
+    accuracy_score,
+    precision_score,
+    recall_score,
+    f1_score,
+    roc_auc_score,
+    precision_recall_curve,
+    auc,
+    average_precision_score,
+)
+import numpy as np
+import warnings
+
+EPSILON = 1e-6
+BINARY_CLASSIF_THRESHOLD = 0.5
+
+
+def get_classification_metrics(logging_dict, args):
+    stats_dict = OrderedDict()
+
+    golds = np.array(logging_dict["golds"]).reshape(-1)
+    probs = np.array(logging_dict["probs"])
+    preds = (probs[:, -1] > 0.5).reshape(-1)
+    probs = probs.reshape((-1, probs.shape[-1]))
+
+    stats_dict["accuracy"] = accuracy_score(y_true=golds, y_pred=preds)
+
+    if (args.num_classes == 2) and not (
+        np.unique(golds)[-1] > 1 or np.unique(preds)[-1] > 1
+    ):
+        stats_dict["precision"] = precision_score(y_true=golds, y_pred=preds)
+        stats_dict["recall"] = recall_score(y_true=golds, y_pred=preds)
+        stats_dict["f1"] = f1_score(y_true=golds, y_pred=preds)
+        num_pos = golds.sum()
+        if num_pos > 0 and num_pos < len(golds):
+            stats_dict["auc"] = roc_auc_score(golds, probs[:, -1], average="samples")
+            stats_dict["ap_score"] = average_precision_score(
+                golds, probs[:, -1], average="samples"
+            )
+            precision, recall, _ = precision_recall_curve(golds, probs[:, -1])
+            stats_dict["prauc"] = auc(recall, precision)
+
+    return stats_dict
+
+
+def get_survival_metrics(logging_dict, args):
+    stats_dict = {}
+
+    censor_times, probs, golds = (
+        logging_dict["censors"],
+        logging_dict["probs"],
+        logging_dict["golds"],
+    )
+    for followup in range(args.max_followup):
+        min_followup_if_neg = followup + 1
+        roc_auc, ap_score, pr_auc = compute_auc_at_followup(
+            probs, censor_times, golds, followup
+        )
+        stats_dict["{}_year_auc".format(min_followup_if_neg)] = roc_auc
+        stats_dict["{}_year_apscore".format(min_followup_if_neg)] = ap_score
+        stats_dict["{}_year_prauc".format(min_followup_if_neg)] = pr_auc
+
+    if np.array(golds).sum() > 0:
+        stats_dict["c_index"] = concordance_index(
+            logging_dict["censors"], probs, golds, args.censoring_distribution
+        )
+    else:
+        stats_dict["c_index"] = -1.0
+    return stats_dict
+
+
+def get_alignment_metrics(logging_dict, args):
+    stats_dict = OrderedDict()
+
+    golds = np.array(logging_dict["discrim_golds"]).reshape(-1)
+    probs = np.array(logging_dict["discrim_probs"])
+    preds = probs.argmax(axis=-1).reshape(-1)
+    probs = probs.reshape((-1, probs.shape[-1]))
+
+    stats_dict["discrim_accuracy"] = accuracy_score(y_true=golds, y_pred=preds)
+    stats_dict["discrim_precision"] = precision_score(y_true=golds, y_pred=preds)
+    stats_dict["discrim_recall"] = recall_score(y_true=golds, y_pred=preds)
+    stats_dict["discrim_f1"] = f1_score(y_true=golds, y_pred=preds)
+    num_pos = golds.sum()
+    if num_pos > 0 and num_pos < len(golds):
+        try:
+            stats_dict["discrim_auc"] = roc_auc_score(
+                golds, probs[:, -1], average="samples"
+            )
+            stats_dict["discrim_ap_score"] = average_precision_score(
+                golds, probs[:, -1], average="samples"
+            )
+            precision, recall, _ = precision_recall_curve(golds, probs[:, -1])
+            stats_dict["discrim_prauc"] = auc(recall, precision)
+        except Exception as e:
+            print(e)
+
+    return stats_dict
+
+
+def get_risk_metrics(logging_dict, args):
+    stats_dict = {}
+    censor_times, probs, golds = (
+        logging_dict["censors"],
+        logging_dict["probs"],
+        logging_dict["golds"],
+    )
+    for followup in range(args.max_followup):
+        min_followup_if_neg = followup + 1
+        roc_auc, ap_score, pr_auc = compute_auc_at_followup(
+            probs, censor_times, golds, followup, fup_lower_bound=0
+        )
+        stats_dict["{}_year_risk_auc".format(min_followup_if_neg)] = roc_auc
+        stats_dict["{}_year_risk_apscore".format(min_followup_if_neg)] = ap_score
+        stats_dict["{}_year_risk_prauc".format(min_followup_if_neg)] = pr_auc
+
+    return stats_dict
+
+
+def compute_auc_at_followup(probs, censor_times, golds, followup, fup_lower_bound=-1):
+    golds, censor_times = golds.ravel(), censor_times.ravel()
+    if len(probs.shape) == 3:
+        probs = probs.reshape(probs.shape[0] * probs.shape[1], probs.shape[2])
+
+    def include_exam_and_determine_label(prob_arr, censor_time, gold):
+        valid_pos = gold and censor_time <= followup and censor_time > fup_lower_bound
+        valid_neg = censor_time >= followup
+        included, label = (valid_pos or valid_neg), valid_pos
+        return included, label
+
+    probs_for_eval, golds_for_eval = [], []
+    for prob_arr, censor_time, gold in zip(probs, censor_times, golds):
+        include, label = include_exam_and_determine_label(prob_arr, censor_time, gold)
+        if include:
+            probs_for_eval.append(prob_arr[followup])
+            golds_for_eval.append(label)
+
+    try:
+        roc_auc = roc_auc_score(golds_for_eval, probs_for_eval, average="samples")
+        ap_score = average_precision_score(
+            golds_for_eval, probs_for_eval, average="samples"
+        )
+        precision, recall, _ = precision_recall_curve(golds_for_eval, probs_for_eval)
+        pr_auc = auc(recall, precision)
+    except Exception as e:
+        warnings.warn("Failed to calculate AUC because {}".format(e))
+        roc_auc = -1.0
+        ap_score = -1.0
+        pr_auc = -1.0
+    return roc_auc, ap_score, pr_auc
+
+
+def get_censoring_dist(train_dataset):
+    from lifelines import KaplanMeierFitter
+    _dataset = train_dataset.dataset
+    times, event_observed = (
+        [d["time_at_event"] for d in _dataset],
+        [d["y"] for d in _dataset],
+    )
+    all_observed_times = set(times)
+    kmf = KaplanMeierFitter()
+    kmf.fit(times, event_observed)
+
+    censoring_dist = {str(time): kmf.predict(time) for time in all_observed_times}
+    return censoring_dist
+
+
+def concordance_index(
+    event_times, predicted_scores, event_observed=None, censoring_dist=None
+):
+    """
+    ## Adapted from: https://raw.githubusercontent.com/CamDavidsonPilon/lifelines/master/lifelines/utils/concordance.py
+    ## Modified to weight by ipcw (inverse probality of censor weight) to fit Uno's C-index
+    ## Modified to use a time-dependent score
+
+    Calculates the concordance index (C-index) between two series
+    of event times. The first is the real survival times from
+    the experimental data, and the other is the predicted survival
+    times from a model of some kind.
+
+    The c-index is the average of how often a model says X is greater than Y when, in the observed
+    data, X is indeed greater than Y. The c-index also handles how to handle censored values
+    (obviously, if Y is censored, it's hard to know if X is truly greater than Y).
+
+
+    The concordance index is a value between 0 and 1 where:
+
+    - 0.5 is the expected result from random predictions,
+    - 1.0 is perfect concordance and,
+    - 0.0 is perfect anti-concordance (multiply predictions with -1 to get 1.0)
+
+    Parameters
+    ----------
+    event_times: iterable
+         a length-n iterable of observed survival times.
+    predicted_scores: iterable
+        a length-n iterable of predicted scores - these could be survival times, or hazards, etc. See https://stats.stackexchange.com/questions/352183/use-median-survival-time-to-calculate-cph-c-statistic/352435#352435
+    event_observed: iterable, optional
+        a length-n iterable censorship flags, 1 if observed, 0 if not. Default None assumes all observed.
+
+    Returns
+    -------
+    c-index: float
+      a value between 0 and 1.
+
+    References
+    -----------
+    Harrell FE, Lee KL, Mark DB. Multivariable prognostic models: issues in
+    developing models, evaluating assumptions and adequacy, and measuring and
+    reducing errors. Statistics in Medicine 1996;15(4):361-87.
+
+    Examples
+    --------
+
+    >>> from lifelines.utils import concordance_index
+    >>> cph = CoxPHFitter().fit(df, 'T', 'E')
+    >>> concordance_index(df['T'], -cph.predict_partial_hazard(df), df['E'])
+
+    """
+    event_times = np.array(event_times).ravel()
+    predicted_scores = 1 - np.asarray(predicted_scores, dtype=float)
+    if len(predicted_scores.shape) == 3:
+        predicted_scores = predicted_scores.reshape(
+            [
+                predicted_scores.shape[0] * predicted_scores.shape[1],
+                predicted_scores.shape[2],
+            ]
+        )
+
+    if event_observed is None:
+        event_observed = np.ones(event_times.shape[0], dtype=float)
+    else:
+        event_observed = np.asarray(event_observed, dtype=float).ravel()
+        if event_observed.shape != event_times.shape:
+            raise ValueError(
+                "Observed events must be 1-dimensional of same length as event times"
+            )
+
+    num_correct, num_tied, num_pairs = _concordance_summary_statistics(
+        event_times, predicted_scores, event_observed, censoring_dist
+    )
+
+    return _concordance_ratio(num_correct, num_tied, num_pairs)
+
+
+def _concordance_ratio(num_correct, num_tied, num_pairs):
+    if num_pairs == 0:
+        raise ZeroDivisionError("No admissable pairs in the dataset.")
+    return (num_correct + num_tied / 2) / num_pairs
+
+
+def _concordance_summary_statistics(
+    event_times, predicted_event_times, event_observed, censoring_dist
+):  # pylint: disable=too-many-locals
+    """Find the concordance index in n * log(n) time.
+
+    Assumes the data has been verified by lifelines.utils.concordance_index first.
+    """
+    # Here's how this works.
+    #
+    # It would be pretty easy to do if we had no censored data and no ties. There, the basic idea
+    # would be to iterate over the cases in order of their true event time (from least to greatest),
+    # while keeping track of a pool of *predicted* event times for all cases previously seen (= all
+    # cases that we know should be ranked lower than the case we're looking at currently).
+    #
+    # If the pool has O(log n) insert and O(log n) RANK (i.e., "how many things in the pool have
+    # value less than x"), then the following algorithm is n log n:
+    #
+    # Sort the times and predictions by time, increasing
+    # n_pairs, n_correct := 0
+    # pool := {}
+    # for each prediction p:
+    #     n_pairs += len(pool)
+    #     n_correct += rank(pool, p)
+    #     add p to pool
+    #
+    # There are three complications: tied ground truth values, tied predictions, and censored
+    # observations.
+    #
+    # - To handle tied true event times, we modify the inner loop to work in *batches* of observations
+    # p_1, ..., p_n whose true event times are tied, and then add them all to the pool
+    # simultaneously at the end.
+    #
+    # - To handle tied predictions, which should each count for 0.5, we switch to
+    #     n_correct += min_rank(pool, p)
+    #     n_tied += count(pool, p)
+    #
+    # - To handle censored observations, we handle each batch of tied, censored observations just
+    # after the batch of observations that died at the same time (since those censored observations
+    # are comparable all the observations that died at the same time or previously). However, we do
+    # NOT add them to the pool at the end, because they are NOT comparable with any observations
+    # that leave the study afterward--whether or not those observations get censored.
+    if np.logical_not(event_observed).all():
+        return (0, 0, 0)
+
+    observed_times = set(event_times)
+
+    died_mask = event_observed.astype(bool)
+    # TODO: is event_times already sorted? That would be nice...
+    died_truth = event_times[died_mask]
+    ix = np.argsort(died_truth)
+    died_truth = died_truth[ix]
+    died_pred = predicted_event_times[died_mask][ix]
+
+    censored_truth = event_times[~died_mask]
+    ix = np.argsort(censored_truth)
+    censored_truth = censored_truth[ix]
+    censored_pred = predicted_event_times[~died_mask][ix]
+
+    from lifelines.utils.btree import _BTree
+    censored_ix = 0
+    died_ix = 0
+    times_to_compare = {}
+    for time in observed_times:
+        times_to_compare[time] = _BTree(np.unique(died_pred[:, int(time)]))
+    num_pairs = np.int64(0)
+    num_correct = np.int64(0)
+    num_tied = np.int64(0)
+
+    # we iterate through cases sorted by exit time:
+    # - First, all cases that died at time t0. We add these to the sortedlist of died times.
+    # - Then, all cases that were censored at time t0. We DON'T add these since they are NOT
+    #   comparable to subsequent elements.
+    while True:
+        has_more_censored = censored_ix < len(censored_truth)
+        has_more_died = died_ix < len(died_truth)
+        # Should we look at some censored indices next, or died indices?
+        if has_more_censored and (
+            not has_more_died or died_truth[died_ix] > censored_truth[censored_ix]
+        ):
+            pairs, correct, tied, next_ix, weight = _handle_pairs(
+                censored_truth,
+                censored_pred,
+                censored_ix,
+                times_to_compare,
+                censoring_dist,
+            )
+            censored_ix = next_ix
+        elif has_more_died and (
+            not has_more_censored or died_truth[died_ix] <= censored_truth[censored_ix]
+        ):
+            pairs, correct, tied, next_ix, weight = _handle_pairs(
+                died_truth, died_pred, died_ix, times_to_compare, censoring_dist
+            )
+            for pred in died_pred[died_ix:next_ix]:
+                for time in observed_times:
+                    times_to_compare[time].insert(pred[int(time)])
+            died_ix = next_ix
+        else:
+            assert not (has_more_died or has_more_censored)
+            break
+
+        num_pairs += pairs * weight
+        num_correct += correct * weight
+        num_tied += tied * weight
+
+    return (num_correct, num_tied, num_pairs)
+
+
+def _handle_pairs(truth, pred, first_ix, times_to_compare, censoring_dist):
+    """
+    Handle all pairs that exited at the same time as truth[first_ix].
+
+    Returns
+    -------
+      (pairs, correct, tied, next_ix)
+      new_pairs: The number of new comparisons performed
+      new_correct: The number of comparisons correctly predicted
+      next_ix: The next index that needs to be handled
+    """
+    next_ix = first_ix
+    truth_time = truth[first_ix]
+    weight = 1.0 / (censoring_dist[str(int(truth_time))] ** 2)
+    while next_ix < len(truth) and truth[next_ix] == truth[first_ix]:
+        next_ix += 1
+    pairs = len(times_to_compare[truth_time]) * (next_ix - first_ix)
+    correct = np.int64(0)
+    tied = np.int64(0)
+    for i in range(first_ix, next_ix):
+        rank, count = times_to_compare[truth_time].rank(pred[i][int(truth_time)])
+        correct += rank
+        tied += count
+
+    return (pairs, correct, tied, next_ix, weight)