Diff of /pathflowai/sampler.py [000000] .. [e9500f]

Switch to unified view

a b/pathflowai/sampler.py
1
"""
2
sampler.py
3
=======================
4
Balanced sampling based on one of the columns of the patch information.
5
"""
6
7
import torch
8
import torch.utils.data
9
import torchvision
10
import numpy as np
11
12
13
class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
14
    """Samples elements randomly from a given list of indices for imbalanced dataset
15
    https://raw.githubusercontent.com/ufoym/imbalanced-dataset-sampler/master/sampler.py
16
    Arguments:
17
        indices (list, optional): a list of indices
18
        num_samples (int, optional): number of samples to draw
19
    """
20
21
    def __init__(self, dataset, indices=None, num_samples=None):
22
23
        # if indices is not provided,
24
        # all elements in the dataset will be considered
25
        self.indices = list(range(len(dataset))) \
26
            if indices is None else indices
27
28
        self.n_targets=len(dataset.targets)
29
30
        # if num_samples is not provided,
31
        # draw `len(indices)` samples in each iteration
32
        self.num_samples = len(self.indices) \
33
            if num_samples is None else num_samples
34
35
        # distribution of classes in the dataset
36
        label_to_count = {}
37
        for idx in self.indices:
38
            label = self._get_label(dataset, idx)
39
            if label in label_to_count:
40
                label_to_count[label] += 1
41
            else:
42
                label_to_count[label] = 1
43
44
        # weight for each sample
45
        weights = [1.0 / label_to_count[self._get_label(dataset, idx)]
46
                   for idx in self.indices]
47
        self.weights = torch.DoubleTensor(weights)
48
49
    def _get_label(self, dataset, idx):
50
        dataset_type = type(dataset)
51
        if dataset_type is torchvision.datasets.MNIST:
52
            return dataset.train_labels[idx].item()
53
        elif dataset_type is torchvision.datasets.ImageFolder:
54
            return dataset.imgs[idx][1]
55
        else:
56
            y=dataset.patch_info.iloc[idx][dataset.targets]
57
            if not isinstance(y,np.float):
58
                y=y.values
59
            if self.n_targets>1:
60
                y=np.argmax(y)
61
            elif isinstance(y,(list,np.ndarray)):
62
                y=y[0]
63
            #print(y)
64
            return int(y)
65
66
    def __iter__(self):
67
        return (self.indices[i] for i in torch.multinomial(
68
            self.weights, self.num_samples, replacement=True))
69
70
    def __len__(self):
71
        return self.num_samples
72
73
74
"""MIT License
75
76
Copyright (c) 2018 Ming
77
78
Permission is hereby granted, free of charge, to any person obtaining a copy
79
of this software and associated documentation files (the "Software"), to deal
80
in the Software without restriction, including without limitation the rights
81
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
82
copies of the Software, and to permit persons to whom the Software is
83
furnished to do so, subject to the following conditions:
84
85
The above copyright notice and this permission notice shall be included in all
86
copies or substantial portions of the Software.
87
88
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
89
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
90
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
91
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
92
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
93
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
94
SOFTWARE."""