|
a |
|
b/dl/utils/sampler.py |
|
|
1 |
import random |
|
|
2 |
import numpy as np |
|
|
3 |
import collections |
|
|
4 |
|
|
|
5 |
try: |
|
|
6 |
import torch |
|
|
7 |
from torch.autograd import Variable |
|
|
8 |
except ImportError: |
|
|
9 |
pass |
|
|
10 |
|
|
|
11 |
__all__ = ['BatchSequentialSampler', 'RepeatedBatchSampler', 'balanced_sampler', 'BatchLoader'] |
|
|
12 |
|
|
|
13 |
if torch.cuda.is_available(): |
|
|
14 |
dtype = {'float': torch.cuda.FloatTensor, 'long': torch.cuda.LongTensor, 'byte': torch.cuda.ByteTensor} #pylint disable=no-member |
|
|
15 |
else: |
|
|
16 |
dtype = {'float': torch.FloatTensor, 'long': torch.LongTensor, 'byte': torch.ByteTensor} |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
class BatchSequentialSampler(object): |
|
|
20 |
"""return a list of batches (same implementation with torch.utils.data.sampler.BatchSampler) |
|
|
21 |
Args: |
|
|
22 |
sampler: an iterator, eg: range(100) |
|
|
23 |
batch_size: int |
|
|
24 |
drop_last: bool |
|
|
25 |
Return: |
|
|
26 |
an iterator, each iter returns a batch of batch_size from sampler |
|
|
27 |
""" |
|
|
28 |
def __init__(self, sampler, batch_size=1, drop_last=False): |
|
|
29 |
self.sampler = sampler |
|
|
30 |
self.batch_size = batch_size |
|
|
31 |
self.drop_last = drop_last |
|
|
32 |
|
|
|
33 |
def __iter__(self): |
|
|
34 |
batch = [] |
|
|
35 |
for i in self.sampler: |
|
|
36 |
batch.append(i) |
|
|
37 |
if len(batch) == self.batch_size: |
|
|
38 |
yield batch |
|
|
39 |
batch = [] |
|
|
40 |
if len(batch) > 0 and not self.drop_last: |
|
|
41 |
yield batch |
|
|
42 |
|
|
|
43 |
def __len__(self): |
|
|
44 |
if self.drop_last: |
|
|
45 |
return len(self.sampler) // self.batch_size |
|
|
46 |
else: |
|
|
47 |
return (len(self.sampler) + self.batch_size - 1) // self.batch_size |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
class RepeatedBatchSampler(object): |
|
|
51 |
"""Generate num_iter of batches with batch_size |
|
|
52 |
Args: |
|
|
53 |
sampler: an iterator, that will be converted to list |
|
|
54 |
batch_size: int |
|
|
55 |
num_iter: int, default: None |
|
|
56 |
shuffle: bool, default: True |
|
|
57 |
allow_duplicate: bool, default: True |
|
|
58 |
Return: |
|
|
59 |
an iterator of length num_iter |
|
|
60 |
""" |
|
|
61 |
def __init__(self, sampler, batch_size=1, num_iter=None, shuffle=True, allow_duplicate=True, seed=None): |
|
|
62 |
assert len(sampler) > 0 |
|
|
63 |
self.sampler = sampler # only for __next__(); uselesss |
|
|
64 |
if len(sampler) < batch_size and not allow_duplicate: |
|
|
65 |
batch_size = len(sampler) |
|
|
66 |
assert batch_size > 0 |
|
|
67 |
self.batch_size = batch_size |
|
|
68 |
if num_iter is None: |
|
|
69 |
num_iter = (len(sampler) + batch_size - 1) // batch_size |
|
|
70 |
assert num_iter > 0 |
|
|
71 |
self.num_iter = num_iter |
|
|
72 |
num_repeats = (num_iter * batch_size + len(sampler) - 1 ) // len(sampler) |
|
|
73 |
self.sampler_ext = [] |
|
|
74 |
if seed is not None: # if seed: is buggy (seed=0) |
|
|
75 |
np.random.seed(seed) |
|
|
76 |
for i in range(num_repeats): |
|
|
77 |
if shuffle: |
|
|
78 |
idx = np.random.permutation(len(sampler)) |
|
|
79 |
else: |
|
|
80 |
idx = range(len(sampler)) |
|
|
81 |
self.sampler_ext += [sampler[i] for i in idx] |
|
|
82 |
|
|
|
83 |
def __iter__(self): |
|
|
84 |
cnt = 0 |
|
|
85 |
for i in range(self.num_iter): |
|
|
86 |
yield self.sampler_ext[cnt:(cnt + self.batch_size)] |
|
|
87 |
cnt += self.batch_size |
|
|
88 |
|
|
|
89 |
def __len__(self): |
|
|
90 |
return self.num_iter |
|
|
91 |
|
|
|
92 |
def __next__(self): |
|
|
93 |
indices = np.random.permutation(len(self.sampler))[:self.batch_size] |
|
|
94 |
sampler = list(self.sampler) |
|
|
95 |
batch = [sampler[i] for i in indices] |
|
|
96 |
return batch |
|
|
97 |
|
|
|
98 |
def balanced_sampler(y, batch_size=10, num_iter=None, allow_duplicate=False, |
|
|
99 |
max_redundancy=3, shuffle=True, seed=None): |
|
|
100 |
"""Given class labels y, return a balanced batch sampler, i.e., |
|
|
101 |
each class appears the same number of times in each batch |
|
|
102 |
Args: |
|
|
103 |
y: list, tuple, or numpy 1-d array |
|
|
104 |
batch_size: int; how many instances of each class should be included in a batch. |
|
|
105 |
Thus the real batch size = batch_size * num_classes in most cases |
|
|
106 |
num_iter: number of batches. If None, calculate from y, batch_size, etc. |
|
|
107 |
allow_duplicate: in case batch_size > the smallest class size, if not allow_duplicate, |
|
|
108 |
reduce batch_size |
|
|
109 |
max_redundancy: default 3; if num_iter is initially None, |
|
|
110 |
the calculated num_iter will be larger than num_iter of a 'traditional' epoch |
|
|
111 |
by a factor of num_classes. max_redundancy can reduce this factor |
|
|
112 |
shuffle: default True. Always shuffle the batches |
|
|
113 |
seed: if not None, call np.random.seed(seed). For unittest |
|
|
114 |
Return: |
|
|
115 |
a numpy array of shape (num_iter, real_batch_size) |
|
|
116 |
""" |
|
|
117 |
z = collections.defaultdict(list) |
|
|
118 |
# this is extremely buggy; when y is torch.Tensor, e is different even if they have the same value |
|
|
119 |
[z[e.item()].append(i) for i, e in enumerate(y)] |
|
|
120 |
least_size = min([len(v) for k, v in z.items()]) |
|
|
121 |
if least_size < batch_size and not allow_duplicate: |
|
|
122 |
batch_size = least_size |
|
|
123 |
if num_iter is None: |
|
|
124 |
num_iter = (len(y) + batch_size - 1) // batch_size |
|
|
125 |
if len(z) > max_redundancy: |
|
|
126 |
num_iter = (num_iter * max_redundancy + len(z) - 1) // len(z) |
|
|
127 |
|
|
|
128 |
bs = [RepeatedBatchSampler(v, batch_size=batch_size, num_iter=num_iter, shuffle=shuffle, |
|
|
129 |
allow_duplicate=allow_duplicate, seed=seed) |
|
|
130 |
for k, v in z.items()] |
|
|
131 |
bs = [[e for e in s] for s in bs] |
|
|
132 |
indices = np.array(bs).transpose(1, 0, 2).reshape(num_iter, -1) |
|
|
133 |
# In each batch, shuffle instances so that instances of the same class won't cluster together |
|
|
134 |
# may not be necessary |
|
|
135 |
if shuffle: |
|
|
136 |
if seed is not None: |
|
|
137 |
np.random.seed(seed) |
|
|
138 |
[np.random.shuffle(v) for v in indices] |
|
|
139 |
return indices |
|
|
140 |
|
|
|
141 |
|
|
|
142 |
class BatchLoader(object): |
|
|
143 |
"""Return an iterator of data batches |
|
|
144 |
Args: |
|
|
145 |
data: a single or a list/tuple of np.array/torch.Tensor |
|
|
146 |
labels: class labels, e.g., a list of int, used for balanced_sampler |
|
|
147 |
batch_size: int |
|
|
148 |
balanced: if true, used balanced_sampler, else use BatchSequentialSampler |
|
|
149 |
The rest of parameters are to be passed to balanced_sampler |
|
|
150 |
""" |
|
|
151 |
def __init__(self, data, batch_size=10, labels=None, balanced=True, num_iter=None, |
|
|
152 |
allow_duplicate=False, max_redundancy=3, shuffle=True, seed=None): |
|
|
153 |
assert (labels is None and isinstance(data, (tuple, list)) and len(data) > 1) or ( |
|
|
154 |
labels is not None) |
|
|
155 |
if labels is None: |
|
|
156 |
labels = data[-1] |
|
|
157 |
if not isinstance(data, (tuple, list)): |
|
|
158 |
data = [data] |
|
|
159 |
assert len(data) > 0 and len(data[0]) == len(labels) |
|
|
160 |
self.data = data |
|
|
161 |
N = len(labels) |
|
|
162 |
if balanced: |
|
|
163 |
self.indices = balanced_sampler(labels, batch_size=batch_size, num_iter=num_iter, |
|
|
164 |
allow_duplicate=allow_duplicate, max_redundancy=max_redundancy, |
|
|
165 |
shuffle=shuffle, seed=seed) |
|
|
166 |
else: |
|
|
167 |
idx = range(N) |
|
|
168 |
if shuffle: |
|
|
169 |
idx = np.random.permutation(N).tolist() |
|
|
170 |
self.indices = RepeatedBatchSampler(idx, batch_size=batch_size, |
|
|
171 |
num_iter=num_iter, shuffle=shuffle, allow_duplicate=allow_duplicate, seed=seed) |
|
|
172 |
|
|
|
173 |
def __iter__(self): |
|
|
174 |
for idx in self.indices: |
|
|
175 |
batch = [] |
|
|
176 |
for data in self.data: |
|
|
177 |
try: |
|
|
178 |
if isinstance(data, torch.Tensor): |
|
|
179 |
idx = torch.LongTensor(idx) |
|
|
180 |
except NameError: |
|
|
181 |
pass |
|
|
182 |
batch.append(data[idx]) |
|
|
183 |
yield batch |
|
|
184 |
|
|
|
185 |
def __len__(self): |
|
|
186 |
return len(self.indices) |