"""
Custom binary prediction metrics using Avalanche
https://github.com/ContinualAI/avalanche/blob/master/notebooks/from-zero-to-hero-tutorial/05_evaluation.ipynb
"""
from typing import List, Union, Dict
from collections import defaultdict
import torch
import numpy as np
from torch import Tensor, arange
from avalanche.evaluation import Metric, PluginMetric, GenericPluginMetric
from avalanche.evaluation.metrics.mean import Mean
from avalanche.evaluation.metric_utils import phase_and_task
from sklearn.metrics import average_precision_score, roc_auc_score
def confusion(prediction, truth):
"""Returns the confusion matrix for the values in the `prediction` and `truth`
tensors, i.e. the amount of positions where the values of `prediction`
and `truth` are
- 1 and 1 (True Positive)
- 1 and 0 (False Positive)
- 0 and 0 (True Negative)
- 0 and 1 (False Negative)
Source: https://gist.github.com/the-bass/cae9f3976866776dea17a5049013258d
"""
confusion_vector = prediction / truth
# Element-wise division of the 2 tensors returns a new tensor which holds a
# unique value for each case:
# 1 where prediction and truth are 1 (True Positive)
# inf where prediction is 1 and truth is 0 (False Positive)
# nan where prediction and truth are 0 (True Negative)
# 0 where prediction is 0 and truth is 1 (False Negative)
true_positives = torch.sum(confusion_vector == 1).item()
false_positives = torch.sum(confusion_vector == float("inf")).item()
true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
false_negatives = torch.sum(confusion_vector == 0).item()
return true_positives, false_positives, true_negatives, false_negatives
# https://github.com/ContinualAI/avalanche/blob/master/avalanche/evaluation/metrics/mean_scores.py
# Use above for AUPRC etc templates.
class BalancedAccuracy(Metric[float]):
"""
The BalancedAccuracy metric. This is a standalone metric.
The metric keeps a dictionary of <task_label, balancedaccuracy value> pairs.
and update the values through a running average over multiple
<prediction, target> pairs of Tensors, provided incrementally.
The "prediction" and "target" tensors may contain plain labels or
one-hot/logit vectors.
Each time `result` is called, this metric emits the average balancedaccuracy
across all predictions made since the last `reset`.
The reset method will bring the metric to its initial state. By default
this metric in its initial state will return an balancedaccuracy value of 0.
"""
def __init__(self):
"""
Creates an instance of the standalone BalancedAccuracy metric.
By default this metric in its initial state will return an balancedaccuracy
value of 0. The metric can be updated by using the `update` method
while the running balancedaccuracy can be retrieved using the `result` method.
"""
super().__init__()
self._mean_balancedaccuracy = defaultdict(Mean)
"""
The mean utility that will be used to store the running balancedaccuracy
for each task label.
"""
@torch.no_grad()
def update(
self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor]
) -> None:
"""
Update the running balancedaccuracy given the true and predicted labels.
Parameter `task_labels` is used to decide how to update the inner
dictionary: if Float, only the dictionary value related to that task
is updated. If Tensor, all the dictionary elements belonging to the
task labels will be updated.
:param predicted_y: The model prediction. Both labels and logit vectors
are supported.
:param true_y: The ground truth. Both labels and one-hot vectors
are supported.
:param task_labels: the int task label associated to the current
experience or the task labels vector showing the task label
for each pattern.
:return: None.
"""
if len(true_y) != len(predicted_y):
raise ValueError("Size mismatch for true_y and predicted_y tensors")
if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
raise ValueError("Size mismatch for true_y and task_labels tensors")
true_y = torch.as_tensor(true_y)
predicted_y = torch.as_tensor(predicted_y)
# Check if logits or labels
if len(predicted_y.shape) > 1:
# Logits -> transform to labels
predicted_y = torch.max(predicted_y, 1)[1]
if len(true_y.shape) > 1:
# Logits -> transform to labels
true_y = torch.max(true_y, 1)[1]
if isinstance(task_labels, int):
(
true_positives,
false_positives,
true_negatives,
false_negatives,
) = confusion(predicted_y, true_y)
try:
tpr = true_positives / (true_positives + false_negatives)
except ZeroDivisionError:
tpr = 1
try:
tnr = true_negatives / (true_negatives + false_positives)
except ZeroDivisionError:
tnr = 1
self._mean_balancedaccuracy[task_labels].update(
(tpr + tnr) / 2, len(predicted_y)
)
elif isinstance(task_labels, Tensor):
raise NotImplementedError
else:
raise ValueError(
f"Task label type: {type(task_labels)}, expected int/float or Tensor"
)
def result(self, task_label=None) -> Dict[int, float]:
"""
Retrieves the running balancedaccuracy.
Calling this method will not change the internal state of the metric.
:param task_label: if None, return the entire dictionary of balanced accuracies
for each task. Otherwise return the dictionary
`{task_label: balancedaccuracy}`.
:return: A dict of running balanced accuracies for each task label,
where each value is a float value between 0 and 1.
"""
assert task_label is None or isinstance(task_label, int)
if task_label is None:
return {k: v.result() for k, v in self._mean_balancedaccuracy.items()}
else:
return {task_label: self._mean_balancedaccuracy[task_label].result()}
def reset(self, task_label=None) -> None:
"""
Resets the metric.
:param task_label: if None, reset the entire dictionary.
Otherwise, reset the value associated to `task_label`.
:return: None.
"""
assert task_label is None or isinstance(task_label, int)
if task_label is None:
self._mean_balancedaccuracy = defaultdict(Mean)
else:
self._mean_balancedaccuracy[task_label].reset()
class BalancedAccuracyPluginMetric(GenericPluginMetric[float]):
"""
Base class for all balanced accuracies plugin metrics
"""
def __init__(self, reset_at, emit_at, mode):
self._balancedaccuracy = BalancedAccuracy()
super(BalancedAccuracyPluginMetric, self).__init__(
self._balancedaccuracy, reset_at=reset_at, emit_at=emit_at, mode=mode
)
def reset(self, strategy=None) -> None:
if self._reset_at == "stream" or strategy is None:
self._metric.reset()
else:
self._metric.reset(phase_and_task(strategy)[1])
def result(self, strategy=None) -> float:
if self._emit_at == "stream" or strategy is None:
return self._metric.result()
else:
return self._metric.result(phase_and_task(strategy)[1])
def update(self, strategy):
# task labels defined for each experience
task_labels = strategy.experience.task_labels
if len(task_labels) > 1:
# task labels defined for each pattern
task_labels = strategy.mb_task_id
else:
task_labels = task_labels[0]
self._balancedaccuracy.update(strategy.mb_output, strategy.mb_y, task_labels)
class MinibatchBalancedAccuracy(BalancedAccuracyPluginMetric):
"""
The minibatch plugin balancedaccuracy metric.
This metric only works at training time.
This metric computes the average balancedaccuracy over patterns
from a single minibatch.
It reports the result after each iteration.
If a more coarse-grained logging is needed, consider using
:class:`EpochBalancedAccuracy` instead.
"""
def __init__(self):
"""
Creates an instance of the MinibatchBalancedAccuracy metric.
"""
super(MinibatchBalancedAccuracy, self).__init__(
reset_at="iteration", emit_at="iteration", mode="train"
)
def __str__(self):
return "BalAcc_MB"
class EpochBalancedAccuracy(BalancedAccuracyPluginMetric):
"""
The average balancedaccuracy over a single training epoch.
This plugin metric only works at training time.
The balancedaccuracy will be logged after each training epoch by computing
the number of correctly predicted patterns during the epoch divided by
the overall number of patterns encountered in that epoch.
"""
def __init__(self):
"""
Creates an instance of the EpochBalancedAccuracy metric.
"""
super(EpochBalancedAccuracy, self).__init__(
reset_at="epoch", emit_at="epoch", mode="train"
)
def __str__(self):
return "BalAcc_Epoch"
class RunningEpochBalancedAccuracy(BalancedAccuracyPluginMetric):
"""
The average balancedaccuracy across all minibatches up to the current
epoch iteration.
This plugin metric only works at training time.
At each iteration, this metric logs the balancedaccuracy averaged over all patterns
seen so far in the current epoch.
The metric resets its state after each training epoch.
"""
def __init__(self):
"""
Creates an instance of the RunningEpochBalancedAccuracy metric.
"""
super(RunningEpochBalancedAccuracy, self).__init__(
reset_at="epoch", emit_at="iteration", mode="train"
)
def __str__(self):
return "RunningBalAcc_Epoch"
class ExperienceBalancedAccuracy(BalancedAccuracyPluginMetric):
"""
At the end of each experience, this plugin metric reports
the average balancedaccuracy over all patterns seen in that experience.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of ExperienceBalancedAccuracy metric
"""
super(ExperienceBalancedAccuracy, self).__init__(
reset_at="experience", emit_at="experience", mode="eval"
)
def __str__(self):
return "BalAcc_Exp"
class StreamBalancedAccuracy(BalancedAccuracyPluginMetric):
"""
At the end of the entire stream of experiences, this plugin metric
reports the average balancedaccuracy over all patterns seen in all experiences.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of StreamBalancedAccuracy metric
"""
super(StreamBalancedAccuracy, self).__init__(
reset_at="stream", emit_at="stream", mode="eval"
)
def __str__(self):
return "BalAcc_Stream"
class TrainedExperienceBalancedAccuracy(BalancedAccuracyPluginMetric):
"""
At the end of each experience, this plugin metric reports the average
balancedaccuracy for only the experiences that the model has been trained on so far.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of TrainedExperienceBalancedAccuracy metric by first
constructing BalancedAccuracyPluginMetric
"""
super(TrainedExperienceBalancedAccuracy, self).__init__(
reset_at="stream", emit_at="stream", mode="eval"
)
self._current_experience = 0
def after_training_exp(self, strategy) -> None:
self._current_experience = strategy.experience.current_experience
# Reset average after learning from a new experience
BalancedAccuracyPluginMetric.reset(self, strategy)
return BalancedAccuracyPluginMetric.after_training_exp(self, strategy)
def update(self, strategy):
"""
Only update the balancedaccuracy with results from experiences that have been
trained on
"""
if strategy.experience.current_experience <= self._current_experience:
BalancedAccuracyPluginMetric.update(self, strategy)
def __str__(self):
return "BalancedAccuracy_On_Trained_Experiences"
def balancedaccuracy_metrics(
*,
minibatch=False,
epoch=False,
epoch_running=False,
experience=False,
stream=False,
trained_experience=False,
) -> List[PluginMetric]:
"""
Helper method that can be used to obtain the desired set of
plugin metrics.
:param minibatch: If True, will return a metric able to log
the minibatch balancedaccuracy at training time.
:param epoch: If True, will return a metric able to log
the epoch balancedaccuracy at training time.
:param epoch_running: If True, will return a metric able to log
the running epoch balancedaccuracy at training time.
:param experience: If True, will return a metric able to log
the balancedaccuracy on each evaluation experience.
:param stream: If True, will return a metric able to log
the balancedaccuracy averaged over the entire evaluation stream of experiences.
:param trained_experience: If True, will return a metric able to log
the average evaluation balancedaccuracy only for experiences that the
model has been trained on
:return: A list of plugin metrics.
"""
metrics = []
if minibatch:
metrics.append(MinibatchBalancedAccuracy())
if epoch:
metrics.append(EpochBalancedAccuracy())
if epoch_running:
metrics.append(RunningEpochBalancedAccuracy())
if experience:
metrics.append(ExperienceBalancedAccuracy())
if stream:
metrics.append(StreamBalancedAccuracy())
if trained_experience:
metrics.append(TrainedExperienceBalancedAccuracy())
return metrics
class Sensitivity(Metric[float]):
"""
The Sensitivity metric. This is a standalone metric.
The metric keeps a dictionary of <task_label, Sensitivity value> pairs.
and update the values through a running average over multiple
<prediction, target> pairs of Tensors, provided incrementally.
The "prediction" and "target" tensors may contain plain labels or
one-hot/logit vectors.
Each time `result` is called, this metric emits the average Sensitivity
across all predictions made since the last `reset`.
The reset method will bring the metric to its initial state. By default
this metric in its initial state will return an Sensitivity value of 0.
"""
def __init__(self):
"""
Creates an instance of the standalone Sensitivity metric.
By default this metric in its initial state will return an Sensitivity
value of 0. The metric can be updated by using the `update` method
while the running Sensitivity can be retrieved using the `result` method.
"""
super().__init__()
self._mean_Sensitivity = defaultdict(Mean)
"""
The mean utility that will be used to store the running Sensitivity
for each task label.
"""
@torch.no_grad()
def update(
self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor]
) -> None:
"""
Update the running Sensitivity given the true and predicted labels.
Parameter `task_labels` is used to decide how to update the inner
dictionary: if Float, only the dictionary value related to that task
is updated. If Tensor, all the dictionary elements belonging to the
task labels will be updated.
:param predicted_y: The model prediction. Both labels and logit vectors
are supported.
:param true_y: The ground truth. Both labels and one-hot vectors
are supported.
:param task_labels: the int task label associated to the current
experience or the task labels vector showing the task label
for each pattern.
:return: None.
"""
if len(true_y) != len(predicted_y):
raise ValueError("Size mismatch for true_y and predicted_y tensors")
if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
raise ValueError("Size mismatch for true_y and task_labels tensors")
true_y = torch.as_tensor(true_y)
predicted_y = torch.as_tensor(predicted_y)
# Check if logits or labels
if len(predicted_y.shape) > 1:
# Logits -> transform to labels
predicted_y = torch.max(predicted_y, 1)[1]
if len(true_y.shape) > 1:
# Logits -> transform to labels
true_y = torch.max(true_y, 1)[1]
if isinstance(task_labels, int):
(
true_positives,
false_positives,
true_negatives,
false_negatives,
) = confusion(predicted_y, true_y)
try:
tpr = true_positives / (true_positives + false_negatives)
except ZeroDivisionError:
tpr = 1
self._mean_Sensitivity[task_labels].update(tpr, len(predicted_y))
elif isinstance(task_labels, Tensor):
raise NotImplementedError
else:
raise ValueError(
f"Task label type: {type(task_labels)}, expected int/float or Tensor"
)
def result(self, task_label=None) -> Dict[int, float]:
"""
Retrieves the running Sensitivity.
Calling this method will not change the internal state of the metric.
:param task_label: if None, return the entire dictionary of sensitivities
for each task. Otherwise return the dictionary
`{task_label: Sensitivity}`.
:return: A dict of running sensitivities for each task label,
where each value is a float value between 0 and 1.
"""
assert task_label is None or isinstance(task_label, int)
if task_label is None:
return {k: v.result() for k, v in self._mean_Sensitivity.items()}
else:
return {task_label: self._mean_Sensitivity[task_label].result()}
def reset(self, task_label=None) -> None:
"""
Resets the metric.
:param task_label: if None, reset the entire dictionary.
Otherwise, reset the value associated to `task_label`.
:return: None.
"""
assert task_label is None or isinstance(task_label, int)
if task_label is None:
self._mean_Sensitivity = defaultdict(Mean)
else:
self._mean_Sensitivity[task_label].reset()
class SensitivityPluginMetric(GenericPluginMetric[float]):
"""
Base class for all sensitivities plugin metrics
"""
def __init__(self, reset_at, emit_at, mode):
self._Sensitivity = Sensitivity()
super(SensitivityPluginMetric, self).__init__(
self._Sensitivity, reset_at=reset_at, emit_at=emit_at, mode=mode
)
def reset(self, strategy=None) -> None:
if self._reset_at == "stream" or strategy is None:
self._metric.reset()
else:
self._metric.reset(phase_and_task(strategy)[1])
def result(self, strategy=None) -> float:
if self._emit_at == "stream" or strategy is None:
return self._metric.result()
else:
return self._metric.result(phase_and_task(strategy)[1])
def update(self, strategy):
# task labels defined for each experience
task_labels = strategy.experience.task_labels
if len(task_labels) > 1:
# task labels defined for each pattern
task_labels = strategy.mb_task_id
else:
task_labels = task_labels[0]
self._Sensitivity.update(strategy.mb_output, strategy.mb_y, task_labels)
class MinibatchSensitivity(SensitivityPluginMetric):
"""
The minibatch plugin Sensitivity metric.
This metric only works at training time.
This metric computes the average Sensitivity over patterns
from a single minibatch.
It reports the result after each iteration.
If a more coarse-grained logging is needed, consider using
:class:`EpochSensitivity` instead.
"""
def __init__(self):
"""
Creates an instance of the MinibatchSensitivity metric.
"""
super(MinibatchSensitivity, self).__init__(
reset_at="iteration", emit_at="iteration", mode="train"
)
def __str__(self):
return "Sens_MB"
class EpochSensitivity(SensitivityPluginMetric):
"""
The average Sensitivity over a single training epoch.
This plugin metric only works at training time.
The Sensitivity will be logged after each training epoch by computing
the number of correctly predicted patterns during the epoch divided by
the overall number of patterns encountered in that epoch.
"""
def __init__(self):
"""
Creates an instance of the EpochSensitivity metric.
"""
super(EpochSensitivity, self).__init__(
reset_at="epoch", emit_at="epoch", mode="train"
)
def __str__(self):
return "Sens_Epoch"
class RunningEpochSensitivity(SensitivityPluginMetric):
"""
The average Sensitivity across all minibatches up to the current
epoch iteration.
This plugin metric only works at training time.
At each iteration, this metric logs the Sensitivity averaged over all patterns
seen so far in the current epoch.
The metric resets its state after each training epoch.
"""
def __init__(self):
"""
Creates an instance of the RunningEpochSensitivity metric.
"""
super(RunningEpochSensitivity, self).__init__(
reset_at="epoch", emit_at="iteration", mode="train"
)
def __str__(self):
return "RunningSens_Epoch"
class ExperienceSensitivity(SensitivityPluginMetric):
"""
At the end of each experience, this plugin metric reports
the average Sensitivity over all patterns seen in that experience.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of ExperienceSensitivity metric
"""
super(ExperienceSensitivity, self).__init__(
reset_at="experience", emit_at="experience", mode="eval"
)
def __str__(self):
return "Sens_Exp"
class StreamSensitivity(SensitivityPluginMetric):
"""
At the end of the entire stream of experiences, this plugin metric
reports the average Sensitivity over all patterns seen in all experiences.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of StreamSensitivity metric
"""
super(StreamSensitivity, self).__init__(
reset_at="stream", emit_at="stream", mode="eval"
)
def __str__(self):
return "Sens_Stream"
class TrainedExperienceSensitivity(SensitivityPluginMetric):
"""
At the end of each experience, this plugin metric reports the average
Sensitivity for only the experiences that the model has been trained on so far.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of TrainedExperienceSensitivity metric by first
constructing SensitivityPluginMetric
"""
super(TrainedExperienceSensitivity, self).__init__(
reset_at="stream", emit_at="stream", mode="eval"
)
self._current_experience = 0
def after_training_exp(self, strategy) -> None:
self._current_experience = strategy.experience.current_experience
# Reset average after learning from a new experience
SensitivityPluginMetric.reset(self, strategy)
return SensitivityPluginMetric.after_training_exp(self, strategy)
def update(self, strategy):
"""
Only update the Sensitivity with results from experiences that have been
trained on
"""
if strategy.experience.current_experience <= self._current_experience:
SensitivityPluginMetric.update(self, strategy)
def __str__(self):
return "Sensitivity_On_Trained_Experiences"
def sensitivity_metrics(
*,
minibatch=False,
epoch=False,
epoch_running=False,
experience=False,
stream=False,
trained_experience=False,
) -> List[PluginMetric]:
"""
Helper method that can be used to obtain the desired set of
plugin metrics.
:param minibatch: If True, will return a metric able to log
the minibatch Sensitivity at training time.
:param epoch: If True, will return a metric able to log
the epoch Sensitivity at training time.
:param epoch_running: If True, will return a metric able to log
the running epoch Sensitivity at training time.
:param experience: If True, will return a metric able to log
the Sensitivity on each evaluation experience.
:param stream: If True, will return a metric able to log
the Sensitivity averaged over the entire evaluation stream of experiences.
:param trained_experience: If True, will return a metric able to log
the average evaluation Sensitivity only for experiences that the
model has been trained on
:return: A list of plugin metrics.
"""
metrics = []
if minibatch:
metrics.append(MinibatchSensitivity())
if epoch:
metrics.append(EpochSensitivity())
if epoch_running:
metrics.append(RunningEpochSensitivity())
if experience:
metrics.append(ExperienceSensitivity())
if stream:
metrics.append(StreamSensitivity())
if trained_experience:
metrics.append(TrainedExperienceSensitivity())
return metrics
class Specificity(Metric[float]):
"""
The Specificity metric. This is a standalone metric.
The metric keeps a dictionary of <task_label, Specificity value> pairs.
and update the values through a running average over multiple
<prediction, target> pairs of Tensors, provided incrementally.
The "prediction" and "target" tensors may contain plain labels or
one-hot/logit vectors.
Each time `result` is called, this metric emits the average Specificity
across all predictions made since the last `reset`.
The reset method will bring the metric to its initial state. By default
this metric in its initial state will return an Specificity value of 0.
"""
def __init__(self):
"""
Creates an instance of the standalone Specificity metric.
By default this metric in its initial state will return an Specificity
value of 0. The metric can be updated by using the `update` method
while the running Specificity can be retrieved using the `result` method.
"""
super().__init__()
self._mean_Specificity = defaultdict(Mean)
"""
The mean utility that will be used to store the running Specificity
for each task label.
"""
@torch.no_grad()
def update(
self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor]
) -> None:
"""
Update the running Specificity given the true and predicted labels.
Parameter `task_labels` is used to decide how to update the inner
dictionary: if Float, only the dictionary value related to that task
is updated. If Tensor, all the dictionary elements belonging to the
task labels will be updated.
:param predicted_y: The model prediction. Both labels and logit vectors
are supported.
:param true_y: The ground truth. Both labels and one-hot vectors
are supported.
:param task_labels: the int task label associated to the current
experience or the task labels vector showing the task label
for each pattern.
:return: None.
"""
if len(true_y) != len(predicted_y):
raise ValueError("Size mismatch for true_y and predicted_y tensors")
if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
raise ValueError("Size mismatch for true_y and task_labels tensors")
true_y = torch.as_tensor(true_y)
predicted_y = torch.as_tensor(predicted_y)
# Check if logits or labels
if len(predicted_y.shape) > 1:
# Logits -> transform to labels
predicted_y = torch.max(predicted_y, 1)[1]
if len(true_y.shape) > 1:
# Logits -> transform to labels
true_y = torch.max(true_y, 1)[1]
if isinstance(task_labels, int):
(
true_positives,
false_positives,
true_negatives,
false_negatives,
) = confusion(predicted_y, true_y)
try:
tnr = true_negatives / (true_negatives + false_positives)
except ZeroDivisionError:
tnr = 1
self._mean_Specificity[task_labels].update(tnr, len(predicted_y))
elif isinstance(task_labels, Tensor):
raise NotImplementedError
else:
raise ValueError(
f"Task label type: {type(task_labels)}, expected int/float or Tensor"
)
def result(self, task_label=None) -> Dict[int, float]:
"""
Retrieves the running Specificity.
Calling this method will not change the internal state of the metric.
:param task_label: if None, return the entire dictionary of specificities
for each task. Otherwise return the dictionary
`{task_label: Specificity}`.
:return: A dict of running specificities for each task label,
where each value is a float value between 0 and 1.
"""
assert task_label is None or isinstance(task_label, int)
if task_label is None:
return {k: v.result() for k, v in self._mean_Specificity.items()}
else:
return {task_label: self._mean_Specificity[task_label].result()}
def reset(self, task_label=None) -> None:
"""
Resets the metric.
:param task_label: if None, reset the entire dictionary.
Otherwise, reset the value associated to `task_label`.
:return: None.
"""
assert task_label is None or isinstance(task_label, int)
if task_label is None:
self._mean_Specificity = defaultdict(Mean)
else:
self._mean_Specificity[task_label].reset()
class SpecificityPluginMetric(GenericPluginMetric[float]):
"""
Base class for all specificities plugin metrics
"""
def __init__(self, reset_at, emit_at, mode):
self._Specificity = Specificity()
super(SpecificityPluginMetric, self).__init__(
self._Specificity, reset_at=reset_at, emit_at=emit_at, mode=mode
)
def reset(self, strategy=None) -> None:
if self._reset_at == "stream" or strategy is None:
self._metric.reset()
else:
self._metric.reset(phase_and_task(strategy)[1])
def result(self, strategy=None) -> float:
if self._emit_at == "stream" or strategy is None:
return self._metric.result()
else:
return self._metric.result(phase_and_task(strategy)[1])
def update(self, strategy):
# task labels defined for each experience
task_labels = strategy.experience.task_labels
if len(task_labels) > 1:
# task labels defined for each pattern
task_labels = strategy.mb_task_id
else:
task_labels = task_labels[0]
self._Specificity.update(strategy.mb_output, strategy.mb_y, task_labels)
class MinibatchSpecificity(SpecificityPluginMetric):
"""
The minibatch plugin Specificity metric.
This metric only works at training time.
This metric computes the average Specificity over patterns
from a single minibatch.
It reports the result after each iteration.
If a more coarse-grained logging is needed, consider using
:class:`EpochSpecificity` instead.
"""
def __init__(self):
"""
Creates an instance of the MinibatchSpecificity metric.
"""
super(MinibatchSpecificity, self).__init__(
reset_at="iteration", emit_at="iteration", mode="train"
)
def __str__(self):
return "Spec_MB"
class EpochSpecificity(SpecificityPluginMetric):
"""
The average Specificity over a single training epoch.
This plugin metric only works at training time.
The Specificity will be logged after each training epoch by computing
the number of correctly predicted patterns during the epoch divided by
the overall number of patterns encountered in that epoch.
"""
def __init__(self):
"""
Creates an instance of the EpochSpecificity metric.
"""
super(EpochSpecificity, self).__init__(
reset_at="epoch", emit_at="epoch", mode="train"
)
def __str__(self):
return "Spec_Epoch"
class RunningEpochSpecificity(SpecificityPluginMetric):
"""
The average Specificity across all minibatches up to the current
epoch iteration.
This plugin metric only works at training time.
At each iteration, this metric logs the Specificity averaged over all patterns
seen so far in the current epoch.
The metric resets its state after each training epoch.
"""
def __init__(self):
"""
Creates an instance of the RunningEpochSpecificity metric.
"""
super(RunningEpochSpecificity, self).__init__(
reset_at="epoch", emit_at="iteration", mode="train"
)
def __str__(self):
return "RunningSpec_Epoch"
class ExperienceSpecificity(SpecificityPluginMetric):
"""
At the end of each experience, this plugin metric reports
the average Specificity over all patterns seen in that experience.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of ExperienceSpecificity metric
"""
super(ExperienceSpecificity, self).__init__(
reset_at="experience", emit_at="experience", mode="eval"
)
def __str__(self):
return "Spec_Exp"
class StreamSpecificity(SpecificityPluginMetric):
"""
At the end of the entire stream of experiences, this plugin metric
reports the average Specificity over all patterns seen in all experiences.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of StreamSpecificity metric
"""
super(StreamSpecificity, self).__init__(
reset_at="stream", emit_at="stream", mode="eval"
)
def __str__(self):
return "Spec_Stream"
class TrainedExperienceSpecificity(SpecificityPluginMetric):
"""
At the end of each experience, this plugin metric reports the average
Specificity for only the experiences that the model has been trained on so far.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of TrainedExperienceSpecificity metric by first
constructing SpecificityPluginMetric
"""
super(TrainedExperienceSpecificity, self).__init__(
reset_at="stream", emit_at="stream", mode="eval"
)
self._current_experience = 0
def after_training_exp(self, strategy) -> None:
self._current_experience = strategy.experience.current_experience
# Reset average after learning from a new experience
SpecificityPluginMetric.reset(self, strategy)
return SpecificityPluginMetric.after_training_exp(self, strategy)
def update(self, strategy):
"""
Only update the Specificity with results from experiences that have been
trained on
"""
if strategy.experience.current_experience <= self._current_experience:
SpecificityPluginMetric.update(self, strategy)
def __str__(self):
return "Specificity_On_Trained_Experiences"
def specificity_metrics(
*,
minibatch=False,
epoch=False,
epoch_running=False,
experience=False,
stream=False,
trained_experience=False,
) -> List[PluginMetric]:
"""
Helper method that can be used to obtain the desired set of
plugin metrics.
:param minibatch: If True, will return a metric able to log
the minibatch Specificity at training time.
:param epoch: If True, will return a metric able to log
the epoch Specificity at training time.
:param epoch_running: If True, will return a metric able to log
the running epoch Specificity at training time.
:param experience: If True, will return a metric able to log
the Specificity on each evaluation experience.
:param stream: If True, will return a metric able to log
the Specificity averaged over the entire evaluation stream of experiences.
:param trained_experience: If True, will return a metric able to log
the average evaluation Specificity only for experiences that the
model has been trained on
:return: A list of plugin metrics.
"""
metrics = []
if minibatch:
metrics.append(MinibatchSpecificity())
if epoch:
metrics.append(EpochSpecificity())
if epoch_running:
metrics.append(RunningEpochSpecificity())
if experience:
metrics.append(ExperienceSpecificity())
if stream:
metrics.append(StreamSpecificity())
if trained_experience:
metrics.append(TrainedExperienceSpecificity())
return metrics
class Precision(Metric[float]):
"""
The Precision metric. This is a standalone metric.
The metric keeps a dictionary of <task_label, Precision value> pairs.
and update the values through a running average over multiple
<prediction, target> pairs of Tensors, provided incrementally.
The "prediction" and "target" tensors may contain plain labels or
one-hot/logit vectors.
Each time `result` is called, this metric emits the average Precision
across all predictions made since the last `reset`.
The reset method will bring the metric to its initial state. By default
this metric in its initial state will return an Precision value of 0.
"""
def __init__(self):
"""
Creates an instance of the standalone Precision metric.
By default this metric in its initial state will return a Precision
value of 0. The metric can be updated by using the `update` method
while the running Precision can be retrieved using the `result` method.
"""
super().__init__()
self._mean_Precision = defaultdict(Mean)
"""
The mean utility that will be used to store the running Precision
for each task label.
"""
@torch.no_grad()
def update(
self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor]
) -> None:
"""
Update the running Precision given the true and predicted labels.
Parameter `task_labels` is used to decide how to update the inner
dictionary: if Float, only the dictionary value related to that task
is updated. If Tensor, all the dictionary elements belonging to the
task labels will be updated.
:param predicted_y: The model prediction. Both labels and logit vectors
are supported.
:param true_y: The ground truth. Both labels and one-hot vectors
are supported.
:param task_labels: the int task label associated to the current
experience or the task labels vector showing the task label
for each pattern.
:return: None.
"""
if len(true_y) != len(predicted_y):
raise ValueError("Size mismatch for true_y and predicted_y tensors")
if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
raise ValueError("Size mismatch for true_y and task_labels tensors")
true_y = torch.as_tensor(true_y)
predicted_y = torch.as_tensor(predicted_y)
# Check if logits or labels
if len(predicted_y.shape) > 1:
# Logits -> transform to labels
predicted_y = torch.max(predicted_y, 1)[1]
if len(true_y.shape) > 1:
# Logits -> transform to labels
true_y = torch.max(true_y, 1)[1]
if isinstance(task_labels, int):
(
true_positives,
false_positives,
true_negatives,
false_negatives,
) = confusion(predicted_y, true_y)
try:
ppv = true_positives / (true_positives + false_positives)
except ZeroDivisionError:
ppv = 1
self._mean_Precision[task_labels].update(ppv, len(predicted_y))
elif isinstance(task_labels, Tensor):
raise NotImplementedError
else:
raise ValueError(
f"Task label type: {type(task_labels)}, expected int/float or Tensor"
)
def result(self, task_label=None) -> Dict[int, float]:
"""
Retrieves the running Precision.
Calling this method will not change the internal state of the metric.
:param task_label: if None, return the entire dictionary of precisions
for each task. Otherwise return the dictionary
`{task_label: Precision}`.
:return: A dict of running precisions for each task label,
where each value is a float value between 0 and 1.
"""
assert task_label is None or isinstance(task_label, int)
if task_label is None:
return {k: v.result() for k, v in self._mean_Precision.items()}
else:
return {task_label: self._mean_Precision[task_label].result()}
def reset(self, task_label=None) -> None:
"""
Resets the metric.
:param task_label: if None, reset the entire dictionary.
Otherwise, reset the value associated to `task_label`.
:return: None.
"""
assert task_label is None or isinstance(task_label, int)
if task_label is None:
self._mean_Precision = defaultdict(Mean)
else:
self._mean_Precision[task_label].reset()
class PrecisionPluginMetric(GenericPluginMetric[float]):
"""
Base class for all precisions plugin metrics
"""
def __init__(self, reset_at, emit_at, mode):
self._Precision = Precision()
super(PrecisionPluginMetric, self).__init__(
self._Precision, reset_at=reset_at, emit_at=emit_at, mode=mode
)
def reset(self, strategy=None) -> None:
if self._reset_at == "stream" or strategy is None:
self._metric.reset()
else:
self._metric.reset(phase_and_task(strategy)[1])
def result(self, strategy=None) -> float:
if self._emit_at == "stream" or strategy is None:
return self._metric.result()
else:
return self._metric.result(phase_and_task(strategy)[1])
def update(self, strategy):
# task labels defined for each experience
task_labels = strategy.experience.task_labels
if len(task_labels) > 1:
# task labels defined for each pattern
task_labels = strategy.mb_task_id
else:
task_labels = task_labels[0]
self._Precision.update(strategy.mb_output, strategy.mb_y, task_labels)
class MinibatchPrecision(PrecisionPluginMetric):
"""
The minibatch plugin Precision metric.
This metric only works at training time.
This metric computes the average Precision over patterns
from a single minibatch.
It reports the result after each iteration.
If a more coarse-grained logging is needed, consider using
:class:`EpochPrecision` instead.
"""
def __init__(self):
"""
Creates an instance of the MinibatchPrecision metric.
"""
super(MinibatchPrecision, self).__init__(
reset_at="iteration", emit_at="iteration", mode="train"
)
def __str__(self):
return "Prec_MB"
class EpochPrecision(PrecisionPluginMetric):
"""
The average Precision over a single training epoch.
This plugin metric only works at training time.
The Precision will be logged after each training epoch by computing
the number of correctly predicted patterns during the epoch divided by
the overall number of patterns encountered in that epoch.
"""
def __init__(self):
"""
Creates an instance of the EpochPrecision metric.
"""
super(EpochPrecision, self).__init__(
reset_at="epoch", emit_at="epoch", mode="train"
)
def __str__(self):
return "Prec_Epoch"
class RunningEpochPrecision(PrecisionPluginMetric):
"""
The average Precision across all minibatches up to the current
epoch iteration.
This plugin metric only works at training time.
At each iteration, this metric logs the Precision averaged over all patterns
seen so far in the current epoch.
The metric resets its state after each training epoch.
"""
def __init__(self):
"""
Creates an instance of the RunningEpochPrecision metric.
"""
super(RunningEpochPrecision, self).__init__(
reset_at="epoch", emit_at="iteration", mode="train"
)
def __str__(self):
return "RunningPrec_Epoch"
class ExperiencePrecision(PrecisionPluginMetric):
"""
At the end of each experience, this plugin metric reports
the average Precision over all patterns seen in that experience.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of ExperiencePrecision metric
"""
super(ExperiencePrecision, self).__init__(
reset_at="experience", emit_at="experience", mode="eval"
)
def __str__(self):
return "Prec_Exp"
class StreamPrecision(PrecisionPluginMetric):
"""
At the end of the entire stream of experiences, this plugin metric
reports the average Precision over all patterns seen in all experiences.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of StreamPrecision metric
"""
super(StreamPrecision, self).__init__(
reset_at="stream", emit_at="stream", mode="eval"
)
def __str__(self):
return "Prec_Stream"
class TrainedExperiencePrecision(PrecisionPluginMetric):
"""
At the end of each experience, this plugin metric reports the average
Precision for only the experiences that the model has been trained on so far.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of TrainedExperiencePrecision metric by first
constructing PrecisionPluginMetric
"""
super(TrainedExperiencePrecision, self).__init__(
reset_at="stream", emit_at="stream", mode="eval"
)
self._current_experience = 0
def after_training_exp(self, strategy) -> None:
self._current_experience = strategy.experience.current_experience
# Reset average after learning from a new experience
PrecisionPluginMetric.reset(self, strategy)
return PrecisionPluginMetric.after_training_exp(self, strategy)
def update(self, strategy):
"""
Only update the Precision with results from experiences that have been
trained on
"""
if strategy.experience.current_experience <= self._current_experience:
PrecisionPluginMetric.update(self, strategy)
def __str__(self):
return "Precision_On_Trained_Experiences"
def precision_metrics(
*,
minibatch=False,
epoch=False,
epoch_running=False,
experience=False,
stream=False,
trained_experience=False,
) -> List[PluginMetric]:
"""
Helper method that can be used to obtain the desired set of
plugin metrics.
:param minibatch: If True, will return a metric able to log
the minibatch Precision at training time.
:param epoch: If True, will return a metric able to log
the epoch Precision at training time.
:param epoch_running: If True, will return a metric able to log
the running epoch Precision at training time.
:param experience: If True, will return a metric able to log
the Precision on each evaluation experience.
:param stream: If True, will return a metric able to log
the Precision averaged over the entire evaluation stream of experiences.
:param trained_experience: If True, will return a metric able to log
the average evaluation Precision only for experiences that the
model has been trained on
:return: A list of plugin metrics.
"""
metrics = []
if minibatch:
metrics.append(MinibatchPrecision())
if epoch:
metrics.append(EpochPrecision())
if epoch_running:
metrics.append(RunningEpochPrecision())
if experience:
metrics.append(ExperiencePrecision())
if stream:
metrics.append(StreamPrecision())
if trained_experience:
metrics.append(TrainedExperiencePrecision())
return metrics
class AUPRC(Metric[float]):
"""
The AUPRC metric. This is a standalone metric.
The metric keeps a dictionary of <task_label, AUPRC value> pairs.
and update the values through a running average over multiple
<prediction, target> pairs of Tensors, provided incrementally.
The "prediction" and "target" tensors may contain plain labels or
one-hot/logit vectors.
Each time `result` is called, this metric emits the average AUPRC
across all predictions made since the last `reset`.
The reset method will bring the metric to its initial state. By default
this metric in its initial state will return an AUPRC value of 0.
"""
def __init__(self):
"""
Creates an instance of the standalone AUPRC metric.
By default this metric in its initial state will return a AUPRC
value of 0. The metric can be updated by using the `update` method
while the running AUPRC can be retrieved using the `result` method.
"""
super().__init__()
self._mean_AUPRC = defaultdict(Mean)
"""
The mean utility that will be used to store the running AUPRC
for each task label.
"""
@torch.no_grad()
def update(
self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor]
) -> None:
"""
Update the running AUPRC given the true and predicted labels.
Parameter `task_labels` is used to decide how to update the inner
dictionary: if Float, only the dictionary value related to that task
is updated. If Tensor, all the dictionary elements belonging to the
task labels will be updated.
:param predicted_y: The model prediction. Both labels and logit vectors
are supported.
:param true_y: The ground truth. Both labels and one-hot vectors
are supported.
:param task_labels: the int task label associated to the current
experience or the task labels vector showing the task label
for each pattern.
:return: None.
"""
if len(true_y) != len(predicted_y):
raise ValueError("Size mismatch for true_y and predicted_y tensors")
if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
raise ValueError("Size mismatch for true_y and task_labels tensors")
true_y = torch.as_tensor(true_y)
predicted_y = torch.as_tensor(predicted_y)
assert len(predicted_y.size()) == 2, (
"Predictions need to be logits or scores, not labels"
)
if len(true_y.shape) > 1:
# Logits -> transform to labels
true_y = torch.max(true_y, 1)[1]
scores = predicted_y[arange(len(true_y)), true_y]
with np.errstate(divide="ignore", invalid="ignore"):
average_precision_score_val = average_precision_score(
true_y.cpu(), scores.cpu()
)
if np.isnan(average_precision_score_val):
average_precision_score_val = 0
if isinstance(task_labels, int):
self._mean_AUPRC[task_labels].update(
average_precision_score_val, len(predicted_y)
)
elif isinstance(task_labels, Tensor):
raise NotImplementedError
else:
raise ValueError(
f"Task label type: {type(task_labels)}, expected int/float or Tensor"
)
def result(self, task_label=None) -> Dict[int, float]:
"""
Retrieves the running AUPRC.
Calling this method will not change the internal state of the metric.
:param task_label: if None, return the entire dictionary of AUPRCs
for each task. Otherwise return the dictionary
`{task_label: AUPRC}`.
:return: A dict of running AUPRCs for each task label,
where each value is a float value between 0 and 1.
"""
assert task_label is None or isinstance(task_label, int)
if task_label is None:
return {k: v.result() for k, v in self._mean_AUPRC.items()}
else:
return {task_label: self._mean_AUPRC[task_label].result()}
def reset(self, task_label=None) -> None:
"""
Resets the metric.
:param task_label: if None, reset the entire dictionary.
Otherwise, reset the value associated to `task_label`.
:return: None.
"""
assert task_label is None or isinstance(task_label, int)
if task_label is None:
self._mean_AUPRC = defaultdict(Mean)
else:
self._mean_AUPRC[task_label].reset()
class AUPRCPluginMetric(GenericPluginMetric[float]):
"""
Base class for all AUPRCs plugin metrics
"""
def __init__(self, reset_at, emit_at, mode):
self._AUPRC = AUPRC()
super(AUPRCPluginMetric, self).__init__(
self._AUPRC, reset_at=reset_at, emit_at=emit_at, mode=mode
)
def reset(self, strategy=None) -> None:
if self._reset_at == "stream" or strategy is None:
self._metric.reset()
else:
self._metric.reset(phase_and_task(strategy)[1])
def result(self, strategy=None) -> float:
if self._emit_at == "stream" or strategy is None:
return self._metric.result()
else:
return self._metric.result(phase_and_task(strategy)[1])
def update(self, strategy):
# task labels defined for each experience
task_labels = strategy.experience.task_labels
if len(task_labels) > 1:
# task labels defined for each pattern
task_labels = strategy.mb_task_id
else:
task_labels = task_labels[0]
self._AUPRC.update(strategy.mb_output, strategy.mb_y, task_labels)
class MinibatchAUPRC(AUPRCPluginMetric):
"""
The minibatch plugin AUPRC metric.
This metric only works at training time.
This metric computes the average AUPRC over patterns
from a single minibatch.
It reports the result after each iteration.
If a more coarse-grained logging is needed, consider using
:class:`EpochAUPRC` instead.
"""
def __init__(self):
"""
Creates an instance of the MinibatchAUPRC metric.
"""
super(MinibatchAUPRC, self).__init__(
reset_at="iteration", emit_at="iteration", mode="train"
)
def __str__(self):
return "AUPRC_MB"
class EpochAUPRC(AUPRCPluginMetric):
"""
The average AUPRC over a single training epoch.
This plugin metric only works at training time.
The AUPRC will be logged after each training epoch by computing
the number of correctly predicted patterns during the epoch divided by
the overall number of patterns encountered in that epoch.
"""
def __init__(self):
"""
Creates an instance of the EpochAUPRC metric.
"""
super(EpochAUPRC, self).__init__(
reset_at="epoch", emit_at="epoch", mode="train"
)
def __str__(self):
return "AUPRC_Epoch"
class RunningEpochAUPRC(AUPRCPluginMetric):
"""
The average AUPRC across all minibatches up to the current
epoch iteration.
This plugin metric only works at training time.
At each iteration, this metric logs the AUPRC averaged over all patterns
seen so far in the current epoch.
The metric resets its state after each training epoch.
"""
def __init__(self):
"""
Creates an instance of the RunningEpochAUPRC metric.
"""
super(RunningEpochAUPRC, self).__init__(
reset_at="epoch", emit_at="iteration", mode="train"
)
def __str__(self):
return "RunningAUPRC_Epoch"
class ExperienceAUPRC(AUPRCPluginMetric):
"""
At the end of each experience, this plugin metric reports
the average AUPRC over all patterns seen in that experience.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of ExperienceAUPRC metric
"""
super(ExperienceAUPRC, self).__init__(
reset_at="experience", emit_at="experience", mode="eval"
)
def __str__(self):
return "AUPRC_Exp"
class StreamAUPRC(AUPRCPluginMetric):
"""
At the end of the entire stream of experiences, this plugin metric
reports the average AUPRC over all patterns seen in all experiences.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of StreamAUPRC metric
"""
super(StreamAUPRC, self).__init__(
reset_at="stream", emit_at="stream", mode="eval"
)
def __str__(self):
return "AUPRC_Stream"
class TrainedExperienceAUPRC(AUPRCPluginMetric):
"""
At the end of each experience, this plugin metric reports the average
AUPRC for only the experiences that the model has been trained on so far.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of TrainedExperienceAUPRC metric by first
constructing AUPRCPluginMetric
"""
super(TrainedExperienceAUPRC, self).__init__(
reset_at="stream", emit_at="stream", mode="eval"
)
self._current_experience = 0
def after_training_exp(self, strategy) -> None:
self._current_experience = strategy.experience.current_experience
# Reset average after learning from a new experience
AUPRCPluginMetric.reset(self, strategy)
return AUPRCPluginMetric.after_training_exp(self, strategy)
def update(self, strategy):
"""
Only update the AUPRC with results from experiences that have been
trained on
"""
if strategy.experience.current_experience <= self._current_experience:
AUPRCPluginMetric.update(self, strategy)
def __str__(self):
return "AUPRC_On_Trained_Experiences"
def auprc_metrics(
*,
minibatch=False,
epoch=False,
epoch_running=False,
experience=False,
stream=False,
trained_experience=False,
) -> List[PluginMetric]:
"""
Helper method that can be used to obtain the desired set of
plugin metrics.
:param minibatch: If True, will return a metric able to log
the minibatch AUPRC at training time.
:param epoch: If True, will return a metric able to log
the epoch AUPRC at training time.
:param epoch_running: If True, will return a metric able to log
the running epoch AUPRC at training time.
:param experience: If True, will return a metric able to log
the AUPRC on each evaluation experience.
:param stream: If True, will return a metric able to logAUPRCperiences.
:param trained_experience: If True, will return a metric able to log
the average evaluation AUPRC only for experiences that the
model has been trained on
:return: A list of plugin metrics.
"""
metrics = []
if minibatch:
metrics.append(MinibatchAUPRC())
if epoch:
metrics.append(EpochAUPRC())
if epoch_running:
metrics.append(RunningEpochAUPRC())
if experience:
metrics.append(ExperienceAUPRC())
if stream:
metrics.append(StreamAUPRC())
if trained_experience:
metrics.append(TrainedExperienceAUPRC())
return metrics
class ROCAUC(Metric[float]):
"""
The ROCAUC metric. This is a standalone metric.
The metric keeps a dictionary of <task_label, ROCAUC value> pairs.
and update the values through a running average over multiple
<prediction, target> pairs of Tensors, provided incrementally.
The "prediction" and "target" tensors may contain plain labels or
one-hot/logit vectors.
Each time `result` is called, this metric emits the average ROCAUC
across all predictions made since the last `reset`.
The reset method will bring the metric to its initial state. By default
this metric in its initial state will return an ROCAUC value of 0.
"""
def __init__(self):
"""
Creates an instance of the standalone ROCAUC metric.
By default this metric in its initial state will return a ROCAUC
value of 0. The metric can be updated by using the `update` method
while the running ROCAUC can be retrieved using the `result` method.
"""
super().__init__()
self._mean_ROCAUC = defaultdict(Mean)
"""
The mean utility that will be used to store the running ROCAUC
for each task label.
"""
@torch.no_grad()
def update(
self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor]
) -> None:
"""
Update the running ROCAUC given the true and predicted labels.
Parameter `task_labels` is used to decide how to update the inner
dictionary: if Float, only the dictionary value related to that task
is updated. If Tensor, all the dictionary elements belonging to the
task labels will be updated.
:param predicted_y: The model prediction. Both labels and logit vectors
are supported.
:param true_y: The ground truth. Both labels and one-hot vectors
are supported.
:param task_labels: the int task label associated to the current
experience or the task labels vector showing the task label
for each pattern.
:return: None.
"""
if len(true_y) != len(predicted_y):
raise ValueError("Size mismatch for true_y and predicted_y tensors")
if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
raise ValueError("Size mismatch for true_y and task_labels tensors")
true_y = torch.as_tensor(true_y)
predicted_y = torch.as_tensor(predicted_y)
assert len(predicted_y.size()) == 2, (
"Predictions need to be logits or scores, not labels"
)
if len(true_y.shape) > 1:
# Logits -> transform to labels
true_y = torch.max(true_y, 1)[1]
scores = predicted_y[arange(len(true_y)), true_y]
try:
roc_auc_score_val = roc_auc_score(true_y.cpu(), scores.cpu())
except ValueError:
roc_auc_score_val = 1
if isinstance(task_labels, int):
self._mean_ROCAUC[task_labels].update(roc_auc_score_val, len(predicted_y))
elif isinstance(task_labels, Tensor):
raise NotImplementedError
else:
raise ValueError(
f"Task label type: {type(task_labels)}, expected int/float or Tensor"
)
def result(self, task_label=None) -> Dict[int, float]:
"""
Retrieves the running ROCAUC.
Calling this method will not change the internal state of the metric.
:param task_label: if None, return the entire dictionary of ROCAUCs
for each task. Otherwise return the dictionary
`{task_label: ROCAUC}`.
:return: A dict of running ROCAUCs for each task label,
where each value is a float value between 0 and 1.
"""
assert task_label is None or isinstance(task_label, int)
if task_label is None:
return {k: v.result() for k, v in self._mean_ROCAUC.items()}
else:
return {task_label: self._mean_ROCAUC[task_label].result()}
def reset(self, task_label=None) -> None:
"""
Resets the metric.
:param task_label: if None, reset the entire dictionary.
Otherwise, reset the value associated to `task_label`.
:return: None.
"""
assert task_label is None or isinstance(task_label, int)
if task_label is None:
self._mean_ROCAUC = defaultdict(Mean)
else:
self._mean_ROCAUC[task_label].reset()
class ROCAUCPluginMetric(GenericPluginMetric[float]):
"""
Base class for all ROCAUCs plugin metrics
"""
def __init__(self, reset_at, emit_at, mode):
self._ROCAUC = ROCAUC()
super(ROCAUCPluginMetric, self).__init__(
self._ROCAUC, reset_at=reset_at, emit_at=emit_at, mode=mode
)
def reset(self, strategy=None) -> None:
if self._reset_at == "stream" or strategy is None:
self._metric.reset()
else:
self._metric.reset(phase_and_task(strategy)[1])
def result(self, strategy=None) -> float:
if self._emit_at == "stream" or strategy is None:
return self._metric.result()
else:
return self._metric.result(phase_and_task(strategy)[1])
def update(self, strategy):
# task labels defined for each experience
task_labels = strategy.experience.task_labels
if len(task_labels) > 1:
# task labels defined for each pattern
task_labels = strategy.mb_task_id
else:
task_labels = task_labels[0]
self._ROCAUC.update(strategy.mb_output, strategy.mb_y, task_labels)
class MinibatchROCAUC(ROCAUCPluginMetric):
"""
The minibatch plugin ROCAUC metric.
This metric only works at training time.
This metric computes the average ROCAUC over patterns
from a single minibatch.
It reports the result after each iteration.
If a more coarse-grained logging is needed, consider using
:class:`EpochROCAUC` instead.
"""
def __init__(self):
"""
Creates an instance of the MinibatchROCAUC metric.
"""
super(MinibatchROCAUC, self).__init__(
reset_at="iteration", emit_at="iteration", mode="train"
)
def __str__(self):
return "ROCAUC_MB"
class EpochROCAUC(ROCAUCPluginMetric):
"""
The average ROCAUC over a single training epoch.
This plugin metric only works at training time.
The ROCAUC will be logged after each training epoch by computing
the number of correctly predicted patterns during the epoch divided by
the overall number of patterns encountered in that epoch.
"""
def __init__(self):
"""
Creates an instance of the EpochROCAUC metric.
"""
super(EpochROCAUC, self).__init__(
reset_at="epoch", emit_at="epoch", mode="train"
)
def __str__(self):
return "ROCAUC_Epoch"
class RunningEpochROCAUC(ROCAUCPluginMetric):
"""
The average ROCAUC across all minibatches up to the current
epoch iteration.
This plugin metric only works at training time.
At each iteration, this metric logs the ROCAUC averaged over all patterns
seen so far in the current epoch.
The metric resets its state after each training epoch.
"""
def __init__(self):
"""
Creates an instance of the RunningEpochROCAUC metric.
"""
super(RunningEpochROCAUC, self).__init__(
reset_at="epoch", emit_at="iteration", mode="train"
)
def __str__(self):
return "RunningROCAUC_Epoch"
class ExperienceROCAUC(ROCAUCPluginMetric):
"""
At the end of each experience, this plugin metric reports
the average ROCAUC over all patterns seen in that experience.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of ExperienceROCAUC metric
"""
super(ExperienceROCAUC, self).__init__(
reset_at="experience", emit_at="experience", mode="eval"
)
def __str__(self):
return "ROCAUC_Exp"
class StreamROCAUC(ROCAUCPluginMetric):
"""
At the end of the entire stream of experiences, this plugin metric
reports the average ROCAUC over all patterns seen in all experiences.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of StreamROCAUC metric
"""
super(StreamROCAUC, self).__init__(
reset_at="stream", emit_at="stream", mode="eval"
)
def __str__(self):
return "ROCAUC_Stream"
class TrainedExperienceROCAUC(ROCAUCPluginMetric):
"""
At the end of each experience, this plugin metric reports the average
ROCAUC for only the experiences that the model has been trained on so far.
This metric only works at eval time.
"""
def __init__(self):
"""
Creates an instance of TrainedExperienceROCAUC metric by first
constructing ROCAUCPluginMetric
"""
super(TrainedExperienceROCAUC, self).__init__(
reset_at="stream", emit_at="stream", mode="eval"
)
self._current_experience = 0
def after_training_exp(self, strategy) -> None:
self._current_experience = strategy.experience.current_experience
# Reset average after learning from a new experience
ROCAUCPluginMetric.reset(self, strategy)
return ROCAUCPluginMetric.after_training_exp(self, strategy)
def update(self, strategy):
"""
Only update the ROCAUC with results from experiences that have been
trained on
"""
if strategy.experience.current_experience <= self._current_experience:
ROCAUCPluginMetric.update(self, strategy)
def __str__(self):
return "ROCAUC_On_Trained_Experiences"
def rocauc_metrics(
*,
minibatch=False,
epoch=False,
epoch_running=False,
experience=False,
stream=False,
trained_experience=False,
) -> List[PluginMetric]:
"""
Helper method that can be used to obtain the desired set of
plugin metrics.
:param minibatch: If True, will return a metric able to log
the minibatch ROCAUC at training time.
:param epoch: If True, will return a metric able to log
the epoch ROCAUC at training time.
:param epoch_running: If True, will return a metric able to log
the running epoch ROCAUC at training time.
:param experience: If True, will return a metric able to log
the ROCAUC on each evaluation experience.
:param stream: If True, will return a metric able to logROCAUCperiences.
:param trained_experience: If True, will return a metric able to log
the average evaluation ROCAUC only for experiences that the
model has been trained on
:return: A list of plugin metrics.
"""
metrics = []
if minibatch:
metrics.append(MinibatchROCAUC())
if epoch:
metrics.append(EpochROCAUC())
if epoch_running:
metrics.append(RunningEpochROCAUC())
if experience:
metrics.append(ExperienceROCAUC())
if stream:
metrics.append(StreamROCAUC())
if trained_experience:
metrics.append(TrainedExperienceROCAUC())
return metrics
__all__ = [
"BalancedAccuracy",
"MinibatchBalancedAccuracy",
"EpochBalancedAccuracy",
"RunningEpochBalancedAccuracy",
"ExperienceBalancedAccuracy",
"StreamBalancedAccuracy",
"TrainedExperienceBalancedAccuracy",
"balancedaccuracy_metrics",
"Sensitivity",
"MinibatchSensitivity",
"EpochSensitivity",
"RunningEpochSensitivity",
"ExperienceSensitivity",
"StreamSensitivity",
"TrainedExperienceSensitivity",
"sensitivity_metrics",
"Specificity",
"MinibatchSpecificity",
"EpochSpecificity",
"RunningEpochSpecificity",
"ExperienceSpecificity",
"StreamSpecificity",
"TrainedExperienceSpecificity",
"specificity_metrics",
"Precision",
"MinibatchPrecision",
"EpochPrecision",
"RunningEpochPrecision",
"ExperiencePrecision",
"StreamPrecision",
"TrainedExperiencePrecision",
"precision_metrics",
"AUPRC",
"MinibatchAUPRC",
"EpochAUPRC",
"RunningEpochAUPRC",
"ExperienceAUPRC",
"StreamAUPRC",
"TrainedExperienceAUPRC",
"auprc_metrics",
"ROCAUC",
"MinibatchROCAUC",
"EpochROCAUC",
"RunningEpochROCAUC",
"ExperienceROCAUC",
"StreamROCAUC",
"TrainedExperienceROCAUC",
"rocauc_metrics",
]