--- a +++ b/src/utils/metrics.py @@ -0,0 +1,2240 @@ +""" +Custom binary prediction metrics using Avalanche +https://github.com/ContinualAI/avalanche/blob/master/notebooks/from-zero-to-hero-tutorial/05_evaluation.ipynb +""" + +from typing import List, Union, Dict +from collections import defaultdict + +import torch +import numpy as np +from torch import Tensor, arange +from avalanche.evaluation import Metric, PluginMetric, GenericPluginMetric +from avalanche.evaluation.metrics.mean import Mean +from avalanche.evaluation.metric_utils import phase_and_task + +from sklearn.metrics import average_precision_score, roc_auc_score + + +def confusion(prediction, truth): + """Returns the confusion matrix for the values in the `prediction` and `truth` + tensors, i.e. the amount of positions where the values of `prediction` + and `truth` are + - 1 and 1 (True Positive) + - 1 and 0 (False Positive) + - 0 and 0 (True Negative) + - 0 and 1 (False Negative) + + Source: https://gist.github.com/the-bass/cae9f3976866776dea17a5049013258d + """ + + confusion_vector = prediction / truth + # Element-wise division of the 2 tensors returns a new tensor which holds a + # unique value for each case: + # 1 where prediction and truth are 1 (True Positive) + # inf where prediction is 1 and truth is 0 (False Positive) + # nan where prediction and truth are 0 (True Negative) + # 0 where prediction is 0 and truth is 1 (False Negative) + + true_positives = torch.sum(confusion_vector == 1).item() + false_positives = torch.sum(confusion_vector == float("inf")).item() + true_negatives = torch.sum(torch.isnan(confusion_vector)).item() + false_negatives = torch.sum(confusion_vector == 0).item() + + return true_positives, false_positives, true_negatives, false_negatives + + +# https://github.com/ContinualAI/avalanche/blob/master/avalanche/evaluation/metrics/mean_scores.py +# Use above for AUPRC etc templates. + + +class BalancedAccuracy(Metric[float]): + """ + The BalancedAccuracy metric. This is a standalone metric. + + The metric keeps a dictionary of <task_label, balancedaccuracy value> pairs. + and update the values through a running average over multiple + <prediction, target> pairs of Tensors, provided incrementally. + The "prediction" and "target" tensors may contain plain labels or + one-hot/logit vectors. + + Each time `result` is called, this metric emits the average balancedaccuracy + across all predictions made since the last `reset`. + + The reset method will bring the metric to its initial state. By default + this metric in its initial state will return an balancedaccuracy value of 0. + """ + + def __init__(self): + """ + Creates an instance of the standalone BalancedAccuracy metric. + + By default this metric in its initial state will return an balancedaccuracy + value of 0. The metric can be updated by using the `update` method + while the running balancedaccuracy can be retrieved using the `result` method. + """ + super().__init__() + self._mean_balancedaccuracy = defaultdict(Mean) + """ + The mean utility that will be used to store the running balancedaccuracy + for each task label. + """ + + @torch.no_grad() + def update( + self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor] + ) -> None: + """ + Update the running balancedaccuracy given the true and predicted labels. + Parameter `task_labels` is used to decide how to update the inner + dictionary: if Float, only the dictionary value related to that task + is updated. If Tensor, all the dictionary elements belonging to the + task labels will be updated. + + :param predicted_y: The model prediction. Both labels and logit vectors + are supported. + :param true_y: The ground truth. Both labels and one-hot vectors + are supported. + :param task_labels: the int task label associated to the current + experience or the task labels vector showing the task label + for each pattern. + + :return: None. + """ + if len(true_y) != len(predicted_y): + raise ValueError("Size mismatch for true_y and predicted_y tensors") + + if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y): + raise ValueError("Size mismatch for true_y and task_labels tensors") + + true_y = torch.as_tensor(true_y) + predicted_y = torch.as_tensor(predicted_y) + + # Check if logits or labels + if len(predicted_y.shape) > 1: + # Logits -> transform to labels + predicted_y = torch.max(predicted_y, 1)[1] + + if len(true_y.shape) > 1: + # Logits -> transform to labels + true_y = torch.max(true_y, 1)[1] + + if isinstance(task_labels, int): + ( + true_positives, + false_positives, + true_negatives, + false_negatives, + ) = confusion(predicted_y, true_y) + + try: + tpr = true_positives / (true_positives + false_negatives) + except ZeroDivisionError: + tpr = 1 + + try: + tnr = true_negatives / (true_negatives + false_positives) + except ZeroDivisionError: + tnr = 1 + + self._mean_balancedaccuracy[task_labels].update( + (tpr + tnr) / 2, len(predicted_y) + ) + elif isinstance(task_labels, Tensor): + raise NotImplementedError + else: + raise ValueError( + f"Task label type: {type(task_labels)}, expected int/float or Tensor" + ) + + def result(self, task_label=None) -> Dict[int, float]: + """ + Retrieves the running balancedaccuracy. + + Calling this method will not change the internal state of the metric. + + :param task_label: if None, return the entire dictionary of balanced accuracies + for each task. Otherwise return the dictionary + `{task_label: balancedaccuracy}`. + :return: A dict of running balanced accuracies for each task label, + where each value is a float value between 0 and 1. + """ + assert task_label is None or isinstance(task_label, int) + if task_label is None: + return {k: v.result() for k, v in self._mean_balancedaccuracy.items()} + else: + return {task_label: self._mean_balancedaccuracy[task_label].result()} + + def reset(self, task_label=None) -> None: + """ + Resets the metric. + :param task_label: if None, reset the entire dictionary. + Otherwise, reset the value associated to `task_label`. + + :return: None. + """ + assert task_label is None or isinstance(task_label, int) + if task_label is None: + self._mean_balancedaccuracy = defaultdict(Mean) + else: + self._mean_balancedaccuracy[task_label].reset() + + +class BalancedAccuracyPluginMetric(GenericPluginMetric[float]): + """ + Base class for all balanced accuracies plugin metrics + """ + + def __init__(self, reset_at, emit_at, mode): + self._balancedaccuracy = BalancedAccuracy() + super(BalancedAccuracyPluginMetric, self).__init__( + self._balancedaccuracy, reset_at=reset_at, emit_at=emit_at, mode=mode + ) + + def reset(self, strategy=None) -> None: + if self._reset_at == "stream" or strategy is None: + self._metric.reset() + else: + self._metric.reset(phase_and_task(strategy)[1]) + + def result(self, strategy=None) -> float: + if self._emit_at == "stream" or strategy is None: + return self._metric.result() + else: + return self._metric.result(phase_and_task(strategy)[1]) + + def update(self, strategy): + # task labels defined for each experience + task_labels = strategy.experience.task_labels + if len(task_labels) > 1: + # task labels defined for each pattern + task_labels = strategy.mb_task_id + else: + task_labels = task_labels[0] + self._balancedaccuracy.update(strategy.mb_output, strategy.mb_y, task_labels) + + +class MinibatchBalancedAccuracy(BalancedAccuracyPluginMetric): + """ + The minibatch plugin balancedaccuracy metric. + This metric only works at training time. + + This metric computes the average balancedaccuracy over patterns + from a single minibatch. + It reports the result after each iteration. + + If a more coarse-grained logging is needed, consider using + :class:`EpochBalancedAccuracy` instead. + """ + + def __init__(self): + """ + Creates an instance of the MinibatchBalancedAccuracy metric. + """ + super(MinibatchBalancedAccuracy, self).__init__( + reset_at="iteration", emit_at="iteration", mode="train" + ) + + def __str__(self): + return "BalAcc_MB" + + +class EpochBalancedAccuracy(BalancedAccuracyPluginMetric): + """ + The average balancedaccuracy over a single training epoch. + This plugin metric only works at training time. + + The balancedaccuracy will be logged after each training epoch by computing + the number of correctly predicted patterns during the epoch divided by + the overall number of patterns encountered in that epoch. + """ + + def __init__(self): + """ + Creates an instance of the EpochBalancedAccuracy metric. + """ + + super(EpochBalancedAccuracy, self).__init__( + reset_at="epoch", emit_at="epoch", mode="train" + ) + + def __str__(self): + return "BalAcc_Epoch" + + +class RunningEpochBalancedAccuracy(BalancedAccuracyPluginMetric): + """ + The average balancedaccuracy across all minibatches up to the current + epoch iteration. + This plugin metric only works at training time. + + At each iteration, this metric logs the balancedaccuracy averaged over all patterns + seen so far in the current epoch. + The metric resets its state after each training epoch. + """ + + def __init__(self): + """ + Creates an instance of the RunningEpochBalancedAccuracy metric. + """ + + super(RunningEpochBalancedAccuracy, self).__init__( + reset_at="epoch", emit_at="iteration", mode="train" + ) + + def __str__(self): + return "RunningBalAcc_Epoch" + + +class ExperienceBalancedAccuracy(BalancedAccuracyPluginMetric): + """ + At the end of each experience, this plugin metric reports + the average balancedaccuracy over all patterns seen in that experience. + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of ExperienceBalancedAccuracy metric + """ + super(ExperienceBalancedAccuracy, self).__init__( + reset_at="experience", emit_at="experience", mode="eval" + ) + + def __str__(self): + return "BalAcc_Exp" + + +class StreamBalancedAccuracy(BalancedAccuracyPluginMetric): + """ + At the end of the entire stream of experiences, this plugin metric + reports the average balancedaccuracy over all patterns seen in all experiences. + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of StreamBalancedAccuracy metric + """ + super(StreamBalancedAccuracy, self).__init__( + reset_at="stream", emit_at="stream", mode="eval" + ) + + def __str__(self): + return "BalAcc_Stream" + + +class TrainedExperienceBalancedAccuracy(BalancedAccuracyPluginMetric): + """ + At the end of each experience, this plugin metric reports the average + balancedaccuracy for only the experiences that the model has been trained on so far. + + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of TrainedExperienceBalancedAccuracy metric by first + constructing BalancedAccuracyPluginMetric + """ + super(TrainedExperienceBalancedAccuracy, self).__init__( + reset_at="stream", emit_at="stream", mode="eval" + ) + self._current_experience = 0 + + def after_training_exp(self, strategy) -> None: + self._current_experience = strategy.experience.current_experience + # Reset average after learning from a new experience + BalancedAccuracyPluginMetric.reset(self, strategy) + return BalancedAccuracyPluginMetric.after_training_exp(self, strategy) + + def update(self, strategy): + """ + Only update the balancedaccuracy with results from experiences that have been + trained on + """ + if strategy.experience.current_experience <= self._current_experience: + BalancedAccuracyPluginMetric.update(self, strategy) + + def __str__(self): + return "BalancedAccuracy_On_Trained_Experiences" + + +def balancedaccuracy_metrics( + *, + minibatch=False, + epoch=False, + epoch_running=False, + experience=False, + stream=False, + trained_experience=False, +) -> List[PluginMetric]: + """ + Helper method that can be used to obtain the desired set of + plugin metrics. + + :param minibatch: If True, will return a metric able to log + the minibatch balancedaccuracy at training time. + :param epoch: If True, will return a metric able to log + the epoch balancedaccuracy at training time. + :param epoch_running: If True, will return a metric able to log + the running epoch balancedaccuracy at training time. + :param experience: If True, will return a metric able to log + the balancedaccuracy on each evaluation experience. + :param stream: If True, will return a metric able to log + the balancedaccuracy averaged over the entire evaluation stream of experiences. + :param trained_experience: If True, will return a metric able to log + the average evaluation balancedaccuracy only for experiences that the + model has been trained on + + :return: A list of plugin metrics. + """ + + metrics = [] + if minibatch: + metrics.append(MinibatchBalancedAccuracy()) + + if epoch: + metrics.append(EpochBalancedAccuracy()) + + if epoch_running: + metrics.append(RunningEpochBalancedAccuracy()) + + if experience: + metrics.append(ExperienceBalancedAccuracy()) + + if stream: + metrics.append(StreamBalancedAccuracy()) + + if trained_experience: + metrics.append(TrainedExperienceBalancedAccuracy()) + + return metrics + + +class Sensitivity(Metric[float]): + """ + The Sensitivity metric. This is a standalone metric. + + The metric keeps a dictionary of <task_label, Sensitivity value> pairs. + and update the values through a running average over multiple + <prediction, target> pairs of Tensors, provided incrementally. + The "prediction" and "target" tensors may contain plain labels or + one-hot/logit vectors. + + Each time `result` is called, this metric emits the average Sensitivity + across all predictions made since the last `reset`. + + The reset method will bring the metric to its initial state. By default + this metric in its initial state will return an Sensitivity value of 0. + """ + + def __init__(self): + """ + Creates an instance of the standalone Sensitivity metric. + + By default this metric in its initial state will return an Sensitivity + value of 0. The metric can be updated by using the `update` method + while the running Sensitivity can be retrieved using the `result` method. + """ + super().__init__() + self._mean_Sensitivity = defaultdict(Mean) + """ + The mean utility that will be used to store the running Sensitivity + for each task label. + """ + + @torch.no_grad() + def update( + self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor] + ) -> None: + """ + Update the running Sensitivity given the true and predicted labels. + Parameter `task_labels` is used to decide how to update the inner + dictionary: if Float, only the dictionary value related to that task + is updated. If Tensor, all the dictionary elements belonging to the + task labels will be updated. + + :param predicted_y: The model prediction. Both labels and logit vectors + are supported. + :param true_y: The ground truth. Both labels and one-hot vectors + are supported. + :param task_labels: the int task label associated to the current + experience or the task labels vector showing the task label + for each pattern. + + :return: None. + """ + if len(true_y) != len(predicted_y): + raise ValueError("Size mismatch for true_y and predicted_y tensors") + + if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y): + raise ValueError("Size mismatch for true_y and task_labels tensors") + + true_y = torch.as_tensor(true_y) + predicted_y = torch.as_tensor(predicted_y) + + # Check if logits or labels + if len(predicted_y.shape) > 1: + # Logits -> transform to labels + predicted_y = torch.max(predicted_y, 1)[1] + + if len(true_y.shape) > 1: + # Logits -> transform to labels + true_y = torch.max(true_y, 1)[1] + + if isinstance(task_labels, int): + ( + true_positives, + false_positives, + true_negatives, + false_negatives, + ) = confusion(predicted_y, true_y) + + try: + tpr = true_positives / (true_positives + false_negatives) + except ZeroDivisionError: + tpr = 1 + + self._mean_Sensitivity[task_labels].update(tpr, len(predicted_y)) + elif isinstance(task_labels, Tensor): + raise NotImplementedError + else: + raise ValueError( + f"Task label type: {type(task_labels)}, expected int/float or Tensor" + ) + + def result(self, task_label=None) -> Dict[int, float]: + """ + Retrieves the running Sensitivity. + + Calling this method will not change the internal state of the metric. + + :param task_label: if None, return the entire dictionary of sensitivities + for each task. Otherwise return the dictionary + `{task_label: Sensitivity}`. + :return: A dict of running sensitivities for each task label, + where each value is a float value between 0 and 1. + """ + assert task_label is None or isinstance(task_label, int) + if task_label is None: + return {k: v.result() for k, v in self._mean_Sensitivity.items()} + else: + return {task_label: self._mean_Sensitivity[task_label].result()} + + def reset(self, task_label=None) -> None: + """ + Resets the metric. + :param task_label: if None, reset the entire dictionary. + Otherwise, reset the value associated to `task_label`. + + :return: None. + """ + assert task_label is None or isinstance(task_label, int) + if task_label is None: + self._mean_Sensitivity = defaultdict(Mean) + else: + self._mean_Sensitivity[task_label].reset() + + +class SensitivityPluginMetric(GenericPluginMetric[float]): + """ + Base class for all sensitivities plugin metrics + """ + + def __init__(self, reset_at, emit_at, mode): + self._Sensitivity = Sensitivity() + super(SensitivityPluginMetric, self).__init__( + self._Sensitivity, reset_at=reset_at, emit_at=emit_at, mode=mode + ) + + def reset(self, strategy=None) -> None: + if self._reset_at == "stream" or strategy is None: + self._metric.reset() + else: + self._metric.reset(phase_and_task(strategy)[1]) + + def result(self, strategy=None) -> float: + if self._emit_at == "stream" or strategy is None: + return self._metric.result() + else: + return self._metric.result(phase_and_task(strategy)[1]) + + def update(self, strategy): + # task labels defined for each experience + task_labels = strategy.experience.task_labels + if len(task_labels) > 1: + # task labels defined for each pattern + task_labels = strategy.mb_task_id + else: + task_labels = task_labels[0] + self._Sensitivity.update(strategy.mb_output, strategy.mb_y, task_labels) + + +class MinibatchSensitivity(SensitivityPluginMetric): + """ + The minibatch plugin Sensitivity metric. + This metric only works at training time. + + This metric computes the average Sensitivity over patterns + from a single minibatch. + It reports the result after each iteration. + + If a more coarse-grained logging is needed, consider using + :class:`EpochSensitivity` instead. + """ + + def __init__(self): + """ + Creates an instance of the MinibatchSensitivity metric. + """ + super(MinibatchSensitivity, self).__init__( + reset_at="iteration", emit_at="iteration", mode="train" + ) + + def __str__(self): + return "Sens_MB" + + +class EpochSensitivity(SensitivityPluginMetric): + """ + The average Sensitivity over a single training epoch. + This plugin metric only works at training time. + + The Sensitivity will be logged after each training epoch by computing + the number of correctly predicted patterns during the epoch divided by + the overall number of patterns encountered in that epoch. + """ + + def __init__(self): + """ + Creates an instance of the EpochSensitivity metric. + """ + + super(EpochSensitivity, self).__init__( + reset_at="epoch", emit_at="epoch", mode="train" + ) + + def __str__(self): + return "Sens_Epoch" + + +class RunningEpochSensitivity(SensitivityPluginMetric): + """ + The average Sensitivity across all minibatches up to the current + epoch iteration. + This plugin metric only works at training time. + + At each iteration, this metric logs the Sensitivity averaged over all patterns + seen so far in the current epoch. + The metric resets its state after each training epoch. + """ + + def __init__(self): + """ + Creates an instance of the RunningEpochSensitivity metric. + """ + + super(RunningEpochSensitivity, self).__init__( + reset_at="epoch", emit_at="iteration", mode="train" + ) + + def __str__(self): + return "RunningSens_Epoch" + + +class ExperienceSensitivity(SensitivityPluginMetric): + """ + At the end of each experience, this plugin metric reports + the average Sensitivity over all patterns seen in that experience. + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of ExperienceSensitivity metric + """ + super(ExperienceSensitivity, self).__init__( + reset_at="experience", emit_at="experience", mode="eval" + ) + + def __str__(self): + return "Sens_Exp" + + +class StreamSensitivity(SensitivityPluginMetric): + """ + At the end of the entire stream of experiences, this plugin metric + reports the average Sensitivity over all patterns seen in all experiences. + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of StreamSensitivity metric + """ + super(StreamSensitivity, self).__init__( + reset_at="stream", emit_at="stream", mode="eval" + ) + + def __str__(self): + return "Sens_Stream" + + +class TrainedExperienceSensitivity(SensitivityPluginMetric): + """ + At the end of each experience, this plugin metric reports the average + Sensitivity for only the experiences that the model has been trained on so far. + + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of TrainedExperienceSensitivity metric by first + constructing SensitivityPluginMetric + """ + super(TrainedExperienceSensitivity, self).__init__( + reset_at="stream", emit_at="stream", mode="eval" + ) + self._current_experience = 0 + + def after_training_exp(self, strategy) -> None: + self._current_experience = strategy.experience.current_experience + # Reset average after learning from a new experience + SensitivityPluginMetric.reset(self, strategy) + return SensitivityPluginMetric.after_training_exp(self, strategy) + + def update(self, strategy): + """ + Only update the Sensitivity with results from experiences that have been + trained on + """ + if strategy.experience.current_experience <= self._current_experience: + SensitivityPluginMetric.update(self, strategy) + + def __str__(self): + return "Sensitivity_On_Trained_Experiences" + + +def sensitivity_metrics( + *, + minibatch=False, + epoch=False, + epoch_running=False, + experience=False, + stream=False, + trained_experience=False, +) -> List[PluginMetric]: + """ + Helper method that can be used to obtain the desired set of + plugin metrics. + + :param minibatch: If True, will return a metric able to log + the minibatch Sensitivity at training time. + :param epoch: If True, will return a metric able to log + the epoch Sensitivity at training time. + :param epoch_running: If True, will return a metric able to log + the running epoch Sensitivity at training time. + :param experience: If True, will return a metric able to log + the Sensitivity on each evaluation experience. + :param stream: If True, will return a metric able to log + the Sensitivity averaged over the entire evaluation stream of experiences. + :param trained_experience: If True, will return a metric able to log + the average evaluation Sensitivity only for experiences that the + model has been trained on + + :return: A list of plugin metrics. + """ + + metrics = [] + if minibatch: + metrics.append(MinibatchSensitivity()) + + if epoch: + metrics.append(EpochSensitivity()) + + if epoch_running: + metrics.append(RunningEpochSensitivity()) + + if experience: + metrics.append(ExperienceSensitivity()) + + if stream: + metrics.append(StreamSensitivity()) + + if trained_experience: + metrics.append(TrainedExperienceSensitivity()) + + return metrics + + +class Specificity(Metric[float]): + """ + The Specificity metric. This is a standalone metric. + + The metric keeps a dictionary of <task_label, Specificity value> pairs. + and update the values through a running average over multiple + <prediction, target> pairs of Tensors, provided incrementally. + The "prediction" and "target" tensors may contain plain labels or + one-hot/logit vectors. + + Each time `result` is called, this metric emits the average Specificity + across all predictions made since the last `reset`. + + The reset method will bring the metric to its initial state. By default + this metric in its initial state will return an Specificity value of 0. + """ + + def __init__(self): + """ + Creates an instance of the standalone Specificity metric. + + By default this metric in its initial state will return an Specificity + value of 0. The metric can be updated by using the `update` method + while the running Specificity can be retrieved using the `result` method. + """ + super().__init__() + self._mean_Specificity = defaultdict(Mean) + """ + The mean utility that will be used to store the running Specificity + for each task label. + """ + + @torch.no_grad() + def update( + self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor] + ) -> None: + """ + Update the running Specificity given the true and predicted labels. + Parameter `task_labels` is used to decide how to update the inner + dictionary: if Float, only the dictionary value related to that task + is updated. If Tensor, all the dictionary elements belonging to the + task labels will be updated. + + :param predicted_y: The model prediction. Both labels and logit vectors + are supported. + :param true_y: The ground truth. Both labels and one-hot vectors + are supported. + :param task_labels: the int task label associated to the current + experience or the task labels vector showing the task label + for each pattern. + + :return: None. + """ + if len(true_y) != len(predicted_y): + raise ValueError("Size mismatch for true_y and predicted_y tensors") + + if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y): + raise ValueError("Size mismatch for true_y and task_labels tensors") + + true_y = torch.as_tensor(true_y) + predicted_y = torch.as_tensor(predicted_y) + + # Check if logits or labels + if len(predicted_y.shape) > 1: + # Logits -> transform to labels + predicted_y = torch.max(predicted_y, 1)[1] + + if len(true_y.shape) > 1: + # Logits -> transform to labels + true_y = torch.max(true_y, 1)[1] + + if isinstance(task_labels, int): + ( + true_positives, + false_positives, + true_negatives, + false_negatives, + ) = confusion(predicted_y, true_y) + + try: + tnr = true_negatives / (true_negatives + false_positives) + except ZeroDivisionError: + tnr = 1 + + self._mean_Specificity[task_labels].update(tnr, len(predicted_y)) + elif isinstance(task_labels, Tensor): + raise NotImplementedError + else: + raise ValueError( + f"Task label type: {type(task_labels)}, expected int/float or Tensor" + ) + + def result(self, task_label=None) -> Dict[int, float]: + """ + Retrieves the running Specificity. + + Calling this method will not change the internal state of the metric. + + :param task_label: if None, return the entire dictionary of specificities + for each task. Otherwise return the dictionary + `{task_label: Specificity}`. + :return: A dict of running specificities for each task label, + where each value is a float value between 0 and 1. + """ + assert task_label is None or isinstance(task_label, int) + if task_label is None: + return {k: v.result() for k, v in self._mean_Specificity.items()} + else: + return {task_label: self._mean_Specificity[task_label].result()} + + def reset(self, task_label=None) -> None: + """ + Resets the metric. + :param task_label: if None, reset the entire dictionary. + Otherwise, reset the value associated to `task_label`. + + :return: None. + """ + assert task_label is None or isinstance(task_label, int) + if task_label is None: + self._mean_Specificity = defaultdict(Mean) + else: + self._mean_Specificity[task_label].reset() + + +class SpecificityPluginMetric(GenericPluginMetric[float]): + """ + Base class for all specificities plugin metrics + """ + + def __init__(self, reset_at, emit_at, mode): + self._Specificity = Specificity() + super(SpecificityPluginMetric, self).__init__( + self._Specificity, reset_at=reset_at, emit_at=emit_at, mode=mode + ) + + def reset(self, strategy=None) -> None: + if self._reset_at == "stream" or strategy is None: + self._metric.reset() + else: + self._metric.reset(phase_and_task(strategy)[1]) + + def result(self, strategy=None) -> float: + if self._emit_at == "stream" or strategy is None: + return self._metric.result() + else: + return self._metric.result(phase_and_task(strategy)[1]) + + def update(self, strategy): + # task labels defined for each experience + task_labels = strategy.experience.task_labels + if len(task_labels) > 1: + # task labels defined for each pattern + task_labels = strategy.mb_task_id + else: + task_labels = task_labels[0] + self._Specificity.update(strategy.mb_output, strategy.mb_y, task_labels) + + +class MinibatchSpecificity(SpecificityPluginMetric): + """ + The minibatch plugin Specificity metric. + This metric only works at training time. + + This metric computes the average Specificity over patterns + from a single minibatch. + It reports the result after each iteration. + + If a more coarse-grained logging is needed, consider using + :class:`EpochSpecificity` instead. + """ + + def __init__(self): + """ + Creates an instance of the MinibatchSpecificity metric. + """ + super(MinibatchSpecificity, self).__init__( + reset_at="iteration", emit_at="iteration", mode="train" + ) + + def __str__(self): + return "Spec_MB" + + +class EpochSpecificity(SpecificityPluginMetric): + """ + The average Specificity over a single training epoch. + This plugin metric only works at training time. + + The Specificity will be logged after each training epoch by computing + the number of correctly predicted patterns during the epoch divided by + the overall number of patterns encountered in that epoch. + """ + + def __init__(self): + """ + Creates an instance of the EpochSpecificity metric. + """ + + super(EpochSpecificity, self).__init__( + reset_at="epoch", emit_at="epoch", mode="train" + ) + + def __str__(self): + return "Spec_Epoch" + + +class RunningEpochSpecificity(SpecificityPluginMetric): + """ + The average Specificity across all minibatches up to the current + epoch iteration. + This plugin metric only works at training time. + + At each iteration, this metric logs the Specificity averaged over all patterns + seen so far in the current epoch. + The metric resets its state after each training epoch. + """ + + def __init__(self): + """ + Creates an instance of the RunningEpochSpecificity metric. + """ + + super(RunningEpochSpecificity, self).__init__( + reset_at="epoch", emit_at="iteration", mode="train" + ) + + def __str__(self): + return "RunningSpec_Epoch" + + +class ExperienceSpecificity(SpecificityPluginMetric): + """ + At the end of each experience, this plugin metric reports + the average Specificity over all patterns seen in that experience. + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of ExperienceSpecificity metric + """ + super(ExperienceSpecificity, self).__init__( + reset_at="experience", emit_at="experience", mode="eval" + ) + + def __str__(self): + return "Spec_Exp" + + +class StreamSpecificity(SpecificityPluginMetric): + """ + At the end of the entire stream of experiences, this plugin metric + reports the average Specificity over all patterns seen in all experiences. + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of StreamSpecificity metric + """ + super(StreamSpecificity, self).__init__( + reset_at="stream", emit_at="stream", mode="eval" + ) + + def __str__(self): + return "Spec_Stream" + + +class TrainedExperienceSpecificity(SpecificityPluginMetric): + """ + At the end of each experience, this plugin metric reports the average + Specificity for only the experiences that the model has been trained on so far. + + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of TrainedExperienceSpecificity metric by first + constructing SpecificityPluginMetric + """ + super(TrainedExperienceSpecificity, self).__init__( + reset_at="stream", emit_at="stream", mode="eval" + ) + self._current_experience = 0 + + def after_training_exp(self, strategy) -> None: + self._current_experience = strategy.experience.current_experience + # Reset average after learning from a new experience + SpecificityPluginMetric.reset(self, strategy) + return SpecificityPluginMetric.after_training_exp(self, strategy) + + def update(self, strategy): + """ + Only update the Specificity with results from experiences that have been + trained on + """ + if strategy.experience.current_experience <= self._current_experience: + SpecificityPluginMetric.update(self, strategy) + + def __str__(self): + return "Specificity_On_Trained_Experiences" + + +def specificity_metrics( + *, + minibatch=False, + epoch=False, + epoch_running=False, + experience=False, + stream=False, + trained_experience=False, +) -> List[PluginMetric]: + """ + Helper method that can be used to obtain the desired set of + plugin metrics. + + :param minibatch: If True, will return a metric able to log + the minibatch Specificity at training time. + :param epoch: If True, will return a metric able to log + the epoch Specificity at training time. + :param epoch_running: If True, will return a metric able to log + the running epoch Specificity at training time. + :param experience: If True, will return a metric able to log + the Specificity on each evaluation experience. + :param stream: If True, will return a metric able to log + the Specificity averaged over the entire evaluation stream of experiences. + :param trained_experience: If True, will return a metric able to log + the average evaluation Specificity only for experiences that the + model has been trained on + + :return: A list of plugin metrics. + """ + + metrics = [] + if minibatch: + metrics.append(MinibatchSpecificity()) + + if epoch: + metrics.append(EpochSpecificity()) + + if epoch_running: + metrics.append(RunningEpochSpecificity()) + + if experience: + metrics.append(ExperienceSpecificity()) + + if stream: + metrics.append(StreamSpecificity()) + + if trained_experience: + metrics.append(TrainedExperienceSpecificity()) + + return metrics + + +class Precision(Metric[float]): + """ + The Precision metric. This is a standalone metric. + + The metric keeps a dictionary of <task_label, Precision value> pairs. + and update the values through a running average over multiple + <prediction, target> pairs of Tensors, provided incrementally. + The "prediction" and "target" tensors may contain plain labels or + one-hot/logit vectors. + + Each time `result` is called, this metric emits the average Precision + across all predictions made since the last `reset`. + + The reset method will bring the metric to its initial state. By default + this metric in its initial state will return an Precision value of 0. + """ + + def __init__(self): + """ + Creates an instance of the standalone Precision metric. + + By default this metric in its initial state will return a Precision + value of 0. The metric can be updated by using the `update` method + while the running Precision can be retrieved using the `result` method. + """ + super().__init__() + self._mean_Precision = defaultdict(Mean) + """ + The mean utility that will be used to store the running Precision + for each task label. + """ + + @torch.no_grad() + def update( + self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor] + ) -> None: + """ + Update the running Precision given the true and predicted labels. + Parameter `task_labels` is used to decide how to update the inner + dictionary: if Float, only the dictionary value related to that task + is updated. If Tensor, all the dictionary elements belonging to the + task labels will be updated. + + :param predicted_y: The model prediction. Both labels and logit vectors + are supported. + :param true_y: The ground truth. Both labels and one-hot vectors + are supported. + :param task_labels: the int task label associated to the current + experience or the task labels vector showing the task label + for each pattern. + + :return: None. + """ + if len(true_y) != len(predicted_y): + raise ValueError("Size mismatch for true_y and predicted_y tensors") + + if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y): + raise ValueError("Size mismatch for true_y and task_labels tensors") + + true_y = torch.as_tensor(true_y) + predicted_y = torch.as_tensor(predicted_y) + + # Check if logits or labels + if len(predicted_y.shape) > 1: + # Logits -> transform to labels + predicted_y = torch.max(predicted_y, 1)[1] + + if len(true_y.shape) > 1: + # Logits -> transform to labels + true_y = torch.max(true_y, 1)[1] + + if isinstance(task_labels, int): + ( + true_positives, + false_positives, + true_negatives, + false_negatives, + ) = confusion(predicted_y, true_y) + + try: + ppv = true_positives / (true_positives + false_positives) + except ZeroDivisionError: + ppv = 1 + + self._mean_Precision[task_labels].update(ppv, len(predicted_y)) + elif isinstance(task_labels, Tensor): + raise NotImplementedError + else: + raise ValueError( + f"Task label type: {type(task_labels)}, expected int/float or Tensor" + ) + + def result(self, task_label=None) -> Dict[int, float]: + """ + Retrieves the running Precision. + + Calling this method will not change the internal state of the metric. + + :param task_label: if None, return the entire dictionary of precisions + for each task. Otherwise return the dictionary + `{task_label: Precision}`. + :return: A dict of running precisions for each task label, + where each value is a float value between 0 and 1. + """ + assert task_label is None or isinstance(task_label, int) + if task_label is None: + return {k: v.result() for k, v in self._mean_Precision.items()} + else: + return {task_label: self._mean_Precision[task_label].result()} + + def reset(self, task_label=None) -> None: + """ + Resets the metric. + :param task_label: if None, reset the entire dictionary. + Otherwise, reset the value associated to `task_label`. + + :return: None. + """ + assert task_label is None or isinstance(task_label, int) + if task_label is None: + self._mean_Precision = defaultdict(Mean) + else: + self._mean_Precision[task_label].reset() + + +class PrecisionPluginMetric(GenericPluginMetric[float]): + """ + Base class for all precisions plugin metrics + """ + + def __init__(self, reset_at, emit_at, mode): + self._Precision = Precision() + super(PrecisionPluginMetric, self).__init__( + self._Precision, reset_at=reset_at, emit_at=emit_at, mode=mode + ) + + def reset(self, strategy=None) -> None: + if self._reset_at == "stream" or strategy is None: + self._metric.reset() + else: + self._metric.reset(phase_and_task(strategy)[1]) + + def result(self, strategy=None) -> float: + if self._emit_at == "stream" or strategy is None: + return self._metric.result() + else: + return self._metric.result(phase_and_task(strategy)[1]) + + def update(self, strategy): + # task labels defined for each experience + task_labels = strategy.experience.task_labels + if len(task_labels) > 1: + # task labels defined for each pattern + task_labels = strategy.mb_task_id + else: + task_labels = task_labels[0] + self._Precision.update(strategy.mb_output, strategy.mb_y, task_labels) + + +class MinibatchPrecision(PrecisionPluginMetric): + """ + The minibatch plugin Precision metric. + This metric only works at training time. + + This metric computes the average Precision over patterns + from a single minibatch. + It reports the result after each iteration. + + If a more coarse-grained logging is needed, consider using + :class:`EpochPrecision` instead. + """ + + def __init__(self): + """ + Creates an instance of the MinibatchPrecision metric. + """ + super(MinibatchPrecision, self).__init__( + reset_at="iteration", emit_at="iteration", mode="train" + ) + + def __str__(self): + return "Prec_MB" + + +class EpochPrecision(PrecisionPluginMetric): + """ + The average Precision over a single training epoch. + This plugin metric only works at training time. + + The Precision will be logged after each training epoch by computing + the number of correctly predicted patterns during the epoch divided by + the overall number of patterns encountered in that epoch. + """ + + def __init__(self): + """ + Creates an instance of the EpochPrecision metric. + """ + + super(EpochPrecision, self).__init__( + reset_at="epoch", emit_at="epoch", mode="train" + ) + + def __str__(self): + return "Prec_Epoch" + + +class RunningEpochPrecision(PrecisionPluginMetric): + """ + The average Precision across all minibatches up to the current + epoch iteration. + This plugin metric only works at training time. + + At each iteration, this metric logs the Precision averaged over all patterns + seen so far in the current epoch. + The metric resets its state after each training epoch. + """ + + def __init__(self): + """ + Creates an instance of the RunningEpochPrecision metric. + """ + + super(RunningEpochPrecision, self).__init__( + reset_at="epoch", emit_at="iteration", mode="train" + ) + + def __str__(self): + return "RunningPrec_Epoch" + + +class ExperiencePrecision(PrecisionPluginMetric): + """ + At the end of each experience, this plugin metric reports + the average Precision over all patterns seen in that experience. + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of ExperiencePrecision metric + """ + super(ExperiencePrecision, self).__init__( + reset_at="experience", emit_at="experience", mode="eval" + ) + + def __str__(self): + return "Prec_Exp" + + +class StreamPrecision(PrecisionPluginMetric): + """ + At the end of the entire stream of experiences, this plugin metric + reports the average Precision over all patterns seen in all experiences. + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of StreamPrecision metric + """ + super(StreamPrecision, self).__init__( + reset_at="stream", emit_at="stream", mode="eval" + ) + + def __str__(self): + return "Prec_Stream" + + +class TrainedExperiencePrecision(PrecisionPluginMetric): + """ + At the end of each experience, this plugin metric reports the average + Precision for only the experiences that the model has been trained on so far. + + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of TrainedExperiencePrecision metric by first + constructing PrecisionPluginMetric + """ + super(TrainedExperiencePrecision, self).__init__( + reset_at="stream", emit_at="stream", mode="eval" + ) + self._current_experience = 0 + + def after_training_exp(self, strategy) -> None: + self._current_experience = strategy.experience.current_experience + # Reset average after learning from a new experience + PrecisionPluginMetric.reset(self, strategy) + return PrecisionPluginMetric.after_training_exp(self, strategy) + + def update(self, strategy): + """ + Only update the Precision with results from experiences that have been + trained on + """ + if strategy.experience.current_experience <= self._current_experience: + PrecisionPluginMetric.update(self, strategy) + + def __str__(self): + return "Precision_On_Trained_Experiences" + + +def precision_metrics( + *, + minibatch=False, + epoch=False, + epoch_running=False, + experience=False, + stream=False, + trained_experience=False, +) -> List[PluginMetric]: + """ + Helper method that can be used to obtain the desired set of + plugin metrics. + + :param minibatch: If True, will return a metric able to log + the minibatch Precision at training time. + :param epoch: If True, will return a metric able to log + the epoch Precision at training time. + :param epoch_running: If True, will return a metric able to log + the running epoch Precision at training time. + :param experience: If True, will return a metric able to log + the Precision on each evaluation experience. + :param stream: If True, will return a metric able to log + the Precision averaged over the entire evaluation stream of experiences. + :param trained_experience: If True, will return a metric able to log + the average evaluation Precision only for experiences that the + model has been trained on + + :return: A list of plugin metrics. + """ + + metrics = [] + if minibatch: + metrics.append(MinibatchPrecision()) + + if epoch: + metrics.append(EpochPrecision()) + + if epoch_running: + metrics.append(RunningEpochPrecision()) + + if experience: + metrics.append(ExperiencePrecision()) + + if stream: + metrics.append(StreamPrecision()) + + if trained_experience: + metrics.append(TrainedExperiencePrecision()) + + return metrics + + +class AUPRC(Metric[float]): + """ + The AUPRC metric. This is a standalone metric. + + The metric keeps a dictionary of <task_label, AUPRC value> pairs. + and update the values through a running average over multiple + <prediction, target> pairs of Tensors, provided incrementally. + The "prediction" and "target" tensors may contain plain labels or + one-hot/logit vectors. + + Each time `result` is called, this metric emits the average AUPRC + across all predictions made since the last `reset`. + + The reset method will bring the metric to its initial state. By default + this metric in its initial state will return an AUPRC value of 0. + """ + + def __init__(self): + """ + Creates an instance of the standalone AUPRC metric. + + By default this metric in its initial state will return a AUPRC + value of 0. The metric can be updated by using the `update` method + while the running AUPRC can be retrieved using the `result` method. + """ + super().__init__() + self._mean_AUPRC = defaultdict(Mean) + """ + The mean utility that will be used to store the running AUPRC + for each task label. + """ + + @torch.no_grad() + def update( + self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor] + ) -> None: + """ + Update the running AUPRC given the true and predicted labels. + Parameter `task_labels` is used to decide how to update the inner + dictionary: if Float, only the dictionary value related to that task + is updated. If Tensor, all the dictionary elements belonging to the + task labels will be updated. + + :param predicted_y: The model prediction. Both labels and logit vectors + are supported. + :param true_y: The ground truth. Both labels and one-hot vectors + are supported. + :param task_labels: the int task label associated to the current + experience or the task labels vector showing the task label + for each pattern. + + :return: None. + """ + if len(true_y) != len(predicted_y): + raise ValueError("Size mismatch for true_y and predicted_y tensors") + + if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y): + raise ValueError("Size mismatch for true_y and task_labels tensors") + + true_y = torch.as_tensor(true_y) + predicted_y = torch.as_tensor(predicted_y) + + assert len(predicted_y.size()) == 2, ( + "Predictions need to be logits or scores, not labels" + ) + + if len(true_y.shape) > 1: + # Logits -> transform to labels + true_y = torch.max(true_y, 1)[1] + + scores = predicted_y[arange(len(true_y)), true_y] + + with np.errstate(divide="ignore", invalid="ignore"): + average_precision_score_val = average_precision_score( + true_y.cpu(), scores.cpu() + ) + + if np.isnan(average_precision_score_val): + average_precision_score_val = 0 + + if isinstance(task_labels, int): + self._mean_AUPRC[task_labels].update( + average_precision_score_val, len(predicted_y) + ) + elif isinstance(task_labels, Tensor): + raise NotImplementedError + else: + raise ValueError( + f"Task label type: {type(task_labels)}, expected int/float or Tensor" + ) + + def result(self, task_label=None) -> Dict[int, float]: + """ + Retrieves the running AUPRC. + + Calling this method will not change the internal state of the metric. + + :param task_label: if None, return the entire dictionary of AUPRCs + for each task. Otherwise return the dictionary + `{task_label: AUPRC}`. + :return: A dict of running AUPRCs for each task label, + where each value is a float value between 0 and 1. + """ + assert task_label is None or isinstance(task_label, int) + if task_label is None: + return {k: v.result() for k, v in self._mean_AUPRC.items()} + else: + return {task_label: self._mean_AUPRC[task_label].result()} + + def reset(self, task_label=None) -> None: + """ + Resets the metric. + :param task_label: if None, reset the entire dictionary. + Otherwise, reset the value associated to `task_label`. + + :return: None. + """ + assert task_label is None or isinstance(task_label, int) + if task_label is None: + self._mean_AUPRC = defaultdict(Mean) + else: + self._mean_AUPRC[task_label].reset() + + +class AUPRCPluginMetric(GenericPluginMetric[float]): + """ + Base class for all AUPRCs plugin metrics + """ + + def __init__(self, reset_at, emit_at, mode): + self._AUPRC = AUPRC() + super(AUPRCPluginMetric, self).__init__( + self._AUPRC, reset_at=reset_at, emit_at=emit_at, mode=mode + ) + + def reset(self, strategy=None) -> None: + if self._reset_at == "stream" or strategy is None: + self._metric.reset() + else: + self._metric.reset(phase_and_task(strategy)[1]) + + def result(self, strategy=None) -> float: + if self._emit_at == "stream" or strategy is None: + return self._metric.result() + else: + return self._metric.result(phase_and_task(strategy)[1]) + + def update(self, strategy): + # task labels defined for each experience + task_labels = strategy.experience.task_labels + if len(task_labels) > 1: + # task labels defined for each pattern + task_labels = strategy.mb_task_id + else: + task_labels = task_labels[0] + self._AUPRC.update(strategy.mb_output, strategy.mb_y, task_labels) + + +class MinibatchAUPRC(AUPRCPluginMetric): + """ + The minibatch plugin AUPRC metric. + This metric only works at training time. + + This metric computes the average AUPRC over patterns + from a single minibatch. + It reports the result after each iteration. + + If a more coarse-grained logging is needed, consider using + :class:`EpochAUPRC` instead. + """ + + def __init__(self): + """ + Creates an instance of the MinibatchAUPRC metric. + """ + super(MinibatchAUPRC, self).__init__( + reset_at="iteration", emit_at="iteration", mode="train" + ) + + def __str__(self): + return "AUPRC_MB" + + +class EpochAUPRC(AUPRCPluginMetric): + """ + The average AUPRC over a single training epoch. + This plugin metric only works at training time. + + The AUPRC will be logged after each training epoch by computing + the number of correctly predicted patterns during the epoch divided by + the overall number of patterns encountered in that epoch. + """ + + def __init__(self): + """ + Creates an instance of the EpochAUPRC metric. + """ + + super(EpochAUPRC, self).__init__( + reset_at="epoch", emit_at="epoch", mode="train" + ) + + def __str__(self): + return "AUPRC_Epoch" + + +class RunningEpochAUPRC(AUPRCPluginMetric): + """ + The average AUPRC across all minibatches up to the current + epoch iteration. + This plugin metric only works at training time. + + At each iteration, this metric logs the AUPRC averaged over all patterns + seen so far in the current epoch. + The metric resets its state after each training epoch. + """ + + def __init__(self): + """ + Creates an instance of the RunningEpochAUPRC metric. + """ + + super(RunningEpochAUPRC, self).__init__( + reset_at="epoch", emit_at="iteration", mode="train" + ) + + def __str__(self): + return "RunningAUPRC_Epoch" + + +class ExperienceAUPRC(AUPRCPluginMetric): + """ + At the end of each experience, this plugin metric reports + the average AUPRC over all patterns seen in that experience. + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of ExperienceAUPRC metric + """ + super(ExperienceAUPRC, self).__init__( + reset_at="experience", emit_at="experience", mode="eval" + ) + + def __str__(self): + return "AUPRC_Exp" + + +class StreamAUPRC(AUPRCPluginMetric): + """ + At the end of the entire stream of experiences, this plugin metric + reports the average AUPRC over all patterns seen in all experiences. + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of StreamAUPRC metric + """ + super(StreamAUPRC, self).__init__( + reset_at="stream", emit_at="stream", mode="eval" + ) + + def __str__(self): + return "AUPRC_Stream" + + +class TrainedExperienceAUPRC(AUPRCPluginMetric): + """ + At the end of each experience, this plugin metric reports the average + AUPRC for only the experiences that the model has been trained on so far. + + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of TrainedExperienceAUPRC metric by first + constructing AUPRCPluginMetric + """ + super(TrainedExperienceAUPRC, self).__init__( + reset_at="stream", emit_at="stream", mode="eval" + ) + self._current_experience = 0 + + def after_training_exp(self, strategy) -> None: + self._current_experience = strategy.experience.current_experience + # Reset average after learning from a new experience + AUPRCPluginMetric.reset(self, strategy) + return AUPRCPluginMetric.after_training_exp(self, strategy) + + def update(self, strategy): + """ + Only update the AUPRC with results from experiences that have been + trained on + """ + if strategy.experience.current_experience <= self._current_experience: + AUPRCPluginMetric.update(self, strategy) + + def __str__(self): + return "AUPRC_On_Trained_Experiences" + + +def auprc_metrics( + *, + minibatch=False, + epoch=False, + epoch_running=False, + experience=False, + stream=False, + trained_experience=False, +) -> List[PluginMetric]: + """ + Helper method that can be used to obtain the desired set of + plugin metrics. + + :param minibatch: If True, will return a metric able to log + the minibatch AUPRC at training time. + :param epoch: If True, will return a metric able to log + the epoch AUPRC at training time. + :param epoch_running: If True, will return a metric able to log + the running epoch AUPRC at training time. + :param experience: If True, will return a metric able to log + the AUPRC on each evaluation experience. + :param stream: If True, will return a metric able to logAUPRCperiences. + :param trained_experience: If True, will return a metric able to log + the average evaluation AUPRC only for experiences that the + model has been trained on + + :return: A list of plugin metrics. + """ + + metrics = [] + if minibatch: + metrics.append(MinibatchAUPRC()) + + if epoch: + metrics.append(EpochAUPRC()) + + if epoch_running: + metrics.append(RunningEpochAUPRC()) + + if experience: + metrics.append(ExperienceAUPRC()) + + if stream: + metrics.append(StreamAUPRC()) + + if trained_experience: + metrics.append(TrainedExperienceAUPRC()) + + return metrics + + +class ROCAUC(Metric[float]): + """ + The ROCAUC metric. This is a standalone metric. + + The metric keeps a dictionary of <task_label, ROCAUC value> pairs. + and update the values through a running average over multiple + <prediction, target> pairs of Tensors, provided incrementally. + The "prediction" and "target" tensors may contain plain labels or + one-hot/logit vectors. + + Each time `result` is called, this metric emits the average ROCAUC + across all predictions made since the last `reset`. + + The reset method will bring the metric to its initial state. By default + this metric in its initial state will return an ROCAUC value of 0. + """ + + def __init__(self): + """ + Creates an instance of the standalone ROCAUC metric. + + By default this metric in its initial state will return a ROCAUC + value of 0. The metric can be updated by using the `update` method + while the running ROCAUC can be retrieved using the `result` method. + """ + super().__init__() + self._mean_ROCAUC = defaultdict(Mean) + """ + The mean utility that will be used to store the running ROCAUC + for each task label. + """ + + @torch.no_grad() + def update( + self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor] + ) -> None: + """ + Update the running ROCAUC given the true and predicted labels. + Parameter `task_labels` is used to decide how to update the inner + dictionary: if Float, only the dictionary value related to that task + is updated. If Tensor, all the dictionary elements belonging to the + task labels will be updated. + + :param predicted_y: The model prediction. Both labels and logit vectors + are supported. + :param true_y: The ground truth. Both labels and one-hot vectors + are supported. + :param task_labels: the int task label associated to the current + experience or the task labels vector showing the task label + for each pattern. + + :return: None. + """ + if len(true_y) != len(predicted_y): + raise ValueError("Size mismatch for true_y and predicted_y tensors") + + if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y): + raise ValueError("Size mismatch for true_y and task_labels tensors") + + true_y = torch.as_tensor(true_y) + predicted_y = torch.as_tensor(predicted_y) + + assert len(predicted_y.size()) == 2, ( + "Predictions need to be logits or scores, not labels" + ) + + if len(true_y.shape) > 1: + # Logits -> transform to labels + true_y = torch.max(true_y, 1)[1] + + scores = predicted_y[arange(len(true_y)), true_y] + + try: + roc_auc_score_val = roc_auc_score(true_y.cpu(), scores.cpu()) + except ValueError: + roc_auc_score_val = 1 + + if isinstance(task_labels, int): + self._mean_ROCAUC[task_labels].update(roc_auc_score_val, len(predicted_y)) + elif isinstance(task_labels, Tensor): + raise NotImplementedError + else: + raise ValueError( + f"Task label type: {type(task_labels)}, expected int/float or Tensor" + ) + + def result(self, task_label=None) -> Dict[int, float]: + """ + Retrieves the running ROCAUC. + + Calling this method will not change the internal state of the metric. + + :param task_label: if None, return the entire dictionary of ROCAUCs + for each task. Otherwise return the dictionary + `{task_label: ROCAUC}`. + :return: A dict of running ROCAUCs for each task label, + where each value is a float value between 0 and 1. + """ + assert task_label is None or isinstance(task_label, int) + if task_label is None: + return {k: v.result() for k, v in self._mean_ROCAUC.items()} + else: + return {task_label: self._mean_ROCAUC[task_label].result()} + + def reset(self, task_label=None) -> None: + """ + Resets the metric. + :param task_label: if None, reset the entire dictionary. + Otherwise, reset the value associated to `task_label`. + + :return: None. + """ + assert task_label is None or isinstance(task_label, int) + if task_label is None: + self._mean_ROCAUC = defaultdict(Mean) + else: + self._mean_ROCAUC[task_label].reset() + + +class ROCAUCPluginMetric(GenericPluginMetric[float]): + """ + Base class for all ROCAUCs plugin metrics + """ + + def __init__(self, reset_at, emit_at, mode): + self._ROCAUC = ROCAUC() + super(ROCAUCPluginMetric, self).__init__( + self._ROCAUC, reset_at=reset_at, emit_at=emit_at, mode=mode + ) + + def reset(self, strategy=None) -> None: + if self._reset_at == "stream" or strategy is None: + self._metric.reset() + else: + self._metric.reset(phase_and_task(strategy)[1]) + + def result(self, strategy=None) -> float: + if self._emit_at == "stream" or strategy is None: + return self._metric.result() + else: + return self._metric.result(phase_and_task(strategy)[1]) + + def update(self, strategy): + # task labels defined for each experience + task_labels = strategy.experience.task_labels + if len(task_labels) > 1: + # task labels defined for each pattern + task_labels = strategy.mb_task_id + else: + task_labels = task_labels[0] + self._ROCAUC.update(strategy.mb_output, strategy.mb_y, task_labels) + + +class MinibatchROCAUC(ROCAUCPluginMetric): + """ + The minibatch plugin ROCAUC metric. + This metric only works at training time. + + This metric computes the average ROCAUC over patterns + from a single minibatch. + It reports the result after each iteration. + + If a more coarse-grained logging is needed, consider using + :class:`EpochROCAUC` instead. + """ + + def __init__(self): + """ + Creates an instance of the MinibatchROCAUC metric. + """ + super(MinibatchROCAUC, self).__init__( + reset_at="iteration", emit_at="iteration", mode="train" + ) + + def __str__(self): + return "ROCAUC_MB" + + +class EpochROCAUC(ROCAUCPluginMetric): + """ + The average ROCAUC over a single training epoch. + This plugin metric only works at training time. + + The ROCAUC will be logged after each training epoch by computing + the number of correctly predicted patterns during the epoch divided by + the overall number of patterns encountered in that epoch. + """ + + def __init__(self): + """ + Creates an instance of the EpochROCAUC metric. + """ + + super(EpochROCAUC, self).__init__( + reset_at="epoch", emit_at="epoch", mode="train" + ) + + def __str__(self): + return "ROCAUC_Epoch" + + +class RunningEpochROCAUC(ROCAUCPluginMetric): + """ + The average ROCAUC across all minibatches up to the current + epoch iteration. + This plugin metric only works at training time. + + At each iteration, this metric logs the ROCAUC averaged over all patterns + seen so far in the current epoch. + The metric resets its state after each training epoch. + """ + + def __init__(self): + """ + Creates an instance of the RunningEpochROCAUC metric. + """ + + super(RunningEpochROCAUC, self).__init__( + reset_at="epoch", emit_at="iteration", mode="train" + ) + + def __str__(self): + return "RunningROCAUC_Epoch" + + +class ExperienceROCAUC(ROCAUCPluginMetric): + """ + At the end of each experience, this plugin metric reports + the average ROCAUC over all patterns seen in that experience. + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of ExperienceROCAUC metric + """ + super(ExperienceROCAUC, self).__init__( + reset_at="experience", emit_at="experience", mode="eval" + ) + + def __str__(self): + return "ROCAUC_Exp" + + +class StreamROCAUC(ROCAUCPluginMetric): + """ + At the end of the entire stream of experiences, this plugin metric + reports the average ROCAUC over all patterns seen in all experiences. + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of StreamROCAUC metric + """ + super(StreamROCAUC, self).__init__( + reset_at="stream", emit_at="stream", mode="eval" + ) + + def __str__(self): + return "ROCAUC_Stream" + + +class TrainedExperienceROCAUC(ROCAUCPluginMetric): + """ + At the end of each experience, this plugin metric reports the average + ROCAUC for only the experiences that the model has been trained on so far. + + This metric only works at eval time. + """ + + def __init__(self): + """ + Creates an instance of TrainedExperienceROCAUC metric by first + constructing ROCAUCPluginMetric + """ + super(TrainedExperienceROCAUC, self).__init__( + reset_at="stream", emit_at="stream", mode="eval" + ) + self._current_experience = 0 + + def after_training_exp(self, strategy) -> None: + self._current_experience = strategy.experience.current_experience + # Reset average after learning from a new experience + ROCAUCPluginMetric.reset(self, strategy) + return ROCAUCPluginMetric.after_training_exp(self, strategy) + + def update(self, strategy): + """ + Only update the ROCAUC with results from experiences that have been + trained on + """ + if strategy.experience.current_experience <= self._current_experience: + ROCAUCPluginMetric.update(self, strategy) + + def __str__(self): + return "ROCAUC_On_Trained_Experiences" + + +def rocauc_metrics( + *, + minibatch=False, + epoch=False, + epoch_running=False, + experience=False, + stream=False, + trained_experience=False, +) -> List[PluginMetric]: + """ + Helper method that can be used to obtain the desired set of + plugin metrics. + + :param minibatch: If True, will return a metric able to log + the minibatch ROCAUC at training time. + :param epoch: If True, will return a metric able to log + the epoch ROCAUC at training time. + :param epoch_running: If True, will return a metric able to log + the running epoch ROCAUC at training time. + :param experience: If True, will return a metric able to log + the ROCAUC on each evaluation experience. + :param stream: If True, will return a metric able to logROCAUCperiences. + :param trained_experience: If True, will return a metric able to log + the average evaluation ROCAUC only for experiences that the + model has been trained on + + :return: A list of plugin metrics. + """ + + metrics = [] + if minibatch: + metrics.append(MinibatchROCAUC()) + + if epoch: + metrics.append(EpochROCAUC()) + + if epoch_running: + metrics.append(RunningEpochROCAUC()) + + if experience: + metrics.append(ExperienceROCAUC()) + + if stream: + metrics.append(StreamROCAUC()) + + if trained_experience: + metrics.append(TrainedExperienceROCAUC()) + + return metrics + + +__all__ = [ + "BalancedAccuracy", + "MinibatchBalancedAccuracy", + "EpochBalancedAccuracy", + "RunningEpochBalancedAccuracy", + "ExperienceBalancedAccuracy", + "StreamBalancedAccuracy", + "TrainedExperienceBalancedAccuracy", + "balancedaccuracy_metrics", + "Sensitivity", + "MinibatchSensitivity", + "EpochSensitivity", + "RunningEpochSensitivity", + "ExperienceSensitivity", + "StreamSensitivity", + "TrainedExperienceSensitivity", + "sensitivity_metrics", + "Specificity", + "MinibatchSpecificity", + "EpochSpecificity", + "RunningEpochSpecificity", + "ExperienceSpecificity", + "StreamSpecificity", + "TrainedExperienceSpecificity", + "specificity_metrics", + "Precision", + "MinibatchPrecision", + "EpochPrecision", + "RunningEpochPrecision", + "ExperiencePrecision", + "StreamPrecision", + "TrainedExperiencePrecision", + "precision_metrics", + "AUPRC", + "MinibatchAUPRC", + "EpochAUPRC", + "RunningEpochAUPRC", + "ExperienceAUPRC", + "StreamAUPRC", + "TrainedExperienceAUPRC", + "auprc_metrics", + "ROCAUC", + "MinibatchROCAUC", + "EpochROCAUC", + "RunningEpochROCAUC", + "ExperienceROCAUC", + "StreamROCAUC", + "TrainedExperienceROCAUC", + "rocauc_metrics", +]