Switch to side-by-side view

--- a
+++ b/src/extensions/baal/dataset.py
@@ -0,0 +1,188 @@
+# Base Dependencies
+# -----------------
+import numpy as np
+from typing import List, Union, Any
+
+# 3rd-Party Dependencies
+# ----------------------
+import torch
+
+from baal.active import ActiveLearningDataset
+from baal.active.dataset.base import Dataset
+from datasets import Dataset as HFDataset
+
+
+# Auxiliar Functions 
+# ------------------
+def my_active_huggingface_dataset(
+    dataset,
+    tokenizer=None,
+    target_key: str = "label",
+    input_key: str = "sentence",
+    max_seq_len: int = 128,
+    **kwargs
+):
+    """
+    Wrapping huggingface.datasets with baal.active.ActiveLearningDataset.
+    Args:
+        dataset (torch.utils.data.Dataset): a dataset provided by huggingface.
+        tokenizer (transformers.PreTrainedTokenizer): a tokenizer provided by huggingface.
+        target_key (str): target key used in the dataset's dictionary.
+        input_key (str): input key used in the dataset's dictionary.
+        max_seq_len (int): max length of a sequence to be used for padding the shorter sequences.
+        kwargs (Dict): Parameters forwarded to 'ActiveLearningDataset'.
+    Returns:
+        an baal.active.ActiveLearningDataset object.
+    """
+
+    return MyActiveLearningDatasetBert(
+        MyHuggingFaceDatasets(dataset, tokenizer, target_key, input_key, max_seq_len),
+        **kwargs
+    )
+
+
+# Datasets
+# --------
+class MyHuggingFaceDatasets(Dataset):
+    """
+    Support for `huggingface.datasets`: (https://github.com/huggingface/datasets).
+    The purpose of this wrapper is to separate the labels from the rest of the sample information
+    and make the dataset ready to be used by `baal.active.ActiveLearningDataset`.
+    Args:
+        dataset (Dataset): a dataset provided by huggingface.
+        tokenizer (transformers.PreTrainedTokenizer): a tokenizer provided by huggingface.
+        target_key (str): target key used in the dataset's dictionary.
+        input_key (str): input key used in the dataset's dictionary.
+        max_seq_len (int): max length of a sequence to be used for padding the shorter
+            sequences.
+    """
+
+    def __init__(
+        self,
+        dataset: HFDataset,
+        tokenizer=None,
+        target_key: str = "label",
+        input_key: str = "sentence",
+        max_seq_len: int = 128,
+    ):
+        self.dataset = dataset
+        self.targets, self.texts = self.dataset[target_key], self.dataset[input_key]
+        self.targets_list: List = np.unique(self.targets).tolist()
+
+        if tokenizer:
+            self.input_ids, self.attention_masks = self._tokenize(
+                tokenizer, max_seq_len
+            )
+        else:
+            self.input_ids = self.dataset["input_ids"]
+            self.attention_masks = self.dataset["attention_mask"]
+
+    @property
+    def num_classes(self):
+        return len(self.targets_list)
+
+    def _tokenize(self, tokenizer, max_seq_len):
+        # For speed purposes, we should use fast tokenizers here, but that is up to the caller
+        tokenized = tokenizer(
+            self.texts,
+            add_special_tokens=True,
+            max_length=max_seq_len,
+            return_token_type_ids=False,
+            padding="max_length",
+            return_attention_mask=True,
+            return_tensors="pt",
+            truncation=True,
+        )
+        return tokenized["input_ids"], tokenized["attention_mask"]
+
+    def label(self, idx: int, value: int):
+        """Label the item.
+        Args:
+            idx: index to label
+            value: Value to label the index.
+        """
+        self.targets[idx] = value
+
+    def __len__(self):
+        return len(self.texts)
+
+    def __getitem__(self, idx):
+        target = self.targets_list.index(self.targets[idx])
+
+        return {
+            "input_ids": self.input_ids[idx].flatten()
+            if len(self.input_ids) > 0
+            else None,
+            "inputs": self.texts[idx],
+            "attention_mask": self.attention_masks[idx].flatten()
+            if len(self.attention_masks) > 0
+            else None,
+            "label": torch.tensor(target, dtype=torch.long),
+        }
+
+
+class MyActiveLearningDatasetBert(ActiveLearningDataset):
+    """
+    MyActiveLearningDataset
+
+    Modification of ActiveLearningDataset to allow the indexing with a
+    a list of integers.
+    """
+    
+    
+    @property
+    def labels(self) -> List[int]:
+        return self._dataset[self.get_indices_for_active_step()]["label"]
+
+    @property
+    def n_labelled_tokens(self) -> int:
+        return (
+            self._dataset.dataset[self.get_indices_for_active_step()]["seq_length"]
+            .sum()
+            .item()
+        )
+
+    @property
+    def n_labelled_chars(self) -> int:
+        return (
+            self._dataset.dataset[self.get_indices_for_active_step()]["char_length"]
+            .sum()
+            .item()
+        )
+
+    def __getitem__(self, index: Union[int, List[int]]) -> Any:
+        """Return items from the original dataset based on the labelled index."""
+        _index = np.array(self.get_indices_for_active_step())[index]
+        return self._dataset[_index]
+
+
+class MyActiveLearningDatasetBilstm(ActiveLearningDataset):
+    """
+    MyActiveLearningDataset
+
+    Modification of ActiveLearningDataset to allow the indexing with a
+    a list of integers.
+    """
+
+    @property
+    def labels(self) -> List[int]:
+        return self._dataset[self.get_indices_for_active_step()]["label"]
+
+    @property
+    def n_labelled_tokens(self) -> int:
+        return (
+            self._dataset[self.get_indices_for_active_step()]["seq_length"].sum().item()
+        )
+
+    @property
+    def n_labelled_chars(self) -> int:
+        return (
+            self._dataset[self.get_indices_for_active_step()]["char_length"]
+            .sum()
+            .item()
+        )
+
+    def __getitem__(self, index: Union[int, List[int]]) -> Any:
+        """Return items from the original dataset based on the labelled index."""
+        _index = np.array(self.get_indices_for_active_step())[index]
+        return self._dataset[_index]