Diff of /dl/utils/sampler.py [000000] .. [4807fa]

Switch to unified view

a b/dl/utils/sampler.py
1
import random
2
import numpy as np
3
import collections
4
5
try:
6
    import torch
7
    from torch.autograd import Variable
8
except ImportError:
9
    pass
10
11
__all__ = ['BatchSequentialSampler', 'RepeatedBatchSampler', 'balanced_sampler', 'BatchLoader']
12
13
if torch.cuda.is_available():
14
  dtype = {'float': torch.cuda.FloatTensor, 'long': torch.cuda.LongTensor, 'byte': torch.cuda.ByteTensor} #pylint disable=no-member
15
else:
16
  dtype = {'float': torch.FloatTensor, 'long': torch.LongTensor, 'byte': torch.ByteTensor} 
17
18
19
class BatchSequentialSampler(object):
20
    """return a list of batches (same implementation with torch.utils.data.sampler.BatchSampler)
21
    Args:
22
        sampler: an iterator, eg: range(100)
23
        batch_size: int
24
        drop_last: bool
25
    Return:
26
        an iterator, each iter returns a batch of batch_size from sampler 
27
    """
28
    def __init__(self, sampler, batch_size=1, drop_last=False):
29
        self.sampler = sampler    
30
        self.batch_size = batch_size
31
        self.drop_last = drop_last
32
    
33
    def __iter__(self):
34
        batch = []
35
        for i in self.sampler:
36
            batch.append(i)
37
            if len(batch) == self.batch_size:
38
                yield batch
39
                batch = []
40
        if len(batch) > 0 and not self.drop_last:
41
            yield batch
42
    
43
    def __len__(self):
44
        if self.drop_last:
45
            return len(self.sampler) // self.batch_size
46
        else:
47
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size
48
49
        
50
class RepeatedBatchSampler(object):
51
    """Generate num_iter of batches with batch_size
52
    Args:
53
        sampler: an iterator, that will be converted to list
54
        batch_size: int
55
        num_iter: int, default: None
56
        shuffle: bool, default: True
57
        allow_duplicate: bool, default: True
58
    Return:
59
        an iterator of length num_iter
60
    """
61
    def __init__(self, sampler, batch_size=1, num_iter=None, shuffle=True, allow_duplicate=True, seed=None): 
62
        assert len(sampler) > 0
63
        self.sampler = sampler # only for __next__(); uselesss
64
        if len(sampler) < batch_size and not allow_duplicate:
65
            batch_size = len(sampler)
66
        assert batch_size > 0
67
        self.batch_size = batch_size
68
        if num_iter is None:
69
            num_iter = (len(sampler) + batch_size - 1) // batch_size
70
        assert num_iter > 0
71
        self.num_iter = num_iter
72
        num_repeats = (num_iter * batch_size + len(sampler) - 1 ) // len(sampler)
73
        self.sampler_ext = []
74
        if seed is not None: # if seed: is buggy (seed=0)
75
            np.random.seed(seed)
76
        for i in range(num_repeats):
77
            if shuffle:
78
                idx = np.random.permutation(len(sampler))
79
            else:
80
                idx = range(len(sampler))
81
            self.sampler_ext += [sampler[i] for i in idx]
82
            
83
    def __iter__(self):
84
        cnt = 0
85
        for i in range(self.num_iter):
86
            yield self.sampler_ext[cnt:(cnt + self.batch_size)]
87
            cnt += self.batch_size
88
            
89
    def __len__(self):
90
        return self.num_iter
91
    
92
    def __next__(self):
93
        indices = np.random.permutation(len(self.sampler))[:self.batch_size]
94
        sampler = list(self.sampler)
95
        batch = [sampler[i] for i in indices]
96
        return batch
97
98
def balanced_sampler(y, batch_size=10, num_iter=None, allow_duplicate=False, 
99
                     max_redundancy=3, shuffle=True, seed=None):
100
    """Given class labels y, return a balanced batch sampler, i.e., 
101
       each class appears the same number of times in each batch
102
    Args:
103
        y: list, tuple, or numpy 1-d array
104
        batch_size: int; how many instances of each class should be included in a batch. 
105
                    Thus the real batch size = batch_size * num_classes in most cases
106
        num_iter: number of batches. If None, calculate from y, batch_size, etc.
107
        allow_duplicate: in case batch_size > the smallest class size, if not allow_duplicate, 
108
                         reduce batch_size
109
        max_redundancy: default 3; if num_iter is initially None, 
110
                the calculated num_iter will be larger than num_iter of a 'traditional' epoch 
111
                by a factor of num_classes. max_redundancy can reduce this factor
112
        shuffle: default True. Always shuffle the batches
113
        seed: if not None, call np.random.seed(seed). For unittest
114
    Return:
115
        a numpy array of shape (num_iter, real_batch_size)
116
    """
117
    z = collections.defaultdict(list)
118
    # this is extremely buggy; when y is torch.Tensor, e is different even if they have the same value
119
    [z[e.item()].append(i) for i, e in enumerate(y)] 
120
    least_size = min([len(v) for k, v in z.items()])
121
    if least_size < batch_size and not allow_duplicate:
122
        batch_size = least_size
123
    if num_iter is None:
124
        num_iter = (len(y) + batch_size - 1) // batch_size
125
        if len(z) > max_redundancy:
126
            num_iter = (num_iter * max_redundancy + len(z) - 1) // len(z)
127
    
128
    bs = [RepeatedBatchSampler(v, batch_size=batch_size, num_iter=num_iter, shuffle=shuffle,
129
                               allow_duplicate=allow_duplicate, seed=seed)
130
          for k, v in z.items()]
131
    bs = [[e for e in s] for s in bs]
132
    indices = np.array(bs).transpose(1, 0, 2).reshape(num_iter, -1)
133
    # In each batch, shuffle instances so that instances of the same class won't cluster together
134
    # may not be necessary
135
    if shuffle:
136
        if seed is not None:
137
            np.random.seed(seed)
138
        [np.random.shuffle(v) for v in indices]
139
    return indices
140
141
142
class BatchLoader(object):
143
    """Return an iterator of data batches
144
    Args:
145
        data: a single or a list/tuple of np.array/torch.Tensor
146
        labels: class labels, e.g., a list of int, used for balanced_sampler
147
        batch_size: int
148
        balanced: if true, used balanced_sampler, else use BatchSequentialSampler
149
        The rest of parameters are to be passed to balanced_sampler
150
    """
151
    def __init__(self, data, batch_size=10, labels=None, balanced=True, num_iter=None, 
152
                 allow_duplicate=False, max_redundancy=3, shuffle=True, seed=None):
153
        assert (labels is None and isinstance(data, (tuple, list)) and len(data) > 1) or (
154
            labels is not None)
155
        if labels is None:
156
            labels = data[-1]
157
        if not isinstance(data, (tuple, list)):
158
            data = [data]
159
        assert len(data) > 0 and len(data[0]) == len(labels)
160
        self.data = data
161
        N = len(labels)   
162
        if balanced:
163
            self.indices = balanced_sampler(labels, batch_size=batch_size, num_iter=num_iter, 
164
                                            allow_duplicate=allow_duplicate, max_redundancy=max_redundancy,
165
                                            shuffle=shuffle, seed=seed)
166
        else:
167
            idx = range(N)
168
            if shuffle:
169
                idx = np.random.permutation(N).tolist()
170
            self.indices = RepeatedBatchSampler(idx, batch_size=batch_size, 
171
                num_iter=num_iter, shuffle=shuffle, allow_duplicate=allow_duplicate, seed=seed)
172
    
173
    def __iter__(self):
174
        for idx in self.indices:
175
            batch = []
176
            for data in self.data:
177
                try:
178
                    if isinstance(data, torch.Tensor):
179
                        idx = torch.LongTensor(idx)
180
                except NameError:
181
                    pass
182
                batch.append(data[idx]) 
183
            yield batch
184
                
185
    def __len__(self):
186
        return len(self.indices)