--- a +++ b/sybil/utils/metrics.py @@ -0,0 +1,385 @@ +from collections import OrderedDict +from sklearn.metrics import ( + accuracy_score, + precision_score, + recall_score, + f1_score, + roc_auc_score, + precision_recall_curve, + auc, + average_precision_score, +) +import numpy as np +import warnings + +EPSILON = 1e-6 +BINARY_CLASSIF_THRESHOLD = 0.5 + + +def get_classification_metrics(logging_dict, args): + stats_dict = OrderedDict() + + golds = np.array(logging_dict["golds"]).reshape(-1) + probs = np.array(logging_dict["probs"]) + preds = (probs[:, -1] > 0.5).reshape(-1) + probs = probs.reshape((-1, probs.shape[-1])) + + stats_dict["accuracy"] = accuracy_score(y_true=golds, y_pred=preds) + + if (args.num_classes == 2) and not ( + np.unique(golds)[-1] > 1 or np.unique(preds)[-1] > 1 + ): + stats_dict["precision"] = precision_score(y_true=golds, y_pred=preds) + stats_dict["recall"] = recall_score(y_true=golds, y_pred=preds) + stats_dict["f1"] = f1_score(y_true=golds, y_pred=preds) + num_pos = golds.sum() + if num_pos > 0 and num_pos < len(golds): + stats_dict["auc"] = roc_auc_score(golds, probs[:, -1], average="samples") + stats_dict["ap_score"] = average_precision_score( + golds, probs[:, -1], average="samples" + ) + precision, recall, _ = precision_recall_curve(golds, probs[:, -1]) + stats_dict["prauc"] = auc(recall, precision) + + return stats_dict + + +def get_survival_metrics(logging_dict, args): + stats_dict = {} + + censor_times, probs, golds = ( + logging_dict["censors"], + logging_dict["probs"], + logging_dict["golds"], + ) + for followup in range(args.max_followup): + min_followup_if_neg = followup + 1 + roc_auc, ap_score, pr_auc = compute_auc_at_followup( + probs, censor_times, golds, followup + ) + stats_dict["{}_year_auc".format(min_followup_if_neg)] = roc_auc + stats_dict["{}_year_apscore".format(min_followup_if_neg)] = ap_score + stats_dict["{}_year_prauc".format(min_followup_if_neg)] = pr_auc + + if np.array(golds).sum() > 0: + stats_dict["c_index"] = concordance_index( + logging_dict["censors"], probs, golds, args.censoring_distribution + ) + else: + stats_dict["c_index"] = -1.0 + return stats_dict + + +def get_alignment_metrics(logging_dict, args): + stats_dict = OrderedDict() + + golds = np.array(logging_dict["discrim_golds"]).reshape(-1) + probs = np.array(logging_dict["discrim_probs"]) + preds = probs.argmax(axis=-1).reshape(-1) + probs = probs.reshape((-1, probs.shape[-1])) + + stats_dict["discrim_accuracy"] = accuracy_score(y_true=golds, y_pred=preds) + stats_dict["discrim_precision"] = precision_score(y_true=golds, y_pred=preds) + stats_dict["discrim_recall"] = recall_score(y_true=golds, y_pred=preds) + stats_dict["discrim_f1"] = f1_score(y_true=golds, y_pred=preds) + num_pos = golds.sum() + if num_pos > 0 and num_pos < len(golds): + try: + stats_dict["discrim_auc"] = roc_auc_score( + golds, probs[:, -1], average="samples" + ) + stats_dict["discrim_ap_score"] = average_precision_score( + golds, probs[:, -1], average="samples" + ) + precision, recall, _ = precision_recall_curve(golds, probs[:, -1]) + stats_dict["discrim_prauc"] = auc(recall, precision) + except Exception as e: + print(e) + + return stats_dict + + +def get_risk_metrics(logging_dict, args): + stats_dict = {} + censor_times, probs, golds = ( + logging_dict["censors"], + logging_dict["probs"], + logging_dict["golds"], + ) + for followup in range(args.max_followup): + min_followup_if_neg = followup + 1 + roc_auc, ap_score, pr_auc = compute_auc_at_followup( + probs, censor_times, golds, followup, fup_lower_bound=0 + ) + stats_dict["{}_year_risk_auc".format(min_followup_if_neg)] = roc_auc + stats_dict["{}_year_risk_apscore".format(min_followup_if_neg)] = ap_score + stats_dict["{}_year_risk_prauc".format(min_followup_if_neg)] = pr_auc + + return stats_dict + + +def compute_auc_at_followup(probs, censor_times, golds, followup, fup_lower_bound=-1): + golds, censor_times = golds.ravel(), censor_times.ravel() + if len(probs.shape) == 3: + probs = probs.reshape(probs.shape[0] * probs.shape[1], probs.shape[2]) + + def include_exam_and_determine_label(prob_arr, censor_time, gold): + valid_pos = gold and censor_time <= followup and censor_time > fup_lower_bound + valid_neg = censor_time >= followup + included, label = (valid_pos or valid_neg), valid_pos + return included, label + + probs_for_eval, golds_for_eval = [], [] + for prob_arr, censor_time, gold in zip(probs, censor_times, golds): + include, label = include_exam_and_determine_label(prob_arr, censor_time, gold) + if include: + probs_for_eval.append(prob_arr[followup]) + golds_for_eval.append(label) + + try: + roc_auc = roc_auc_score(golds_for_eval, probs_for_eval, average="samples") + ap_score = average_precision_score( + golds_for_eval, probs_for_eval, average="samples" + ) + precision, recall, _ = precision_recall_curve(golds_for_eval, probs_for_eval) + pr_auc = auc(recall, precision) + except Exception as e: + warnings.warn("Failed to calculate AUC because {}".format(e)) + roc_auc = -1.0 + ap_score = -1.0 + pr_auc = -1.0 + return roc_auc, ap_score, pr_auc + + +def get_censoring_dist(train_dataset): + from lifelines import KaplanMeierFitter + _dataset = train_dataset.dataset + times, event_observed = ( + [d["time_at_event"] for d in _dataset], + [d["y"] for d in _dataset], + ) + all_observed_times = set(times) + kmf = KaplanMeierFitter() + kmf.fit(times, event_observed) + + censoring_dist = {str(time): kmf.predict(time) for time in all_observed_times} + return censoring_dist + + +def concordance_index( + event_times, predicted_scores, event_observed=None, censoring_dist=None +): + """ + ## Adapted from: https://raw.githubusercontent.com/CamDavidsonPilon/lifelines/master/lifelines/utils/concordance.py + ## Modified to weight by ipcw (inverse probality of censor weight) to fit Uno's C-index + ## Modified to use a time-dependent score + + Calculates the concordance index (C-index) between two series + of event times. The first is the real survival times from + the experimental data, and the other is the predicted survival + times from a model of some kind. + + The c-index is the average of how often a model says X is greater than Y when, in the observed + data, X is indeed greater than Y. The c-index also handles how to handle censored values + (obviously, if Y is censored, it's hard to know if X is truly greater than Y). + + + The concordance index is a value between 0 and 1 where: + + - 0.5 is the expected result from random predictions, + - 1.0 is perfect concordance and, + - 0.0 is perfect anti-concordance (multiply predictions with -1 to get 1.0) + + Parameters + ---------- + event_times: iterable + a length-n iterable of observed survival times. + predicted_scores: iterable + a length-n iterable of predicted scores - these could be survival times, or hazards, etc. See https://stats.stackexchange.com/questions/352183/use-median-survival-time-to-calculate-cph-c-statistic/352435#352435 + event_observed: iterable, optional + a length-n iterable censorship flags, 1 if observed, 0 if not. Default None assumes all observed. + + Returns + ------- + c-index: float + a value between 0 and 1. + + References + ----------- + Harrell FE, Lee KL, Mark DB. Multivariable prognostic models: issues in + developing models, evaluating assumptions and adequacy, and measuring and + reducing errors. Statistics in Medicine 1996;15(4):361-87. + + Examples + -------- + + >>> from lifelines.utils import concordance_index + >>> cph = CoxPHFitter().fit(df, 'T', 'E') + >>> concordance_index(df['T'], -cph.predict_partial_hazard(df), df['E']) + + """ + event_times = np.array(event_times).ravel() + predicted_scores = 1 - np.asarray(predicted_scores, dtype=float) + if len(predicted_scores.shape) == 3: + predicted_scores = predicted_scores.reshape( + [ + predicted_scores.shape[0] * predicted_scores.shape[1], + predicted_scores.shape[2], + ] + ) + + if event_observed is None: + event_observed = np.ones(event_times.shape[0], dtype=float) + else: + event_observed = np.asarray(event_observed, dtype=float).ravel() + if event_observed.shape != event_times.shape: + raise ValueError( + "Observed events must be 1-dimensional of same length as event times" + ) + + num_correct, num_tied, num_pairs = _concordance_summary_statistics( + event_times, predicted_scores, event_observed, censoring_dist + ) + + return _concordance_ratio(num_correct, num_tied, num_pairs) + + +def _concordance_ratio(num_correct, num_tied, num_pairs): + if num_pairs == 0: + raise ZeroDivisionError("No admissable pairs in the dataset.") + return (num_correct + num_tied / 2) / num_pairs + + +def _concordance_summary_statistics( + event_times, predicted_event_times, event_observed, censoring_dist +): # pylint: disable=too-many-locals + """Find the concordance index in n * log(n) time. + + Assumes the data has been verified by lifelines.utils.concordance_index first. + """ + # Here's how this works. + # + # It would be pretty easy to do if we had no censored data and no ties. There, the basic idea + # would be to iterate over the cases in order of their true event time (from least to greatest), + # while keeping track of a pool of *predicted* event times for all cases previously seen (= all + # cases that we know should be ranked lower than the case we're looking at currently). + # + # If the pool has O(log n) insert and O(log n) RANK (i.e., "how many things in the pool have + # value less than x"), then the following algorithm is n log n: + # + # Sort the times and predictions by time, increasing + # n_pairs, n_correct := 0 + # pool := {} + # for each prediction p: + # n_pairs += len(pool) + # n_correct += rank(pool, p) + # add p to pool + # + # There are three complications: tied ground truth values, tied predictions, and censored + # observations. + # + # - To handle tied true event times, we modify the inner loop to work in *batches* of observations + # p_1, ..., p_n whose true event times are tied, and then add them all to the pool + # simultaneously at the end. + # + # - To handle tied predictions, which should each count for 0.5, we switch to + # n_correct += min_rank(pool, p) + # n_tied += count(pool, p) + # + # - To handle censored observations, we handle each batch of tied, censored observations just + # after the batch of observations that died at the same time (since those censored observations + # are comparable all the observations that died at the same time or previously). However, we do + # NOT add them to the pool at the end, because they are NOT comparable with any observations + # that leave the study afterward--whether or not those observations get censored. + if np.logical_not(event_observed).all(): + return (0, 0, 0) + + observed_times = set(event_times) + + died_mask = event_observed.astype(bool) + # TODO: is event_times already sorted? That would be nice... + died_truth = event_times[died_mask] + ix = np.argsort(died_truth) + died_truth = died_truth[ix] + died_pred = predicted_event_times[died_mask][ix] + + censored_truth = event_times[~died_mask] + ix = np.argsort(censored_truth) + censored_truth = censored_truth[ix] + censored_pred = predicted_event_times[~died_mask][ix] + + from lifelines.utils.btree import _BTree + censored_ix = 0 + died_ix = 0 + times_to_compare = {} + for time in observed_times: + times_to_compare[time] = _BTree(np.unique(died_pred[:, int(time)])) + num_pairs = np.int64(0) + num_correct = np.int64(0) + num_tied = np.int64(0) + + # we iterate through cases sorted by exit time: + # - First, all cases that died at time t0. We add these to the sortedlist of died times. + # - Then, all cases that were censored at time t0. We DON'T add these since they are NOT + # comparable to subsequent elements. + while True: + has_more_censored = censored_ix < len(censored_truth) + has_more_died = died_ix < len(died_truth) + # Should we look at some censored indices next, or died indices? + if has_more_censored and ( + not has_more_died or died_truth[died_ix] > censored_truth[censored_ix] + ): + pairs, correct, tied, next_ix, weight = _handle_pairs( + censored_truth, + censored_pred, + censored_ix, + times_to_compare, + censoring_dist, + ) + censored_ix = next_ix + elif has_more_died and ( + not has_more_censored or died_truth[died_ix] <= censored_truth[censored_ix] + ): + pairs, correct, tied, next_ix, weight = _handle_pairs( + died_truth, died_pred, died_ix, times_to_compare, censoring_dist + ) + for pred in died_pred[died_ix:next_ix]: + for time in observed_times: + times_to_compare[time].insert(pred[int(time)]) + died_ix = next_ix + else: + assert not (has_more_died or has_more_censored) + break + + num_pairs += pairs * weight + num_correct += correct * weight + num_tied += tied * weight + + return (num_correct, num_tied, num_pairs) + + +def _handle_pairs(truth, pred, first_ix, times_to_compare, censoring_dist): + """ + Handle all pairs that exited at the same time as truth[first_ix]. + + Returns + ------- + (pairs, correct, tied, next_ix) + new_pairs: The number of new comparisons performed + new_correct: The number of comparisons correctly predicted + next_ix: The next index that needs to be handled + """ + next_ix = first_ix + truth_time = truth[first_ix] + weight = 1.0 / (censoring_dist[str(int(truth_time))] ** 2) + while next_ix < len(truth) and truth[next_ix] == truth[first_ix]: + next_ix += 1 + pairs = len(times_to_compare[truth_time]) * (next_ix - first_ix) + correct = np.int64(0) + tied = np.int64(0) + for i in range(first_ix, next_ix): + rank, count = times_to_compare[truth_time].rank(pred[i][int(truth_time)]) + correct += rank + tied += count + + return (pairs, correct, tied, next_ix, weight)