|
a |
|
b/demo.py |
|
|
1 |
import itertools |
|
|
2 |
import numpy as np |
|
|
3 |
from torch.utils.data.sampler import Sampler |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class TwoStreamBatchSampler(Sampler): |
|
|
7 |
"""Iterate two sets of indices |
|
|
8 |
|
|
|
9 |
An 'epoch' is one iteration through the primary indices. |
|
|
10 |
During the epoch, the secondary indices are iterated through |
|
|
11 |
as many times as needed. |
|
|
12 |
""" |
|
|
13 |
def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): |
|
|
14 |
# 有标签的索引 |
|
|
15 |
self.primary_indices = primary_indices |
|
|
16 |
# 无标签的索引 |
|
|
17 |
self.secondary_indices = secondary_indices |
|
|
18 |
self.secondary_batch_size = secondary_batch_size |
|
|
19 |
self.primary_batch_size = batch_size - secondary_batch_size |
|
|
20 |
|
|
|
21 |
assert len(self.primary_indices) >= self.primary_batch_size > 0 |
|
|
22 |
assert len(self.secondary_indices) >= self.secondary_batch_size > 0 |
|
|
23 |
|
|
|
24 |
def __iter__(self): |
|
|
25 |
# 随机打乱索引顺序 |
|
|
26 |
primary_iter = iterate_once(self.primary_indices) |
|
|
27 |
secondary_iter = iterate_eternally(self.secondary_indices) |
|
|
28 |
return ( |
|
|
29 |
primary_batch + secondary_batch |
|
|
30 |
for (primary_batch, secondary_batch) |
|
|
31 |
in zip(grouper(primary_iter, self.primary_batch_size), |
|
|
32 |
grouper(secondary_iter, self.secondary_batch_size)) |
|
|
33 |
) |
|
|
34 |
|
|
|
35 |
def __len__(self): |
|
|
36 |
return len(self.primary_indices) // self.primary_batch_size |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def iterate_once(iterable): |
|
|
40 |
# print('shuffle labeled_idxs') |
|
|
41 |
return np.random.permutation(iterable) |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def iterate_eternally(indices): |
|
|
45 |
# print('shuffle unlabeled_idxs') |
|
|
46 |
def infinite_shuffles(): |
|
|
47 |
while True: |
|
|
48 |
yield np.random.permutation(indices) |
|
|
49 |
return itertools.chain.from_iterable(infinite_shuffles()) |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
def grouper(iterable, n): |
|
|
53 |
"Collect data into fixed-length chunks or blocks" |
|
|
54 |
# grouper('ABCDEFG', 3) --> ABC DEF" |
|
|
55 |
args = [iter(iterable)] * n |
|
|
56 |
return zip(*args) |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
if __name__ == '__main__': |
|
|
60 |
labeled_idxs = list(range(12)) |
|
|
61 |
unlabeled_idxs = list(range(12,60)) |
|
|
62 |
batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, 4, 2) |
|
|
63 |
for _ in range(2): |
|
|
64 |
i = 0 |
|
|
65 |
for x in batch_sampler: |
|
|
66 |
i += 1 |
|
|
67 |
print('%02d' % i, '\t', x) |