Diff of /src/utils/metrics.py [000000] .. [66326d]

Switch to unified view

a b/src/utils/metrics.py
1
"""
2
Custom binary prediction metrics using Avalanche
3
https://github.com/ContinualAI/avalanche/blob/master/notebooks/from-zero-to-hero-tutorial/05_evaluation.ipynb
4
"""
5
6
from typing import List, Union, Dict
7
from collections import defaultdict
8
9
import torch
10
import numpy as np
11
from torch import Tensor, arange
12
from avalanche.evaluation import Metric, PluginMetric, GenericPluginMetric
13
from avalanche.evaluation.metrics.mean import Mean
14
from avalanche.evaluation.metric_utils import phase_and_task
15
16
from sklearn.metrics import average_precision_score, roc_auc_score
17
18
19
def confusion(prediction, truth):
20
    """Returns the confusion matrix for the values in the `prediction` and `truth`
21
    tensors, i.e. the amount of positions where the values of `prediction`
22
    and `truth` are
23
    - 1 and 1 (True Positive)
24
    - 1 and 0 (False Positive)
25
    - 0 and 0 (True Negative)
26
    - 0 and 1 (False Negative)
27
28
    Source: https://gist.github.com/the-bass/cae9f3976866776dea17a5049013258d
29
    """
30
31
    confusion_vector = prediction / truth
32
    # Element-wise division of the 2 tensors returns a new tensor which holds a
33
    # unique value for each case:
34
    #   1     where prediction and truth are 1 (True Positive)
35
    #   inf   where prediction is 1 and truth is 0 (False Positive)
36
    #   nan   where prediction and truth are 0 (True Negative)
37
    #   0     where prediction is 0 and truth is 1 (False Negative)
38
39
    true_positives = torch.sum(confusion_vector == 1).item()
40
    false_positives = torch.sum(confusion_vector == float("inf")).item()
41
    true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
42
    false_negatives = torch.sum(confusion_vector == 0).item()
43
44
    return true_positives, false_positives, true_negatives, false_negatives
45
46
47
# https://github.com/ContinualAI/avalanche/blob/master/avalanche/evaluation/metrics/mean_scores.py
48
# Use above for AUPRC etc templates.
49
50
51
class BalancedAccuracy(Metric[float]):
52
    """
53
    The BalancedAccuracy metric. This is a standalone metric.
54
55
    The metric keeps a dictionary of <task_label, balancedaccuracy value> pairs.
56
    and update the values through a running average over multiple
57
    <prediction, target> pairs of Tensors, provided incrementally.
58
    The "prediction" and "target" tensors may contain plain labels or
59
    one-hot/logit vectors.
60
61
    Each time `result` is called, this metric emits the average balancedaccuracy
62
    across all predictions made since the last `reset`.
63
64
    The reset method will bring the metric to its initial state. By default
65
    this metric in its initial state will return an balancedaccuracy value of 0.
66
    """
67
68
    def __init__(self):
69
        """
70
        Creates an instance of the standalone BalancedAccuracy metric.
71
72
        By default this metric in its initial state will return an balancedaccuracy
73
        value of 0. The metric can be updated by using the `update` method
74
        while the running balancedaccuracy can be retrieved using the `result` method.
75
        """
76
        super().__init__()
77
        self._mean_balancedaccuracy = defaultdict(Mean)
78
        """
79
        The mean utility that will be used to store the running balancedaccuracy
80
        for each task label.
81
        """
82
83
    @torch.no_grad()
84
    def update(
85
        self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor]
86
    ) -> None:
87
        """
88
        Update the running balancedaccuracy given the true and predicted labels.
89
        Parameter `task_labels` is used to decide how to update the inner
90
        dictionary: if Float, only the dictionary value related to that task
91
        is updated. If Tensor, all the dictionary elements belonging to the
92
        task labels will be updated.
93
94
        :param predicted_y: The model prediction. Both labels and logit vectors
95
            are supported.
96
        :param true_y: The ground truth. Both labels and one-hot vectors
97
            are supported.
98
        :param task_labels: the int task label associated to the current
99
            experience or the task labels vector showing the task label
100
            for each pattern.
101
102
        :return: None.
103
        """
104
        if len(true_y) != len(predicted_y):
105
            raise ValueError("Size mismatch for true_y and predicted_y tensors")
106
107
        if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
108
            raise ValueError("Size mismatch for true_y and task_labels tensors")
109
110
        true_y = torch.as_tensor(true_y)
111
        predicted_y = torch.as_tensor(predicted_y)
112
113
        # Check if logits or labels
114
        if len(predicted_y.shape) > 1:
115
            # Logits -> transform to labels
116
            predicted_y = torch.max(predicted_y, 1)[1]
117
118
        if len(true_y.shape) > 1:
119
            # Logits -> transform to labels
120
            true_y = torch.max(true_y, 1)[1]
121
122
        if isinstance(task_labels, int):
123
            (
124
                true_positives,
125
                false_positives,
126
                true_negatives,
127
                false_negatives,
128
            ) = confusion(predicted_y, true_y)
129
130
            try:
131
                tpr = true_positives / (true_positives + false_negatives)
132
            except ZeroDivisionError:
133
                tpr = 1
134
135
            try:
136
                tnr = true_negatives / (true_negatives + false_positives)
137
            except ZeroDivisionError:
138
                tnr = 1
139
140
            self._mean_balancedaccuracy[task_labels].update(
141
                (tpr + tnr) / 2, len(predicted_y)
142
            )
143
        elif isinstance(task_labels, Tensor):
144
            raise NotImplementedError
145
        else:
146
            raise ValueError(
147
                f"Task label type: {type(task_labels)}, expected int/float or Tensor"
148
            )
149
150
    def result(self, task_label=None) -> Dict[int, float]:
151
        """
152
        Retrieves the running balancedaccuracy.
153
154
        Calling this method will not change the internal state of the metric.
155
156
        :param task_label: if None, return the entire dictionary of balanced accuracies
157
            for each task. Otherwise return the dictionary
158
            `{task_label: balancedaccuracy}`.
159
        :return: A dict of running balanced accuracies for each task label,
160
            where each value is a float value between 0 and 1.
161
        """
162
        assert task_label is None or isinstance(task_label, int)
163
        if task_label is None:
164
            return {k: v.result() for k, v in self._mean_balancedaccuracy.items()}
165
        else:
166
            return {task_label: self._mean_balancedaccuracy[task_label].result()}
167
168
    def reset(self, task_label=None) -> None:
169
        """
170
        Resets the metric.
171
        :param task_label: if None, reset the entire dictionary.
172
            Otherwise, reset the value associated to `task_label`.
173
174
        :return: None.
175
        """
176
        assert task_label is None or isinstance(task_label, int)
177
        if task_label is None:
178
            self._mean_balancedaccuracy = defaultdict(Mean)
179
        else:
180
            self._mean_balancedaccuracy[task_label].reset()
181
182
183
class BalancedAccuracyPluginMetric(GenericPluginMetric[float]):
184
    """
185
    Base class for all balanced accuracies plugin metrics
186
    """
187
188
    def __init__(self, reset_at, emit_at, mode):
189
        self._balancedaccuracy = BalancedAccuracy()
190
        super(BalancedAccuracyPluginMetric, self).__init__(
191
            self._balancedaccuracy, reset_at=reset_at, emit_at=emit_at, mode=mode
192
        )
193
194
    def reset(self, strategy=None) -> None:
195
        if self._reset_at == "stream" or strategy is None:
196
            self._metric.reset()
197
        else:
198
            self._metric.reset(phase_and_task(strategy)[1])
199
200
    def result(self, strategy=None) -> float:
201
        if self._emit_at == "stream" or strategy is None:
202
            return self._metric.result()
203
        else:
204
            return self._metric.result(phase_and_task(strategy)[1])
205
206
    def update(self, strategy):
207
        # task labels defined for each experience
208
        task_labels = strategy.experience.task_labels
209
        if len(task_labels) > 1:
210
            # task labels defined for each pattern
211
            task_labels = strategy.mb_task_id
212
        else:
213
            task_labels = task_labels[0]
214
        self._balancedaccuracy.update(strategy.mb_output, strategy.mb_y, task_labels)
215
216
217
class MinibatchBalancedAccuracy(BalancedAccuracyPluginMetric):
218
    """
219
    The minibatch plugin balancedaccuracy metric.
220
    This metric only works at training time.
221
222
    This metric computes the average balancedaccuracy over patterns
223
    from a single minibatch.
224
    It reports the result after each iteration.
225
226
    If a more coarse-grained logging is needed, consider using
227
    :class:`EpochBalancedAccuracy` instead.
228
    """
229
230
    def __init__(self):
231
        """
232
        Creates an instance of the MinibatchBalancedAccuracy metric.
233
        """
234
        super(MinibatchBalancedAccuracy, self).__init__(
235
            reset_at="iteration", emit_at="iteration", mode="train"
236
        )
237
238
    def __str__(self):
239
        return "BalAcc_MB"
240
241
242
class EpochBalancedAccuracy(BalancedAccuracyPluginMetric):
243
    """
244
    The average balancedaccuracy over a single training epoch.
245
    This plugin metric only works at training time.
246
247
    The balancedaccuracy will be logged after each training epoch by computing
248
    the number of correctly predicted patterns during the epoch divided by
249
    the overall number of patterns encountered in that epoch.
250
    """
251
252
    def __init__(self):
253
        """
254
        Creates an instance of the EpochBalancedAccuracy metric.
255
        """
256
257
        super(EpochBalancedAccuracy, self).__init__(
258
            reset_at="epoch", emit_at="epoch", mode="train"
259
        )
260
261
    def __str__(self):
262
        return "BalAcc_Epoch"
263
264
265
class RunningEpochBalancedAccuracy(BalancedAccuracyPluginMetric):
266
    """
267
    The average balancedaccuracy across all minibatches up to the current
268
    epoch iteration.
269
    This plugin metric only works at training time.
270
271
    At each iteration, this metric logs the balancedaccuracy averaged over all patterns
272
    seen so far in the current epoch.
273
    The metric resets its state after each training epoch.
274
    """
275
276
    def __init__(self):
277
        """
278
        Creates an instance of the RunningEpochBalancedAccuracy metric.
279
        """
280
281
        super(RunningEpochBalancedAccuracy, self).__init__(
282
            reset_at="epoch", emit_at="iteration", mode="train"
283
        )
284
285
    def __str__(self):
286
        return "RunningBalAcc_Epoch"
287
288
289
class ExperienceBalancedAccuracy(BalancedAccuracyPluginMetric):
290
    """
291
    At the end of each experience, this plugin metric reports
292
    the average balancedaccuracy over all patterns seen in that experience.
293
    This metric only works at eval time.
294
    """
295
296
    def __init__(self):
297
        """
298
        Creates an instance of ExperienceBalancedAccuracy metric
299
        """
300
        super(ExperienceBalancedAccuracy, self).__init__(
301
            reset_at="experience", emit_at="experience", mode="eval"
302
        )
303
304
    def __str__(self):
305
        return "BalAcc_Exp"
306
307
308
class StreamBalancedAccuracy(BalancedAccuracyPluginMetric):
309
    """
310
    At the end of the entire stream of experiences, this plugin metric
311
    reports the average balancedaccuracy over all patterns seen in all experiences.
312
    This metric only works at eval time.
313
    """
314
315
    def __init__(self):
316
        """
317
        Creates an instance of StreamBalancedAccuracy metric
318
        """
319
        super(StreamBalancedAccuracy, self).__init__(
320
            reset_at="stream", emit_at="stream", mode="eval"
321
        )
322
323
    def __str__(self):
324
        return "BalAcc_Stream"
325
326
327
class TrainedExperienceBalancedAccuracy(BalancedAccuracyPluginMetric):
328
    """
329
    At the end of each experience, this plugin metric reports the average
330
    balancedaccuracy for only the experiences that the model has been trained on so far.
331
332
    This metric only works at eval time.
333
    """
334
335
    def __init__(self):
336
        """
337
        Creates an instance of TrainedExperienceBalancedAccuracy metric by first
338
        constructing BalancedAccuracyPluginMetric
339
        """
340
        super(TrainedExperienceBalancedAccuracy, self).__init__(
341
            reset_at="stream", emit_at="stream", mode="eval"
342
        )
343
        self._current_experience = 0
344
345
    def after_training_exp(self, strategy) -> None:
346
        self._current_experience = strategy.experience.current_experience
347
        # Reset average after learning from a new experience
348
        BalancedAccuracyPluginMetric.reset(self, strategy)
349
        return BalancedAccuracyPluginMetric.after_training_exp(self, strategy)
350
351
    def update(self, strategy):
352
        """
353
        Only update the balancedaccuracy with results from experiences that have been
354
        trained on
355
        """
356
        if strategy.experience.current_experience <= self._current_experience:
357
            BalancedAccuracyPluginMetric.update(self, strategy)
358
359
    def __str__(self):
360
        return "BalancedAccuracy_On_Trained_Experiences"
361
362
363
def balancedaccuracy_metrics(
364
    *,
365
    minibatch=False,
366
    epoch=False,
367
    epoch_running=False,
368
    experience=False,
369
    stream=False,
370
    trained_experience=False,
371
) -> List[PluginMetric]:
372
    """
373
    Helper method that can be used to obtain the desired set of
374
    plugin metrics.
375
376
    :param minibatch: If True, will return a metric able to log
377
        the minibatch balancedaccuracy at training time.
378
    :param epoch: If True, will return a metric able to log
379
        the epoch balancedaccuracy at training time.
380
    :param epoch_running: If True, will return a metric able to log
381
        the running epoch balancedaccuracy at training time.
382
    :param experience: If True, will return a metric able to log
383
        the balancedaccuracy on each evaluation experience.
384
    :param stream: If True, will return a metric able to log
385
        the balancedaccuracy averaged over the entire evaluation stream of experiences.
386
    :param trained_experience: If True, will return a metric able to log
387
        the average evaluation balancedaccuracy only for experiences that the
388
        model has been trained on
389
390
    :return: A list of plugin metrics.
391
    """
392
393
    metrics = []
394
    if minibatch:
395
        metrics.append(MinibatchBalancedAccuracy())
396
397
    if epoch:
398
        metrics.append(EpochBalancedAccuracy())
399
400
    if epoch_running:
401
        metrics.append(RunningEpochBalancedAccuracy())
402
403
    if experience:
404
        metrics.append(ExperienceBalancedAccuracy())
405
406
    if stream:
407
        metrics.append(StreamBalancedAccuracy())
408
409
    if trained_experience:
410
        metrics.append(TrainedExperienceBalancedAccuracy())
411
412
    return metrics
413
414
415
class Sensitivity(Metric[float]):
416
    """
417
    The Sensitivity metric. This is a standalone metric.
418
419
    The metric keeps a dictionary of <task_label, Sensitivity value> pairs.
420
    and update the values through a running average over multiple
421
    <prediction, target> pairs of Tensors, provided incrementally.
422
    The "prediction" and "target" tensors may contain plain labels or
423
    one-hot/logit vectors.
424
425
    Each time `result` is called, this metric emits the average Sensitivity
426
    across all predictions made since the last `reset`.
427
428
    The reset method will bring the metric to its initial state. By default
429
    this metric in its initial state will return an Sensitivity value of 0.
430
    """
431
432
    def __init__(self):
433
        """
434
        Creates an instance of the standalone Sensitivity metric.
435
436
        By default this metric in its initial state will return an Sensitivity
437
        value of 0. The metric can be updated by using the `update` method
438
        while the running Sensitivity can be retrieved using the `result` method.
439
        """
440
        super().__init__()
441
        self._mean_Sensitivity = defaultdict(Mean)
442
        """
443
        The mean utility that will be used to store the running Sensitivity
444
        for each task label.
445
        """
446
447
    @torch.no_grad()
448
    def update(
449
        self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor]
450
    ) -> None:
451
        """
452
        Update the running Sensitivity given the true and predicted labels.
453
        Parameter `task_labels` is used to decide how to update the inner
454
        dictionary: if Float, only the dictionary value related to that task
455
        is updated. If Tensor, all the dictionary elements belonging to the
456
        task labels will be updated.
457
458
        :param predicted_y: The model prediction. Both labels and logit vectors
459
            are supported.
460
        :param true_y: The ground truth. Both labels and one-hot vectors
461
            are supported.
462
        :param task_labels: the int task label associated to the current
463
            experience or the task labels vector showing the task label
464
            for each pattern.
465
466
        :return: None.
467
        """
468
        if len(true_y) != len(predicted_y):
469
            raise ValueError("Size mismatch for true_y and predicted_y tensors")
470
471
        if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
472
            raise ValueError("Size mismatch for true_y and task_labels tensors")
473
474
        true_y = torch.as_tensor(true_y)
475
        predicted_y = torch.as_tensor(predicted_y)
476
477
        # Check if logits or labels
478
        if len(predicted_y.shape) > 1:
479
            # Logits -> transform to labels
480
            predicted_y = torch.max(predicted_y, 1)[1]
481
482
        if len(true_y.shape) > 1:
483
            # Logits -> transform to labels
484
            true_y = torch.max(true_y, 1)[1]
485
486
        if isinstance(task_labels, int):
487
            (
488
                true_positives,
489
                false_positives,
490
                true_negatives,
491
                false_negatives,
492
            ) = confusion(predicted_y, true_y)
493
494
            try:
495
                tpr = true_positives / (true_positives + false_negatives)
496
            except ZeroDivisionError:
497
                tpr = 1
498
499
            self._mean_Sensitivity[task_labels].update(tpr, len(predicted_y))
500
        elif isinstance(task_labels, Tensor):
501
            raise NotImplementedError
502
        else:
503
            raise ValueError(
504
                f"Task label type: {type(task_labels)}, expected int/float or Tensor"
505
            )
506
507
    def result(self, task_label=None) -> Dict[int, float]:
508
        """
509
        Retrieves the running Sensitivity.
510
511
        Calling this method will not change the internal state of the metric.
512
513
        :param task_label: if None, return the entire dictionary of sensitivities
514
            for each task. Otherwise return the dictionary
515
            `{task_label: Sensitivity}`.
516
        :return: A dict of running sensitivities for each task label,
517
            where each value is a float value between 0 and 1.
518
        """
519
        assert task_label is None or isinstance(task_label, int)
520
        if task_label is None:
521
            return {k: v.result() for k, v in self._mean_Sensitivity.items()}
522
        else:
523
            return {task_label: self._mean_Sensitivity[task_label].result()}
524
525
    def reset(self, task_label=None) -> None:
526
        """
527
        Resets the metric.
528
        :param task_label: if None, reset the entire dictionary.
529
            Otherwise, reset the value associated to `task_label`.
530
531
        :return: None.
532
        """
533
        assert task_label is None or isinstance(task_label, int)
534
        if task_label is None:
535
            self._mean_Sensitivity = defaultdict(Mean)
536
        else:
537
            self._mean_Sensitivity[task_label].reset()
538
539
540
class SensitivityPluginMetric(GenericPluginMetric[float]):
541
    """
542
    Base class for all sensitivities plugin metrics
543
    """
544
545
    def __init__(self, reset_at, emit_at, mode):
546
        self._Sensitivity = Sensitivity()
547
        super(SensitivityPluginMetric, self).__init__(
548
            self._Sensitivity, reset_at=reset_at, emit_at=emit_at, mode=mode
549
        )
550
551
    def reset(self, strategy=None) -> None:
552
        if self._reset_at == "stream" or strategy is None:
553
            self._metric.reset()
554
        else:
555
            self._metric.reset(phase_and_task(strategy)[1])
556
557
    def result(self, strategy=None) -> float:
558
        if self._emit_at == "stream" or strategy is None:
559
            return self._metric.result()
560
        else:
561
            return self._metric.result(phase_and_task(strategy)[1])
562
563
    def update(self, strategy):
564
        # task labels defined for each experience
565
        task_labels = strategy.experience.task_labels
566
        if len(task_labels) > 1:
567
            # task labels defined for each pattern
568
            task_labels = strategy.mb_task_id
569
        else:
570
            task_labels = task_labels[0]
571
        self._Sensitivity.update(strategy.mb_output, strategy.mb_y, task_labels)
572
573
574
class MinibatchSensitivity(SensitivityPluginMetric):
575
    """
576
    The minibatch plugin Sensitivity metric.
577
    This metric only works at training time.
578
579
    This metric computes the average Sensitivity over patterns
580
    from a single minibatch.
581
    It reports the result after each iteration.
582
583
    If a more coarse-grained logging is needed, consider using
584
    :class:`EpochSensitivity` instead.
585
    """
586
587
    def __init__(self):
588
        """
589
        Creates an instance of the MinibatchSensitivity metric.
590
        """
591
        super(MinibatchSensitivity, self).__init__(
592
            reset_at="iteration", emit_at="iteration", mode="train"
593
        )
594
595
    def __str__(self):
596
        return "Sens_MB"
597
598
599
class EpochSensitivity(SensitivityPluginMetric):
600
    """
601
    The average Sensitivity over a single training epoch.
602
    This plugin metric only works at training time.
603
604
    The Sensitivity will be logged after each training epoch by computing
605
    the number of correctly predicted patterns during the epoch divided by
606
    the overall number of patterns encountered in that epoch.
607
    """
608
609
    def __init__(self):
610
        """
611
        Creates an instance of the EpochSensitivity metric.
612
        """
613
614
        super(EpochSensitivity, self).__init__(
615
            reset_at="epoch", emit_at="epoch", mode="train"
616
        )
617
618
    def __str__(self):
619
        return "Sens_Epoch"
620
621
622
class RunningEpochSensitivity(SensitivityPluginMetric):
623
    """
624
    The average Sensitivity across all minibatches up to the current
625
    epoch iteration.
626
    This plugin metric only works at training time.
627
628
    At each iteration, this metric logs the Sensitivity averaged over all patterns
629
    seen so far in the current epoch.
630
    The metric resets its state after each training epoch.
631
    """
632
633
    def __init__(self):
634
        """
635
        Creates an instance of the RunningEpochSensitivity metric.
636
        """
637
638
        super(RunningEpochSensitivity, self).__init__(
639
            reset_at="epoch", emit_at="iteration", mode="train"
640
        )
641
642
    def __str__(self):
643
        return "RunningSens_Epoch"
644
645
646
class ExperienceSensitivity(SensitivityPluginMetric):
647
    """
648
    At the end of each experience, this plugin metric reports
649
    the average Sensitivity over all patterns seen in that experience.
650
    This metric only works at eval time.
651
    """
652
653
    def __init__(self):
654
        """
655
        Creates an instance of ExperienceSensitivity metric
656
        """
657
        super(ExperienceSensitivity, self).__init__(
658
            reset_at="experience", emit_at="experience", mode="eval"
659
        )
660
661
    def __str__(self):
662
        return "Sens_Exp"
663
664
665
class StreamSensitivity(SensitivityPluginMetric):
666
    """
667
    At the end of the entire stream of experiences, this plugin metric
668
    reports the average Sensitivity over all patterns seen in all experiences.
669
    This metric only works at eval time.
670
    """
671
672
    def __init__(self):
673
        """
674
        Creates an instance of StreamSensitivity metric
675
        """
676
        super(StreamSensitivity, self).__init__(
677
            reset_at="stream", emit_at="stream", mode="eval"
678
        )
679
680
    def __str__(self):
681
        return "Sens_Stream"
682
683
684
class TrainedExperienceSensitivity(SensitivityPluginMetric):
685
    """
686
    At the end of each experience, this plugin metric reports the average
687
    Sensitivity for only the experiences that the model has been trained on so far.
688
689
    This metric only works at eval time.
690
    """
691
692
    def __init__(self):
693
        """
694
        Creates an instance of TrainedExperienceSensitivity metric by first
695
        constructing SensitivityPluginMetric
696
        """
697
        super(TrainedExperienceSensitivity, self).__init__(
698
            reset_at="stream", emit_at="stream", mode="eval"
699
        )
700
        self._current_experience = 0
701
702
    def after_training_exp(self, strategy) -> None:
703
        self._current_experience = strategy.experience.current_experience
704
        # Reset average after learning from a new experience
705
        SensitivityPluginMetric.reset(self, strategy)
706
        return SensitivityPluginMetric.after_training_exp(self, strategy)
707
708
    def update(self, strategy):
709
        """
710
        Only update the Sensitivity with results from experiences that have been
711
        trained on
712
        """
713
        if strategy.experience.current_experience <= self._current_experience:
714
            SensitivityPluginMetric.update(self, strategy)
715
716
    def __str__(self):
717
        return "Sensitivity_On_Trained_Experiences"
718
719
720
def sensitivity_metrics(
721
    *,
722
    minibatch=False,
723
    epoch=False,
724
    epoch_running=False,
725
    experience=False,
726
    stream=False,
727
    trained_experience=False,
728
) -> List[PluginMetric]:
729
    """
730
    Helper method that can be used to obtain the desired set of
731
    plugin metrics.
732
733
    :param minibatch: If True, will return a metric able to log
734
        the minibatch Sensitivity at training time.
735
    :param epoch: If True, will return a metric able to log
736
        the epoch Sensitivity at training time.
737
    :param epoch_running: If True, will return a metric able to log
738
        the running epoch Sensitivity at training time.
739
    :param experience: If True, will return a metric able to log
740
        the Sensitivity on each evaluation experience.
741
    :param stream: If True, will return a metric able to log
742
        the Sensitivity averaged over the entire evaluation stream of experiences.
743
    :param trained_experience: If True, will return a metric able to log
744
        the average evaluation Sensitivity only for experiences that the
745
        model has been trained on
746
747
    :return: A list of plugin metrics.
748
    """
749
750
    metrics = []
751
    if minibatch:
752
        metrics.append(MinibatchSensitivity())
753
754
    if epoch:
755
        metrics.append(EpochSensitivity())
756
757
    if epoch_running:
758
        metrics.append(RunningEpochSensitivity())
759
760
    if experience:
761
        metrics.append(ExperienceSensitivity())
762
763
    if stream:
764
        metrics.append(StreamSensitivity())
765
766
    if trained_experience:
767
        metrics.append(TrainedExperienceSensitivity())
768
769
    return metrics
770
771
772
class Specificity(Metric[float]):
773
    """
774
    The Specificity metric. This is a standalone metric.
775
776
    The metric keeps a dictionary of <task_label, Specificity value> pairs.
777
    and update the values through a running average over multiple
778
    <prediction, target> pairs of Tensors, provided incrementally.
779
    The "prediction" and "target" tensors may contain plain labels or
780
    one-hot/logit vectors.
781
782
    Each time `result` is called, this metric emits the average Specificity
783
    across all predictions made since the last `reset`.
784
785
    The reset method will bring the metric to its initial state. By default
786
    this metric in its initial state will return an Specificity value of 0.
787
    """
788
789
    def __init__(self):
790
        """
791
        Creates an instance of the standalone Specificity metric.
792
793
        By default this metric in its initial state will return an Specificity
794
        value of 0. The metric can be updated by using the `update` method
795
        while the running Specificity can be retrieved using the `result` method.
796
        """
797
        super().__init__()
798
        self._mean_Specificity = defaultdict(Mean)
799
        """
800
        The mean utility that will be used to store the running Specificity
801
        for each task label.
802
        """
803
804
    @torch.no_grad()
805
    def update(
806
        self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor]
807
    ) -> None:
808
        """
809
        Update the running Specificity given the true and predicted labels.
810
        Parameter `task_labels` is used to decide how to update the inner
811
        dictionary: if Float, only the dictionary value related to that task
812
        is updated. If Tensor, all the dictionary elements belonging to the
813
        task labels will be updated.
814
815
        :param predicted_y: The model prediction. Both labels and logit vectors
816
            are supported.
817
        :param true_y: The ground truth. Both labels and one-hot vectors
818
            are supported.
819
        :param task_labels: the int task label associated to the current
820
            experience or the task labels vector showing the task label
821
            for each pattern.
822
823
        :return: None.
824
        """
825
        if len(true_y) != len(predicted_y):
826
            raise ValueError("Size mismatch for true_y and predicted_y tensors")
827
828
        if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
829
            raise ValueError("Size mismatch for true_y and task_labels tensors")
830
831
        true_y = torch.as_tensor(true_y)
832
        predicted_y = torch.as_tensor(predicted_y)
833
834
        # Check if logits or labels
835
        if len(predicted_y.shape) > 1:
836
            # Logits -> transform to labels
837
            predicted_y = torch.max(predicted_y, 1)[1]
838
839
        if len(true_y.shape) > 1:
840
            # Logits -> transform to labels
841
            true_y = torch.max(true_y, 1)[1]
842
843
        if isinstance(task_labels, int):
844
            (
845
                true_positives,
846
                false_positives,
847
                true_negatives,
848
                false_negatives,
849
            ) = confusion(predicted_y, true_y)
850
851
            try:
852
                tnr = true_negatives / (true_negatives + false_positives)
853
            except ZeroDivisionError:
854
                tnr = 1
855
856
            self._mean_Specificity[task_labels].update(tnr, len(predicted_y))
857
        elif isinstance(task_labels, Tensor):
858
            raise NotImplementedError
859
        else:
860
            raise ValueError(
861
                f"Task label type: {type(task_labels)}, expected int/float or Tensor"
862
            )
863
864
    def result(self, task_label=None) -> Dict[int, float]:
865
        """
866
        Retrieves the running Specificity.
867
868
        Calling this method will not change the internal state of the metric.
869
870
        :param task_label: if None, return the entire dictionary of specificities
871
            for each task. Otherwise return the dictionary
872
            `{task_label: Specificity}`.
873
        :return: A dict of running specificities for each task label,
874
            where each value is a float value between 0 and 1.
875
        """
876
        assert task_label is None or isinstance(task_label, int)
877
        if task_label is None:
878
            return {k: v.result() for k, v in self._mean_Specificity.items()}
879
        else:
880
            return {task_label: self._mean_Specificity[task_label].result()}
881
882
    def reset(self, task_label=None) -> None:
883
        """
884
        Resets the metric.
885
        :param task_label: if None, reset the entire dictionary.
886
            Otherwise, reset the value associated to `task_label`.
887
888
        :return: None.
889
        """
890
        assert task_label is None or isinstance(task_label, int)
891
        if task_label is None:
892
            self._mean_Specificity = defaultdict(Mean)
893
        else:
894
            self._mean_Specificity[task_label].reset()
895
896
897
class SpecificityPluginMetric(GenericPluginMetric[float]):
898
    """
899
    Base class for all specificities plugin metrics
900
    """
901
902
    def __init__(self, reset_at, emit_at, mode):
903
        self._Specificity = Specificity()
904
        super(SpecificityPluginMetric, self).__init__(
905
            self._Specificity, reset_at=reset_at, emit_at=emit_at, mode=mode
906
        )
907
908
    def reset(self, strategy=None) -> None:
909
        if self._reset_at == "stream" or strategy is None:
910
            self._metric.reset()
911
        else:
912
            self._metric.reset(phase_and_task(strategy)[1])
913
914
    def result(self, strategy=None) -> float:
915
        if self._emit_at == "stream" or strategy is None:
916
            return self._metric.result()
917
        else:
918
            return self._metric.result(phase_and_task(strategy)[1])
919
920
    def update(self, strategy):
921
        # task labels defined for each experience
922
        task_labels = strategy.experience.task_labels
923
        if len(task_labels) > 1:
924
            # task labels defined for each pattern
925
            task_labels = strategy.mb_task_id
926
        else:
927
            task_labels = task_labels[0]
928
        self._Specificity.update(strategy.mb_output, strategy.mb_y, task_labels)
929
930
931
class MinibatchSpecificity(SpecificityPluginMetric):
932
    """
933
    The minibatch plugin Specificity metric.
934
    This metric only works at training time.
935
936
    This metric computes the average Specificity over patterns
937
    from a single minibatch.
938
    It reports the result after each iteration.
939
940
    If a more coarse-grained logging is needed, consider using
941
    :class:`EpochSpecificity` instead.
942
    """
943
944
    def __init__(self):
945
        """
946
        Creates an instance of the MinibatchSpecificity metric.
947
        """
948
        super(MinibatchSpecificity, self).__init__(
949
            reset_at="iteration", emit_at="iteration", mode="train"
950
        )
951
952
    def __str__(self):
953
        return "Spec_MB"
954
955
956
class EpochSpecificity(SpecificityPluginMetric):
957
    """
958
    The average Specificity over a single training epoch.
959
    This plugin metric only works at training time.
960
961
    The Specificity will be logged after each training epoch by computing
962
    the number of correctly predicted patterns during the epoch divided by
963
    the overall number of patterns encountered in that epoch.
964
    """
965
966
    def __init__(self):
967
        """
968
        Creates an instance of the EpochSpecificity metric.
969
        """
970
971
        super(EpochSpecificity, self).__init__(
972
            reset_at="epoch", emit_at="epoch", mode="train"
973
        )
974
975
    def __str__(self):
976
        return "Spec_Epoch"
977
978
979
class RunningEpochSpecificity(SpecificityPluginMetric):
980
    """
981
    The average Specificity across all minibatches up to the current
982
    epoch iteration.
983
    This plugin metric only works at training time.
984
985
    At each iteration, this metric logs the Specificity averaged over all patterns
986
    seen so far in the current epoch.
987
    The metric resets its state after each training epoch.
988
    """
989
990
    def __init__(self):
991
        """
992
        Creates an instance of the RunningEpochSpecificity metric.
993
        """
994
995
        super(RunningEpochSpecificity, self).__init__(
996
            reset_at="epoch", emit_at="iteration", mode="train"
997
        )
998
999
    def __str__(self):
1000
        return "RunningSpec_Epoch"
1001
1002
1003
class ExperienceSpecificity(SpecificityPluginMetric):
1004
    """
1005
    At the end of each experience, this plugin metric reports
1006
    the average Specificity over all patterns seen in that experience.
1007
    This metric only works at eval time.
1008
    """
1009
1010
    def __init__(self):
1011
        """
1012
        Creates an instance of ExperienceSpecificity metric
1013
        """
1014
        super(ExperienceSpecificity, self).__init__(
1015
            reset_at="experience", emit_at="experience", mode="eval"
1016
        )
1017
1018
    def __str__(self):
1019
        return "Spec_Exp"
1020
1021
1022
class StreamSpecificity(SpecificityPluginMetric):
1023
    """
1024
    At the end of the entire stream of experiences, this plugin metric
1025
    reports the average Specificity over all patterns seen in all experiences.
1026
    This metric only works at eval time.
1027
    """
1028
1029
    def __init__(self):
1030
        """
1031
        Creates an instance of StreamSpecificity metric
1032
        """
1033
        super(StreamSpecificity, self).__init__(
1034
            reset_at="stream", emit_at="stream", mode="eval"
1035
        )
1036
1037
    def __str__(self):
1038
        return "Spec_Stream"
1039
1040
1041
class TrainedExperienceSpecificity(SpecificityPluginMetric):
1042
    """
1043
    At the end of each experience, this plugin metric reports the average
1044
    Specificity for only the experiences that the model has been trained on so far.
1045
1046
    This metric only works at eval time.
1047
    """
1048
1049
    def __init__(self):
1050
        """
1051
        Creates an instance of TrainedExperienceSpecificity metric by first
1052
        constructing SpecificityPluginMetric
1053
        """
1054
        super(TrainedExperienceSpecificity, self).__init__(
1055
            reset_at="stream", emit_at="stream", mode="eval"
1056
        )
1057
        self._current_experience = 0
1058
1059
    def after_training_exp(self, strategy) -> None:
1060
        self._current_experience = strategy.experience.current_experience
1061
        # Reset average after learning from a new experience
1062
        SpecificityPluginMetric.reset(self, strategy)
1063
        return SpecificityPluginMetric.after_training_exp(self, strategy)
1064
1065
    def update(self, strategy):
1066
        """
1067
        Only update the Specificity with results from experiences that have been
1068
        trained on
1069
        """
1070
        if strategy.experience.current_experience <= self._current_experience:
1071
            SpecificityPluginMetric.update(self, strategy)
1072
1073
    def __str__(self):
1074
        return "Specificity_On_Trained_Experiences"
1075
1076
1077
def specificity_metrics(
1078
    *,
1079
    minibatch=False,
1080
    epoch=False,
1081
    epoch_running=False,
1082
    experience=False,
1083
    stream=False,
1084
    trained_experience=False,
1085
) -> List[PluginMetric]:
1086
    """
1087
    Helper method that can be used to obtain the desired set of
1088
    plugin metrics.
1089
1090
    :param minibatch: If True, will return a metric able to log
1091
        the minibatch Specificity at training time.
1092
    :param epoch: If True, will return a metric able to log
1093
        the epoch Specificity at training time.
1094
    :param epoch_running: If True, will return a metric able to log
1095
        the running epoch Specificity at training time.
1096
    :param experience: If True, will return a metric able to log
1097
        the Specificity on each evaluation experience.
1098
    :param stream: If True, will return a metric able to log
1099
        the Specificity averaged over the entire evaluation stream of experiences.
1100
    :param trained_experience: If True, will return a metric able to log
1101
        the average evaluation Specificity only for experiences that the
1102
        model has been trained on
1103
1104
    :return: A list of plugin metrics.
1105
    """
1106
1107
    metrics = []
1108
    if minibatch:
1109
        metrics.append(MinibatchSpecificity())
1110
1111
    if epoch:
1112
        metrics.append(EpochSpecificity())
1113
1114
    if epoch_running:
1115
        metrics.append(RunningEpochSpecificity())
1116
1117
    if experience:
1118
        metrics.append(ExperienceSpecificity())
1119
1120
    if stream:
1121
        metrics.append(StreamSpecificity())
1122
1123
    if trained_experience:
1124
        metrics.append(TrainedExperienceSpecificity())
1125
1126
    return metrics
1127
1128
1129
class Precision(Metric[float]):
1130
    """
1131
    The Precision metric. This is a standalone metric.
1132
1133
    The metric keeps a dictionary of <task_label, Precision value> pairs.
1134
    and update the values through a running average over multiple
1135
    <prediction, target> pairs of Tensors, provided incrementally.
1136
    The "prediction" and "target" tensors may contain plain labels or
1137
    one-hot/logit vectors.
1138
1139
    Each time `result` is called, this metric emits the average Precision
1140
    across all predictions made since the last `reset`.
1141
1142
    The reset method will bring the metric to its initial state. By default
1143
    this metric in its initial state will return an Precision value of 0.
1144
    """
1145
1146
    def __init__(self):
1147
        """
1148
        Creates an instance of the standalone Precision metric.
1149
1150
        By default this metric in its initial state will return a Precision
1151
        value of 0. The metric can be updated by using the `update` method
1152
        while the running Precision can be retrieved using the `result` method.
1153
        """
1154
        super().__init__()
1155
        self._mean_Precision = defaultdict(Mean)
1156
        """
1157
        The mean utility that will be used to store the running Precision
1158
        for each task label.
1159
        """
1160
1161
    @torch.no_grad()
1162
    def update(
1163
        self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor]
1164
    ) -> None:
1165
        """
1166
        Update the running Precision given the true and predicted labels.
1167
        Parameter `task_labels` is used to decide how to update the inner
1168
        dictionary: if Float, only the dictionary value related to that task
1169
        is updated. If Tensor, all the dictionary elements belonging to the
1170
        task labels will be updated.
1171
1172
        :param predicted_y: The model prediction. Both labels and logit vectors
1173
            are supported.
1174
        :param true_y: The ground truth. Both labels and one-hot vectors
1175
            are supported.
1176
        :param task_labels: the int task label associated to the current
1177
            experience or the task labels vector showing the task label
1178
            for each pattern.
1179
1180
        :return: None.
1181
        """
1182
        if len(true_y) != len(predicted_y):
1183
            raise ValueError("Size mismatch for true_y and predicted_y tensors")
1184
1185
        if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
1186
            raise ValueError("Size mismatch for true_y and task_labels tensors")
1187
1188
        true_y = torch.as_tensor(true_y)
1189
        predicted_y = torch.as_tensor(predicted_y)
1190
1191
        # Check if logits or labels
1192
        if len(predicted_y.shape) > 1:
1193
            # Logits -> transform to labels
1194
            predicted_y = torch.max(predicted_y, 1)[1]
1195
1196
        if len(true_y.shape) > 1:
1197
            # Logits -> transform to labels
1198
            true_y = torch.max(true_y, 1)[1]
1199
1200
        if isinstance(task_labels, int):
1201
            (
1202
                true_positives,
1203
                false_positives,
1204
                true_negatives,
1205
                false_negatives,
1206
            ) = confusion(predicted_y, true_y)
1207
1208
            try:
1209
                ppv = true_positives / (true_positives + false_positives)
1210
            except ZeroDivisionError:
1211
                ppv = 1
1212
1213
            self._mean_Precision[task_labels].update(ppv, len(predicted_y))
1214
        elif isinstance(task_labels, Tensor):
1215
            raise NotImplementedError
1216
        else:
1217
            raise ValueError(
1218
                f"Task label type: {type(task_labels)}, expected int/float or Tensor"
1219
            )
1220
1221
    def result(self, task_label=None) -> Dict[int, float]:
1222
        """
1223
        Retrieves the running Precision.
1224
1225
        Calling this method will not change the internal state of the metric.
1226
1227
        :param task_label: if None, return the entire dictionary of precisions
1228
            for each task. Otherwise return the dictionary
1229
            `{task_label: Precision}`.
1230
        :return: A dict of running precisions for each task label,
1231
            where each value is a float value between 0 and 1.
1232
        """
1233
        assert task_label is None or isinstance(task_label, int)
1234
        if task_label is None:
1235
            return {k: v.result() for k, v in self._mean_Precision.items()}
1236
        else:
1237
            return {task_label: self._mean_Precision[task_label].result()}
1238
1239
    def reset(self, task_label=None) -> None:
1240
        """
1241
        Resets the metric.
1242
        :param task_label: if None, reset the entire dictionary.
1243
            Otherwise, reset the value associated to `task_label`.
1244
1245
        :return: None.
1246
        """
1247
        assert task_label is None or isinstance(task_label, int)
1248
        if task_label is None:
1249
            self._mean_Precision = defaultdict(Mean)
1250
        else:
1251
            self._mean_Precision[task_label].reset()
1252
1253
1254
class PrecisionPluginMetric(GenericPluginMetric[float]):
1255
    """
1256
    Base class for all precisions plugin metrics
1257
    """
1258
1259
    def __init__(self, reset_at, emit_at, mode):
1260
        self._Precision = Precision()
1261
        super(PrecisionPluginMetric, self).__init__(
1262
            self._Precision, reset_at=reset_at, emit_at=emit_at, mode=mode
1263
        )
1264
1265
    def reset(self, strategy=None) -> None:
1266
        if self._reset_at == "stream" or strategy is None:
1267
            self._metric.reset()
1268
        else:
1269
            self._metric.reset(phase_and_task(strategy)[1])
1270
1271
    def result(self, strategy=None) -> float:
1272
        if self._emit_at == "stream" or strategy is None:
1273
            return self._metric.result()
1274
        else:
1275
            return self._metric.result(phase_and_task(strategy)[1])
1276
1277
    def update(self, strategy):
1278
        # task labels defined for each experience
1279
        task_labels = strategy.experience.task_labels
1280
        if len(task_labels) > 1:
1281
            # task labels defined for each pattern
1282
            task_labels = strategy.mb_task_id
1283
        else:
1284
            task_labels = task_labels[0]
1285
        self._Precision.update(strategy.mb_output, strategy.mb_y, task_labels)
1286
1287
1288
class MinibatchPrecision(PrecisionPluginMetric):
1289
    """
1290
    The minibatch plugin Precision metric.
1291
    This metric only works at training time.
1292
1293
    This metric computes the average Precision over patterns
1294
    from a single minibatch.
1295
    It reports the result after each iteration.
1296
1297
    If a more coarse-grained logging is needed, consider using
1298
    :class:`EpochPrecision` instead.
1299
    """
1300
1301
    def __init__(self):
1302
        """
1303
        Creates an instance of the MinibatchPrecision metric.
1304
        """
1305
        super(MinibatchPrecision, self).__init__(
1306
            reset_at="iteration", emit_at="iteration", mode="train"
1307
        )
1308
1309
    def __str__(self):
1310
        return "Prec_MB"
1311
1312
1313
class EpochPrecision(PrecisionPluginMetric):
1314
    """
1315
    The average Precision over a single training epoch.
1316
    This plugin metric only works at training time.
1317
1318
    The Precision will be logged after each training epoch by computing
1319
    the number of correctly predicted patterns during the epoch divided by
1320
    the overall number of patterns encountered in that epoch.
1321
    """
1322
1323
    def __init__(self):
1324
        """
1325
        Creates an instance of the EpochPrecision metric.
1326
        """
1327
1328
        super(EpochPrecision, self).__init__(
1329
            reset_at="epoch", emit_at="epoch", mode="train"
1330
        )
1331
1332
    def __str__(self):
1333
        return "Prec_Epoch"
1334
1335
1336
class RunningEpochPrecision(PrecisionPluginMetric):
1337
    """
1338
    The average Precision across all minibatches up to the current
1339
    epoch iteration.
1340
    This plugin metric only works at training time.
1341
1342
    At each iteration, this metric logs the Precision averaged over all patterns
1343
    seen so far in the current epoch.
1344
    The metric resets its state after each training epoch.
1345
    """
1346
1347
    def __init__(self):
1348
        """
1349
        Creates an instance of the RunningEpochPrecision metric.
1350
        """
1351
1352
        super(RunningEpochPrecision, self).__init__(
1353
            reset_at="epoch", emit_at="iteration", mode="train"
1354
        )
1355
1356
    def __str__(self):
1357
        return "RunningPrec_Epoch"
1358
1359
1360
class ExperiencePrecision(PrecisionPluginMetric):
1361
    """
1362
    At the end of each experience, this plugin metric reports
1363
    the average Precision over all patterns seen in that experience.
1364
    This metric only works at eval time.
1365
    """
1366
1367
    def __init__(self):
1368
        """
1369
        Creates an instance of ExperiencePrecision metric
1370
        """
1371
        super(ExperiencePrecision, self).__init__(
1372
            reset_at="experience", emit_at="experience", mode="eval"
1373
        )
1374
1375
    def __str__(self):
1376
        return "Prec_Exp"
1377
1378
1379
class StreamPrecision(PrecisionPluginMetric):
1380
    """
1381
    At the end of the entire stream of experiences, this plugin metric
1382
    reports the average Precision over all patterns seen in all experiences.
1383
    This metric only works at eval time.
1384
    """
1385
1386
    def __init__(self):
1387
        """
1388
        Creates an instance of StreamPrecision metric
1389
        """
1390
        super(StreamPrecision, self).__init__(
1391
            reset_at="stream", emit_at="stream", mode="eval"
1392
        )
1393
1394
    def __str__(self):
1395
        return "Prec_Stream"
1396
1397
1398
class TrainedExperiencePrecision(PrecisionPluginMetric):
1399
    """
1400
    At the end of each experience, this plugin metric reports the average
1401
    Precision for only the experiences that the model has been trained on so far.
1402
1403
    This metric only works at eval time.
1404
    """
1405
1406
    def __init__(self):
1407
        """
1408
        Creates an instance of TrainedExperiencePrecision metric by first
1409
        constructing PrecisionPluginMetric
1410
        """
1411
        super(TrainedExperiencePrecision, self).__init__(
1412
            reset_at="stream", emit_at="stream", mode="eval"
1413
        )
1414
        self._current_experience = 0
1415
1416
    def after_training_exp(self, strategy) -> None:
1417
        self._current_experience = strategy.experience.current_experience
1418
        # Reset average after learning from a new experience
1419
        PrecisionPluginMetric.reset(self, strategy)
1420
        return PrecisionPluginMetric.after_training_exp(self, strategy)
1421
1422
    def update(self, strategy):
1423
        """
1424
        Only update the Precision with results from experiences that have been
1425
        trained on
1426
        """
1427
        if strategy.experience.current_experience <= self._current_experience:
1428
            PrecisionPluginMetric.update(self, strategy)
1429
1430
    def __str__(self):
1431
        return "Precision_On_Trained_Experiences"
1432
1433
1434
def precision_metrics(
1435
    *,
1436
    minibatch=False,
1437
    epoch=False,
1438
    epoch_running=False,
1439
    experience=False,
1440
    stream=False,
1441
    trained_experience=False,
1442
) -> List[PluginMetric]:
1443
    """
1444
    Helper method that can be used to obtain the desired set of
1445
    plugin metrics.
1446
1447
    :param minibatch: If True, will return a metric able to log
1448
        the minibatch Precision at training time.
1449
    :param epoch: If True, will return a metric able to log
1450
        the epoch Precision at training time.
1451
    :param epoch_running: If True, will return a metric able to log
1452
        the running epoch Precision at training time.
1453
    :param experience: If True, will return a metric able to log
1454
        the Precision on each evaluation experience.
1455
    :param stream: If True, will return a metric able to log
1456
        the Precision averaged over the entire evaluation stream of experiences.
1457
    :param trained_experience: If True, will return a metric able to log
1458
        the average evaluation Precision only for experiences that the
1459
        model has been trained on
1460
1461
    :return: A list of plugin metrics.
1462
    """
1463
1464
    metrics = []
1465
    if minibatch:
1466
        metrics.append(MinibatchPrecision())
1467
1468
    if epoch:
1469
        metrics.append(EpochPrecision())
1470
1471
    if epoch_running:
1472
        metrics.append(RunningEpochPrecision())
1473
1474
    if experience:
1475
        metrics.append(ExperiencePrecision())
1476
1477
    if stream:
1478
        metrics.append(StreamPrecision())
1479
1480
    if trained_experience:
1481
        metrics.append(TrainedExperiencePrecision())
1482
1483
    return metrics
1484
1485
1486
class AUPRC(Metric[float]):
1487
    """
1488
    The AUPRC metric. This is a standalone metric.
1489
1490
    The metric keeps a dictionary of <task_label, AUPRC value> pairs.
1491
    and update the values through a running average over multiple
1492
    <prediction, target> pairs of Tensors, provided incrementally.
1493
    The "prediction" and "target" tensors may contain plain labels or
1494
    one-hot/logit vectors.
1495
1496
    Each time `result` is called, this metric emits the average AUPRC
1497
    across all predictions made since the last `reset`.
1498
1499
    The reset method will bring the metric to its initial state. By default
1500
    this metric in its initial state will return an AUPRC value of 0.
1501
    """
1502
1503
    def __init__(self):
1504
        """
1505
        Creates an instance of the standalone AUPRC metric.
1506
1507
        By default this metric in its initial state will return a AUPRC
1508
        value of 0. The metric can be updated by using the `update` method
1509
        while the running AUPRC can be retrieved using the `result` method.
1510
        """
1511
        super().__init__()
1512
        self._mean_AUPRC = defaultdict(Mean)
1513
        """
1514
        The mean utility that will be used to store the running AUPRC
1515
        for each task label.
1516
        """
1517
1518
    @torch.no_grad()
1519
    def update(
1520
        self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor]
1521
    ) -> None:
1522
        """
1523
        Update the running AUPRC given the true and predicted labels.
1524
        Parameter `task_labels` is used to decide how to update the inner
1525
        dictionary: if Float, only the dictionary value related to that task
1526
        is updated. If Tensor, all the dictionary elements belonging to the
1527
        task labels will be updated.
1528
1529
        :param predicted_y: The model prediction. Both labels and logit vectors
1530
            are supported.
1531
        :param true_y: The ground truth. Both labels and one-hot vectors
1532
            are supported.
1533
        :param task_labels: the int task label associated to the current
1534
            experience or the task labels vector showing the task label
1535
            for each pattern.
1536
1537
        :return: None.
1538
        """
1539
        if len(true_y) != len(predicted_y):
1540
            raise ValueError("Size mismatch for true_y and predicted_y tensors")
1541
1542
        if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
1543
            raise ValueError("Size mismatch for true_y and task_labels tensors")
1544
1545
        true_y = torch.as_tensor(true_y)
1546
        predicted_y = torch.as_tensor(predicted_y)
1547
1548
        assert len(predicted_y.size()) == 2, (
1549
            "Predictions need to be logits or scores, not labels"
1550
        )
1551
1552
        if len(true_y.shape) > 1:
1553
            # Logits -> transform to labels
1554
            true_y = torch.max(true_y, 1)[1]
1555
1556
        scores = predicted_y[arange(len(true_y)), true_y]
1557
1558
        with np.errstate(divide="ignore", invalid="ignore"):
1559
            average_precision_score_val = average_precision_score(
1560
                true_y.cpu(), scores.cpu()
1561
            )
1562
1563
            if np.isnan(average_precision_score_val):
1564
                average_precision_score_val = 0
1565
1566
        if isinstance(task_labels, int):
1567
            self._mean_AUPRC[task_labels].update(
1568
                average_precision_score_val, len(predicted_y)
1569
            )
1570
        elif isinstance(task_labels, Tensor):
1571
            raise NotImplementedError
1572
        else:
1573
            raise ValueError(
1574
                f"Task label type: {type(task_labels)}, expected int/float or Tensor"
1575
            )
1576
1577
    def result(self, task_label=None) -> Dict[int, float]:
1578
        """
1579
        Retrieves the running AUPRC.
1580
1581
        Calling this method will not change the internal state of the metric.
1582
1583
        :param task_label: if None, return the entire dictionary of AUPRCs
1584
            for each task. Otherwise return the dictionary
1585
            `{task_label: AUPRC}`.
1586
        :return: A dict of running AUPRCs for each task label,
1587
            where each value is a float value between 0 and 1.
1588
        """
1589
        assert task_label is None or isinstance(task_label, int)
1590
        if task_label is None:
1591
            return {k: v.result() for k, v in self._mean_AUPRC.items()}
1592
        else:
1593
            return {task_label: self._mean_AUPRC[task_label].result()}
1594
1595
    def reset(self, task_label=None) -> None:
1596
        """
1597
        Resets the metric.
1598
        :param task_label: if None, reset the entire dictionary.
1599
            Otherwise, reset the value associated to `task_label`.
1600
1601
        :return: None.
1602
        """
1603
        assert task_label is None or isinstance(task_label, int)
1604
        if task_label is None:
1605
            self._mean_AUPRC = defaultdict(Mean)
1606
        else:
1607
            self._mean_AUPRC[task_label].reset()
1608
1609
1610
class AUPRCPluginMetric(GenericPluginMetric[float]):
1611
    """
1612
    Base class for all AUPRCs plugin metrics
1613
    """
1614
1615
    def __init__(self, reset_at, emit_at, mode):
1616
        self._AUPRC = AUPRC()
1617
        super(AUPRCPluginMetric, self).__init__(
1618
            self._AUPRC, reset_at=reset_at, emit_at=emit_at, mode=mode
1619
        )
1620
1621
    def reset(self, strategy=None) -> None:
1622
        if self._reset_at == "stream" or strategy is None:
1623
            self._metric.reset()
1624
        else:
1625
            self._metric.reset(phase_and_task(strategy)[1])
1626
1627
    def result(self, strategy=None) -> float:
1628
        if self._emit_at == "stream" or strategy is None:
1629
            return self._metric.result()
1630
        else:
1631
            return self._metric.result(phase_and_task(strategy)[1])
1632
1633
    def update(self, strategy):
1634
        # task labels defined for each experience
1635
        task_labels = strategy.experience.task_labels
1636
        if len(task_labels) > 1:
1637
            # task labels defined for each pattern
1638
            task_labels = strategy.mb_task_id
1639
        else:
1640
            task_labels = task_labels[0]
1641
        self._AUPRC.update(strategy.mb_output, strategy.mb_y, task_labels)
1642
1643
1644
class MinibatchAUPRC(AUPRCPluginMetric):
1645
    """
1646
    The minibatch plugin AUPRC metric.
1647
    This metric only works at training time.
1648
1649
    This metric computes the average AUPRC over patterns
1650
    from a single minibatch.
1651
    It reports the result after each iteration.
1652
1653
    If a more coarse-grained logging is needed, consider using
1654
    :class:`EpochAUPRC` instead.
1655
    """
1656
1657
    def __init__(self):
1658
        """
1659
        Creates an instance of the MinibatchAUPRC metric.
1660
        """
1661
        super(MinibatchAUPRC, self).__init__(
1662
            reset_at="iteration", emit_at="iteration", mode="train"
1663
        )
1664
1665
    def __str__(self):
1666
        return "AUPRC_MB"
1667
1668
1669
class EpochAUPRC(AUPRCPluginMetric):
1670
    """
1671
    The average AUPRC over a single training epoch.
1672
    This plugin metric only works at training time.
1673
1674
    The AUPRC will be logged after each training epoch by computing
1675
    the number of correctly predicted patterns during the epoch divided by
1676
    the overall number of patterns encountered in that epoch.
1677
    """
1678
1679
    def __init__(self):
1680
        """
1681
        Creates an instance of the EpochAUPRC metric.
1682
        """
1683
1684
        super(EpochAUPRC, self).__init__(
1685
            reset_at="epoch", emit_at="epoch", mode="train"
1686
        )
1687
1688
    def __str__(self):
1689
        return "AUPRC_Epoch"
1690
1691
1692
class RunningEpochAUPRC(AUPRCPluginMetric):
1693
    """
1694
    The average AUPRC across all minibatches up to the current
1695
    epoch iteration.
1696
    This plugin metric only works at training time.
1697
1698
    At each iteration, this metric logs the AUPRC averaged over all patterns
1699
    seen so far in the current epoch.
1700
    The metric resets its state after each training epoch.
1701
    """
1702
1703
    def __init__(self):
1704
        """
1705
        Creates an instance of the RunningEpochAUPRC metric.
1706
        """
1707
1708
        super(RunningEpochAUPRC, self).__init__(
1709
            reset_at="epoch", emit_at="iteration", mode="train"
1710
        )
1711
1712
    def __str__(self):
1713
        return "RunningAUPRC_Epoch"
1714
1715
1716
class ExperienceAUPRC(AUPRCPluginMetric):
1717
    """
1718
    At the end of each experience, this plugin metric reports
1719
    the average AUPRC over all patterns seen in that experience.
1720
    This metric only works at eval time.
1721
    """
1722
1723
    def __init__(self):
1724
        """
1725
        Creates an instance of ExperienceAUPRC metric
1726
        """
1727
        super(ExperienceAUPRC, self).__init__(
1728
            reset_at="experience", emit_at="experience", mode="eval"
1729
        )
1730
1731
    def __str__(self):
1732
        return "AUPRC_Exp"
1733
1734
1735
class StreamAUPRC(AUPRCPluginMetric):
1736
    """
1737
    At the end of the entire stream of experiences, this plugin metric
1738
    reports the average AUPRC over all patterns seen in all experiences.
1739
    This metric only works at eval time.
1740
    """
1741
1742
    def __init__(self):
1743
        """
1744
        Creates an instance of StreamAUPRC metric
1745
        """
1746
        super(StreamAUPRC, self).__init__(
1747
            reset_at="stream", emit_at="stream", mode="eval"
1748
        )
1749
1750
    def __str__(self):
1751
        return "AUPRC_Stream"
1752
1753
1754
class TrainedExperienceAUPRC(AUPRCPluginMetric):
1755
    """
1756
    At the end of each experience, this plugin metric reports the average
1757
    AUPRC for only the experiences that the model has been trained on so far.
1758
1759
    This metric only works at eval time.
1760
    """
1761
1762
    def __init__(self):
1763
        """
1764
        Creates an instance of TrainedExperienceAUPRC metric by first
1765
        constructing AUPRCPluginMetric
1766
        """
1767
        super(TrainedExperienceAUPRC, self).__init__(
1768
            reset_at="stream", emit_at="stream", mode="eval"
1769
        )
1770
        self._current_experience = 0
1771
1772
    def after_training_exp(self, strategy) -> None:
1773
        self._current_experience = strategy.experience.current_experience
1774
        # Reset average after learning from a new experience
1775
        AUPRCPluginMetric.reset(self, strategy)
1776
        return AUPRCPluginMetric.after_training_exp(self, strategy)
1777
1778
    def update(self, strategy):
1779
        """
1780
        Only update the AUPRC with results from experiences that have been
1781
        trained on
1782
        """
1783
        if strategy.experience.current_experience <= self._current_experience:
1784
            AUPRCPluginMetric.update(self, strategy)
1785
1786
    def __str__(self):
1787
        return "AUPRC_On_Trained_Experiences"
1788
1789
1790
def auprc_metrics(
1791
    *,
1792
    minibatch=False,
1793
    epoch=False,
1794
    epoch_running=False,
1795
    experience=False,
1796
    stream=False,
1797
    trained_experience=False,
1798
) -> List[PluginMetric]:
1799
    """
1800
    Helper method that can be used to obtain the desired set of
1801
    plugin metrics.
1802
1803
    :param minibatch: If True, will return a metric able to log
1804
        the minibatch AUPRC at training time.
1805
    :param epoch: If True, will return a metric able to log
1806
        the epoch AUPRC at training time.
1807
    :param epoch_running: If True, will return a metric able to log
1808
        the running epoch AUPRC at training time.
1809
    :param experience: If True, will return a metric able to log
1810
        the AUPRC on each evaluation experience.
1811
    :param stream: If True, will return a metric able to logAUPRCperiences.
1812
    :param trained_experience: If True, will return a metric able to log
1813
        the average evaluation AUPRC only for experiences that the
1814
        model has been trained on
1815
1816
    :return: A list of plugin metrics.
1817
    """
1818
1819
    metrics = []
1820
    if minibatch:
1821
        metrics.append(MinibatchAUPRC())
1822
1823
    if epoch:
1824
        metrics.append(EpochAUPRC())
1825
1826
    if epoch_running:
1827
        metrics.append(RunningEpochAUPRC())
1828
1829
    if experience:
1830
        metrics.append(ExperienceAUPRC())
1831
1832
    if stream:
1833
        metrics.append(StreamAUPRC())
1834
1835
    if trained_experience:
1836
        metrics.append(TrainedExperienceAUPRC())
1837
1838
    return metrics
1839
1840
1841
class ROCAUC(Metric[float]):
1842
    """
1843
    The ROCAUC metric. This is a standalone metric.
1844
1845
    The metric keeps a dictionary of <task_label, ROCAUC value> pairs.
1846
    and update the values through a running average over multiple
1847
    <prediction, target> pairs of Tensors, provided incrementally.
1848
    The "prediction" and "target" tensors may contain plain labels or
1849
    one-hot/logit vectors.
1850
1851
    Each time `result` is called, this metric emits the average ROCAUC
1852
    across all predictions made since the last `reset`.
1853
1854
    The reset method will bring the metric to its initial state. By default
1855
    this metric in its initial state will return an ROCAUC value of 0.
1856
    """
1857
1858
    def __init__(self):
1859
        """
1860
        Creates an instance of the standalone ROCAUC metric.
1861
1862
        By default this metric in its initial state will return a ROCAUC
1863
        value of 0. The metric can be updated by using the `update` method
1864
        while the running ROCAUC can be retrieved using the `result` method.
1865
        """
1866
        super().__init__()
1867
        self._mean_ROCAUC = defaultdict(Mean)
1868
        """
1869
        The mean utility that will be used to store the running ROCAUC
1870
        for each task label.
1871
        """
1872
1873
    @torch.no_grad()
1874
    def update(
1875
        self, predicted_y: Tensor, true_y: Tensor, task_labels: Union[float, Tensor]
1876
    ) -> None:
1877
        """
1878
        Update the running ROCAUC given the true and predicted labels.
1879
        Parameter `task_labels` is used to decide how to update the inner
1880
        dictionary: if Float, only the dictionary value related to that task
1881
        is updated. If Tensor, all the dictionary elements belonging to the
1882
        task labels will be updated.
1883
1884
        :param predicted_y: The model prediction. Both labels and logit vectors
1885
            are supported.
1886
        :param true_y: The ground truth. Both labels and one-hot vectors
1887
            are supported.
1888
        :param task_labels: the int task label associated to the current
1889
            experience or the task labels vector showing the task label
1890
            for each pattern.
1891
1892
        :return: None.
1893
        """
1894
        if len(true_y) != len(predicted_y):
1895
            raise ValueError("Size mismatch for true_y and predicted_y tensors")
1896
1897
        if isinstance(task_labels, Tensor) and len(task_labels) != len(true_y):
1898
            raise ValueError("Size mismatch for true_y and task_labels tensors")
1899
1900
        true_y = torch.as_tensor(true_y)
1901
        predicted_y = torch.as_tensor(predicted_y)
1902
1903
        assert len(predicted_y.size()) == 2, (
1904
            "Predictions need to be logits or scores, not labels"
1905
        )
1906
1907
        if len(true_y.shape) > 1:
1908
            # Logits -> transform to labels
1909
            true_y = torch.max(true_y, 1)[1]
1910
1911
        scores = predicted_y[arange(len(true_y)), true_y]
1912
1913
        try:
1914
            roc_auc_score_val = roc_auc_score(true_y.cpu(), scores.cpu())
1915
        except ValueError:
1916
            roc_auc_score_val = 1
1917
1918
        if isinstance(task_labels, int):
1919
            self._mean_ROCAUC[task_labels].update(roc_auc_score_val, len(predicted_y))
1920
        elif isinstance(task_labels, Tensor):
1921
            raise NotImplementedError
1922
        else:
1923
            raise ValueError(
1924
                f"Task label type: {type(task_labels)}, expected int/float or Tensor"
1925
            )
1926
1927
    def result(self, task_label=None) -> Dict[int, float]:
1928
        """
1929
        Retrieves the running ROCAUC.
1930
1931
        Calling this method will not change the internal state of the metric.
1932
1933
        :param task_label: if None, return the entire dictionary of ROCAUCs
1934
            for each task. Otherwise return the dictionary
1935
            `{task_label: ROCAUC}`.
1936
        :return: A dict of running ROCAUCs for each task label,
1937
            where each value is a float value between 0 and 1.
1938
        """
1939
        assert task_label is None or isinstance(task_label, int)
1940
        if task_label is None:
1941
            return {k: v.result() for k, v in self._mean_ROCAUC.items()}
1942
        else:
1943
            return {task_label: self._mean_ROCAUC[task_label].result()}
1944
1945
    def reset(self, task_label=None) -> None:
1946
        """
1947
        Resets the metric.
1948
        :param task_label: if None, reset the entire dictionary.
1949
            Otherwise, reset the value associated to `task_label`.
1950
1951
        :return: None.
1952
        """
1953
        assert task_label is None or isinstance(task_label, int)
1954
        if task_label is None:
1955
            self._mean_ROCAUC = defaultdict(Mean)
1956
        else:
1957
            self._mean_ROCAUC[task_label].reset()
1958
1959
1960
class ROCAUCPluginMetric(GenericPluginMetric[float]):
1961
    """
1962
    Base class for all ROCAUCs plugin metrics
1963
    """
1964
1965
    def __init__(self, reset_at, emit_at, mode):
1966
        self._ROCAUC = ROCAUC()
1967
        super(ROCAUCPluginMetric, self).__init__(
1968
            self._ROCAUC, reset_at=reset_at, emit_at=emit_at, mode=mode
1969
        )
1970
1971
    def reset(self, strategy=None) -> None:
1972
        if self._reset_at == "stream" or strategy is None:
1973
            self._metric.reset()
1974
        else:
1975
            self._metric.reset(phase_and_task(strategy)[1])
1976
1977
    def result(self, strategy=None) -> float:
1978
        if self._emit_at == "stream" or strategy is None:
1979
            return self._metric.result()
1980
        else:
1981
            return self._metric.result(phase_and_task(strategy)[1])
1982
1983
    def update(self, strategy):
1984
        # task labels defined for each experience
1985
        task_labels = strategy.experience.task_labels
1986
        if len(task_labels) > 1:
1987
            # task labels defined for each pattern
1988
            task_labels = strategy.mb_task_id
1989
        else:
1990
            task_labels = task_labels[0]
1991
        self._ROCAUC.update(strategy.mb_output, strategy.mb_y, task_labels)
1992
1993
1994
class MinibatchROCAUC(ROCAUCPluginMetric):
1995
    """
1996
    The minibatch plugin ROCAUC metric.
1997
    This metric only works at training time.
1998
1999
    This metric computes the average ROCAUC over patterns
2000
    from a single minibatch.
2001
    It reports the result after each iteration.
2002
2003
    If a more coarse-grained logging is needed, consider using
2004
    :class:`EpochROCAUC` instead.
2005
    """
2006
2007
    def __init__(self):
2008
        """
2009
        Creates an instance of the MinibatchROCAUC metric.
2010
        """
2011
        super(MinibatchROCAUC, self).__init__(
2012
            reset_at="iteration", emit_at="iteration", mode="train"
2013
        )
2014
2015
    def __str__(self):
2016
        return "ROCAUC_MB"
2017
2018
2019
class EpochROCAUC(ROCAUCPluginMetric):
2020
    """
2021
    The average ROCAUC over a single training epoch.
2022
    This plugin metric only works at training time.
2023
2024
    The ROCAUC will be logged after each training epoch by computing
2025
    the number of correctly predicted patterns during the epoch divided by
2026
    the overall number of patterns encountered in that epoch.
2027
    """
2028
2029
    def __init__(self):
2030
        """
2031
        Creates an instance of the EpochROCAUC metric.
2032
        """
2033
2034
        super(EpochROCAUC, self).__init__(
2035
            reset_at="epoch", emit_at="epoch", mode="train"
2036
        )
2037
2038
    def __str__(self):
2039
        return "ROCAUC_Epoch"
2040
2041
2042
class RunningEpochROCAUC(ROCAUCPluginMetric):
2043
    """
2044
    The average ROCAUC across all minibatches up to the current
2045
    epoch iteration.
2046
    This plugin metric only works at training time.
2047
2048
    At each iteration, this metric logs the ROCAUC averaged over all patterns
2049
    seen so far in the current epoch.
2050
    The metric resets its state after each training epoch.
2051
    """
2052
2053
    def __init__(self):
2054
        """
2055
        Creates an instance of the RunningEpochROCAUC metric.
2056
        """
2057
2058
        super(RunningEpochROCAUC, self).__init__(
2059
            reset_at="epoch", emit_at="iteration", mode="train"
2060
        )
2061
2062
    def __str__(self):
2063
        return "RunningROCAUC_Epoch"
2064
2065
2066
class ExperienceROCAUC(ROCAUCPluginMetric):
2067
    """
2068
    At the end of each experience, this plugin metric reports
2069
    the average ROCAUC over all patterns seen in that experience.
2070
    This metric only works at eval time.
2071
    """
2072
2073
    def __init__(self):
2074
        """
2075
        Creates an instance of ExperienceROCAUC metric
2076
        """
2077
        super(ExperienceROCAUC, self).__init__(
2078
            reset_at="experience", emit_at="experience", mode="eval"
2079
        )
2080
2081
    def __str__(self):
2082
        return "ROCAUC_Exp"
2083
2084
2085
class StreamROCAUC(ROCAUCPluginMetric):
2086
    """
2087
    At the end of the entire stream of experiences, this plugin metric
2088
    reports the average ROCAUC over all patterns seen in all experiences.
2089
    This metric only works at eval time.
2090
    """
2091
2092
    def __init__(self):
2093
        """
2094
        Creates an instance of StreamROCAUC metric
2095
        """
2096
        super(StreamROCAUC, self).__init__(
2097
            reset_at="stream", emit_at="stream", mode="eval"
2098
        )
2099
2100
    def __str__(self):
2101
        return "ROCAUC_Stream"
2102
2103
2104
class TrainedExperienceROCAUC(ROCAUCPluginMetric):
2105
    """
2106
    At the end of each experience, this plugin metric reports the average
2107
    ROCAUC for only the experiences that the model has been trained on so far.
2108
2109
    This metric only works at eval time.
2110
    """
2111
2112
    def __init__(self):
2113
        """
2114
        Creates an instance of TrainedExperienceROCAUC metric by first
2115
        constructing ROCAUCPluginMetric
2116
        """
2117
        super(TrainedExperienceROCAUC, self).__init__(
2118
            reset_at="stream", emit_at="stream", mode="eval"
2119
        )
2120
        self._current_experience = 0
2121
2122
    def after_training_exp(self, strategy) -> None:
2123
        self._current_experience = strategy.experience.current_experience
2124
        # Reset average after learning from a new experience
2125
        ROCAUCPluginMetric.reset(self, strategy)
2126
        return ROCAUCPluginMetric.after_training_exp(self, strategy)
2127
2128
    def update(self, strategy):
2129
        """
2130
        Only update the ROCAUC with results from experiences that have been
2131
        trained on
2132
        """
2133
        if strategy.experience.current_experience <= self._current_experience:
2134
            ROCAUCPluginMetric.update(self, strategy)
2135
2136
    def __str__(self):
2137
        return "ROCAUC_On_Trained_Experiences"
2138
2139
2140
def rocauc_metrics(
2141
    *,
2142
    minibatch=False,
2143
    epoch=False,
2144
    epoch_running=False,
2145
    experience=False,
2146
    stream=False,
2147
    trained_experience=False,
2148
) -> List[PluginMetric]:
2149
    """
2150
    Helper method that can be used to obtain the desired set of
2151
    plugin metrics.
2152
2153
    :param minibatch: If True, will return a metric able to log
2154
        the minibatch ROCAUC at training time.
2155
    :param epoch: If True, will return a metric able to log
2156
        the epoch ROCAUC at training time.
2157
    :param epoch_running: If True, will return a metric able to log
2158
        the running epoch ROCAUC at training time.
2159
    :param experience: If True, will return a metric able to log
2160
        the ROCAUC on each evaluation experience.
2161
    :param stream: If True, will return a metric able to logROCAUCperiences.
2162
    :param trained_experience: If True, will return a metric able to log
2163
        the average evaluation ROCAUC only for experiences that the
2164
        model has been trained on
2165
2166
    :return: A list of plugin metrics.
2167
    """
2168
2169
    metrics = []
2170
    if minibatch:
2171
        metrics.append(MinibatchROCAUC())
2172
2173
    if epoch:
2174
        metrics.append(EpochROCAUC())
2175
2176
    if epoch_running:
2177
        metrics.append(RunningEpochROCAUC())
2178
2179
    if experience:
2180
        metrics.append(ExperienceROCAUC())
2181
2182
    if stream:
2183
        metrics.append(StreamROCAUC())
2184
2185
    if trained_experience:
2186
        metrics.append(TrainedExperienceROCAUC())
2187
2188
    return metrics
2189
2190
2191
__all__ = [
2192
    "BalancedAccuracy",
2193
    "MinibatchBalancedAccuracy",
2194
    "EpochBalancedAccuracy",
2195
    "RunningEpochBalancedAccuracy",
2196
    "ExperienceBalancedAccuracy",
2197
    "StreamBalancedAccuracy",
2198
    "TrainedExperienceBalancedAccuracy",
2199
    "balancedaccuracy_metrics",
2200
    "Sensitivity",
2201
    "MinibatchSensitivity",
2202
    "EpochSensitivity",
2203
    "RunningEpochSensitivity",
2204
    "ExperienceSensitivity",
2205
    "StreamSensitivity",
2206
    "TrainedExperienceSensitivity",
2207
    "sensitivity_metrics",
2208
    "Specificity",
2209
    "MinibatchSpecificity",
2210
    "EpochSpecificity",
2211
    "RunningEpochSpecificity",
2212
    "ExperienceSpecificity",
2213
    "StreamSpecificity",
2214
    "TrainedExperienceSpecificity",
2215
    "specificity_metrics",
2216
    "Precision",
2217
    "MinibatchPrecision",
2218
    "EpochPrecision",
2219
    "RunningEpochPrecision",
2220
    "ExperiencePrecision",
2221
    "StreamPrecision",
2222
    "TrainedExperiencePrecision",
2223
    "precision_metrics",
2224
    "AUPRC",
2225
    "MinibatchAUPRC",
2226
    "EpochAUPRC",
2227
    "RunningEpochAUPRC",
2228
    "ExperienceAUPRC",
2229
    "StreamAUPRC",
2230
    "TrainedExperienceAUPRC",
2231
    "auprc_metrics",
2232
    "ROCAUC",
2233
    "MinibatchROCAUC",
2234
    "EpochROCAUC",
2235
    "RunningEpochROCAUC",
2236
    "ExperienceROCAUC",
2237
    "StreamROCAUC",
2238
    "TrainedExperienceROCAUC",
2239
    "rocauc_metrics",
2240
]