--- a +++ b/pathflowai/sampler.py @@ -0,0 +1,94 @@ +""" +sampler.py +======================= +Balanced sampling based on one of the columns of the patch information. +""" + +import torch +import torch.utils.data +import torchvision +import numpy as np + + +class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): + """Samples elements randomly from a given list of indices for imbalanced dataset + https://raw.githubusercontent.com/ufoym/imbalanced-dataset-sampler/master/sampler.py + Arguments: + indices (list, optional): a list of indices + num_samples (int, optional): number of samples to draw + """ + + def __init__(self, dataset, indices=None, num_samples=None): + + # if indices is not provided, + # all elements in the dataset will be considered + self.indices = list(range(len(dataset))) \ + if indices is None else indices + + self.n_targets=len(dataset.targets) + + # if num_samples is not provided, + # draw `len(indices)` samples in each iteration + self.num_samples = len(self.indices) \ + if num_samples is None else num_samples + + # distribution of classes in the dataset + label_to_count = {} + for idx in self.indices: + label = self._get_label(dataset, idx) + if label in label_to_count: + label_to_count[label] += 1 + else: + label_to_count[label] = 1 + + # weight for each sample + weights = [1.0 / label_to_count[self._get_label(dataset, idx)] + for idx in self.indices] + self.weights = torch.DoubleTensor(weights) + + def _get_label(self, dataset, idx): + dataset_type = type(dataset) + if dataset_type is torchvision.datasets.MNIST: + return dataset.train_labels[idx].item() + elif dataset_type is torchvision.datasets.ImageFolder: + return dataset.imgs[idx][1] + else: + y=dataset.patch_info.iloc[idx][dataset.targets] + if not isinstance(y,np.float): + y=y.values + if self.n_targets>1: + y=np.argmax(y) + elif isinstance(y,(list,np.ndarray)): + y=y[0] + #print(y) + return int(y) + + def __iter__(self): + return (self.indices[i] for i in torch.multinomial( + self.weights, self.num_samples, replacement=True)) + + def __len__(self): + return self.num_samples + + +"""MIT License + +Copyright (c) 2018 Ming + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE."""