|
a |
|
b/pathflowai/sampler.py |
|
|
1 |
""" |
|
|
2 |
sampler.py |
|
|
3 |
======================= |
|
|
4 |
Balanced sampling based on one of the columns of the patch information. |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
import torch |
|
|
8 |
import torch.utils.data |
|
|
9 |
import torchvision |
|
|
10 |
import numpy as np |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler): |
|
|
14 |
"""Samples elements randomly from a given list of indices for imbalanced dataset |
|
|
15 |
https://raw.githubusercontent.com/ufoym/imbalanced-dataset-sampler/master/sampler.py |
|
|
16 |
Arguments: |
|
|
17 |
indices (list, optional): a list of indices |
|
|
18 |
num_samples (int, optional): number of samples to draw |
|
|
19 |
""" |
|
|
20 |
|
|
|
21 |
def __init__(self, dataset, indices=None, num_samples=None): |
|
|
22 |
|
|
|
23 |
# if indices is not provided, |
|
|
24 |
# all elements in the dataset will be considered |
|
|
25 |
self.indices = list(range(len(dataset))) \ |
|
|
26 |
if indices is None else indices |
|
|
27 |
|
|
|
28 |
self.n_targets=len(dataset.targets) |
|
|
29 |
|
|
|
30 |
# if num_samples is not provided, |
|
|
31 |
# draw `len(indices)` samples in each iteration |
|
|
32 |
self.num_samples = len(self.indices) \ |
|
|
33 |
if num_samples is None else num_samples |
|
|
34 |
|
|
|
35 |
# distribution of classes in the dataset |
|
|
36 |
label_to_count = {} |
|
|
37 |
for idx in self.indices: |
|
|
38 |
label = self._get_label(dataset, idx) |
|
|
39 |
if label in label_to_count: |
|
|
40 |
label_to_count[label] += 1 |
|
|
41 |
else: |
|
|
42 |
label_to_count[label] = 1 |
|
|
43 |
|
|
|
44 |
# weight for each sample |
|
|
45 |
weights = [1.0 / label_to_count[self._get_label(dataset, idx)] |
|
|
46 |
for idx in self.indices] |
|
|
47 |
self.weights = torch.DoubleTensor(weights) |
|
|
48 |
|
|
|
49 |
def _get_label(self, dataset, idx): |
|
|
50 |
dataset_type = type(dataset) |
|
|
51 |
if dataset_type is torchvision.datasets.MNIST: |
|
|
52 |
return dataset.train_labels[idx].item() |
|
|
53 |
elif dataset_type is torchvision.datasets.ImageFolder: |
|
|
54 |
return dataset.imgs[idx][1] |
|
|
55 |
else: |
|
|
56 |
y=dataset.patch_info.iloc[idx][dataset.targets] |
|
|
57 |
if not isinstance(y,np.float): |
|
|
58 |
y=y.values |
|
|
59 |
if self.n_targets>1: |
|
|
60 |
y=np.argmax(y) |
|
|
61 |
elif isinstance(y,(list,np.ndarray)): |
|
|
62 |
y=y[0] |
|
|
63 |
#print(y) |
|
|
64 |
return int(y) |
|
|
65 |
|
|
|
66 |
def __iter__(self): |
|
|
67 |
return (self.indices[i] for i in torch.multinomial( |
|
|
68 |
self.weights, self.num_samples, replacement=True)) |
|
|
69 |
|
|
|
70 |
def __len__(self): |
|
|
71 |
return self.num_samples |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
"""MIT License |
|
|
75 |
|
|
|
76 |
Copyright (c) 2018 Ming |
|
|
77 |
|
|
|
78 |
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
|
79 |
of this software and associated documentation files (the "Software"), to deal |
|
|
80 |
in the Software without restriction, including without limitation the rights |
|
|
81 |
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
|
82 |
copies of the Software, and to permit persons to whom the Software is |
|
|
83 |
furnished to do so, subject to the following conditions: |
|
|
84 |
|
|
|
85 |
The above copyright notice and this permission notice shall be included in all |
|
|
86 |
copies or substantial portions of the Software. |
|
|
87 |
|
|
|
88 |
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
|
89 |
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
|
90 |
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
|
91 |
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
|
92 |
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
|
93 |
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
|
94 |
SOFTWARE.""" |