Diff of /demo.py [000000] .. [903821]

Switch to unified view

a b/demo.py
1
import itertools
2
import numpy as np
3
from torch.utils.data.sampler import Sampler
4
5
6
class TwoStreamBatchSampler(Sampler):
7
    """Iterate two sets of indices
8
9
    An 'epoch' is one iteration through the primary indices.
10
    During the epoch, the secondary indices are iterated through
11
    as many times as needed.
12
    """
13
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
14
        # 有标签的索引
15
        self.primary_indices = primary_indices
16
        # 无标签的索引
17
        self.secondary_indices = secondary_indices
18
        self.secondary_batch_size = secondary_batch_size
19
        self.primary_batch_size = batch_size - secondary_batch_size
20
21
        assert len(self.primary_indices) >= self.primary_batch_size > 0
22
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0
23
24
    def __iter__(self):
25
        # 随机打乱索引顺序
26
        primary_iter = iterate_once(self.primary_indices)
27
        secondary_iter = iterate_eternally(self.secondary_indices)
28
        return (
29
            primary_batch + secondary_batch
30
            for (primary_batch, secondary_batch)
31
            in zip(grouper(primary_iter, self.primary_batch_size),
32
                    grouper(secondary_iter, self.secondary_batch_size))
33
        )
34
35
    def __len__(self):
36
        return len(self.primary_indices) // self.primary_batch_size
37
38
39
def iterate_once(iterable):
40
    # print('shuffle labeled_idxs')
41
    return np.random.permutation(iterable)
42
43
44
def iterate_eternally(indices):
45
    # print('shuffle unlabeled_idxs')
46
    def infinite_shuffles():
47
        while True:
48
            yield np.random.permutation(indices)
49
    return itertools.chain.from_iterable(infinite_shuffles())
50
51
52
def grouper(iterable, n):
53
    "Collect data into fixed-length chunks or blocks"
54
    # grouper('ABCDEFG', 3) --> ABC DEF"
55
    args = [iter(iterable)] * n
56
    return zip(*args)
57
58
59
if __name__ == '__main__':
60
    labeled_idxs = list(range(12))
61
    unlabeled_idxs = list(range(12,60))
62
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, 4, 2)
63
    for _ in range(2):
64
        i = 0
65
        for x in batch_sampler:
66
            i += 1
67
            print('%02d' % i, '\t', x)