Diff of /pathflowai/sampler.py [000000] .. [e9500f]

Switch to side-by-side view

--- a
+++ b/pathflowai/sampler.py
@@ -0,0 +1,94 @@
+"""
+sampler.py
+=======================
+Balanced sampling based on one of the columns of the patch information.
+"""
+
+import torch
+import torch.utils.data
+import torchvision
+import numpy as np
+
+
+class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
+    """Samples elements randomly from a given list of indices for imbalanced dataset
+    https://raw.githubusercontent.com/ufoym/imbalanced-dataset-sampler/master/sampler.py
+    Arguments:
+        indices (list, optional): a list of indices
+        num_samples (int, optional): number of samples to draw
+    """
+
+    def __init__(self, dataset, indices=None, num_samples=None):
+
+        # if indices is not provided,
+        # all elements in the dataset will be considered
+        self.indices = list(range(len(dataset))) \
+            if indices is None else indices
+
+        self.n_targets=len(dataset.targets)
+
+        # if num_samples is not provided,
+        # draw `len(indices)` samples in each iteration
+        self.num_samples = len(self.indices) \
+            if num_samples is None else num_samples
+
+        # distribution of classes in the dataset
+        label_to_count = {}
+        for idx in self.indices:
+            label = self._get_label(dataset, idx)
+            if label in label_to_count:
+                label_to_count[label] += 1
+            else:
+                label_to_count[label] = 1
+
+        # weight for each sample
+        weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
+                   for idx in self.indices]
+        self.weights = torch.DoubleTensor(weights)
+
+    def _get_label(self, dataset, idx):
+        dataset_type = type(dataset)
+        if dataset_type is torchvision.datasets.MNIST:
+            return dataset.train_labels[idx].item()
+        elif dataset_type is torchvision.datasets.ImageFolder:
+            return dataset.imgs[idx][1]
+        else:
+            y=dataset.patch_info.iloc[idx][dataset.targets]
+            if not isinstance(y,np.float):
+                y=y.values
+            if self.n_targets>1:
+                y=np.argmax(y)
+            elif isinstance(y,(list,np.ndarray)):
+                y=y[0]
+            #print(y)
+            return int(y)
+
+    def __iter__(self):
+        return (self.indices[i] for i in torch.multinomial(
+            self.weights, self.num_samples, replacement=True))
+
+    def __len__(self):
+        return self.num_samples
+
+
+"""MIT License
+
+Copyright (c) 2018 Ming
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE."""