Switch to side-by-side view

--- a
+++ b/src/extensions/baal/active_loop.py
@@ -0,0 +1,134 @@
+# Base Dependencies
+# -----------------
+import os
+import pickle
+import types
+import numpy as np
+import structlog
+
+from typing import Tuple, Callable
+
+
+# 3rd-Party Dependencies
+# ----------------------
+import torch.utils.data as torchdata
+
+from baal.active import ActiveLearningLoop
+from baal.active.heuristics import heuristics
+from baal.active.dataset import ActiveLearningDataset
+
+log = structlog.get_logger(__name__)
+pjoin = os.path.join
+
+
+class MyActiveLearningLoop(ActiveLearningLoop):
+    def __init__(
+        self,
+        dataset: ActiveLearningDataset,
+        get_probabilities: Callable,
+        heuristic: heuristics.AbstractHeuristic = heuristics.Random(),
+        query_size: int = 1,
+        max_sample=-1,
+        uncertainty_folder=None,
+        ndata_to_label=None,
+        **kwargs,
+    ) -> None:
+        super().__init__(
+            dataset,
+            get_probabilities,
+            heuristic,
+            query_size,
+            max_sample,
+            uncertainty_folder,
+            ndata_to_label,
+            **kwargs,
+        )
+        self.step_acc = []
+        self.step_score = []
+
+    def compute_step_metrics(self, probs: np.array, to_label: list):
+        """
+        Register the accuracy and the avg. prediction score the trained model has on the queried instances
+
+        Args:
+            probs (np.array): probabilities of the model on the unlabelled pool
+            to_label (list): list of indices to be labelled
+        """
+        pool = self.dataset.pool
+
+        # obtain true labels of queried examples
+        y_true = []
+        for idx in to_label[: self.query_size]:
+            y_true.append(pool[idx]["label"])
+        y_true = np.array(y_true)
+
+        # obtain predicted labels of queried examples
+        # 1. avg over MC Dropout iterations to obtain prob per class
+        avg_iter_probs = np.mean(probs[: self.query_size], axis=2)
+        # 2. get class with highest prob
+        y_pred = np.argmax(avg_iter_probs, axis=1)
+        assert len(y_pred) == len(y_true)
+
+        # accuracy on the queried examples
+        acc = np.mean(y_true == y_pred)
+        self.step_acc.append(acc)
+
+        # average predicted score on  true class
+        avg_probs = []
+        for true_class, classes_probs in zip(y_true, avg_iter_probs):
+            avg_probs.append(classes_probs[true_class])
+
+        self.step_score.append(np.mean(avg_probs))
+
+    def step(self, pool=None) -> Tuple[bool, dict]:
+        """
+        Perform an active learning step.
+        Args:
+            pool (iterable): Optional dataset pool indices.
+                             If not set, will use pool from the active set.
+        Returns:
+            boolean, Flag indicating if we continue training.
+        """
+        if pool is None:
+            pool = self.dataset.pool
+            if len(pool) > 0:
+                # Limit number of samples
+                if self.max_sample != -1 and self.max_sample < len(pool):
+                    indices = np.random.choice(
+                        len(pool), self.max_sample, replace=False
+                    )
+                    pool = torchdata.Subset(pool, indices)
+                else:
+                    indices = np.arange(len(pool))
+        else:
+            indices = None
+
+        if len(pool) > 0:
+            probs = self.get_probabilities(pool, **self.kwargs)
+            if probs is not None and (
+                isinstance(probs, types.GeneratorType) or len(probs) > 0
+            ):
+                to_label, uncertainty = self.heuristic.get_ranks(probs)
+                if indices is not None:
+                    # re-order the sampled indices based on the uncertainty
+                    to_label = indices[np.array(to_label)]
+                if self.uncertainty_folder is not None:
+                    # We save uncertainty in a file.
+                    uncertainty_name = (
+                        f"uncertainty_pool={len(pool)}"
+                        f"_labelled={len(self.dataset)}.pkl"
+                    )
+                    pickle.dump(
+                        {
+                            "indices": indices,
+                            "uncertainty": uncertainty,
+                            "dataset": self.dataset.state_dict(),
+                        },
+                        open(pjoin(self.uncertainty_folder, uncertainty_name), "wb"),
+                    )
+                if len(to_label) > 0:
+                    self.compute_step_metrics(probs, to_label)
+                    self.dataset.label(to_label[: self.query_size])
+                    return True
+
+        return False