--- a +++ b/demo.py @@ -0,0 +1,67 @@ +import itertools +import numpy as np +from torch.utils.data.sampler import Sampler + + +class TwoStreamBatchSampler(Sampler): + """Iterate two sets of indices + + An 'epoch' is one iteration through the primary indices. + During the epoch, the secondary indices are iterated through + as many times as needed. + """ + def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): + # 有标签的索引 + self.primary_indices = primary_indices + # 无标签的索引 + self.secondary_indices = secondary_indices + self.secondary_batch_size = secondary_batch_size + self.primary_batch_size = batch_size - secondary_batch_size + + assert len(self.primary_indices) >= self.primary_batch_size > 0 + assert len(self.secondary_indices) >= self.secondary_batch_size > 0 + + def __iter__(self): + # 随机打乱索引顺序 + primary_iter = iterate_once(self.primary_indices) + secondary_iter = iterate_eternally(self.secondary_indices) + return ( + primary_batch + secondary_batch + for (primary_batch, secondary_batch) + in zip(grouper(primary_iter, self.primary_batch_size), + grouper(secondary_iter, self.secondary_batch_size)) + ) + + def __len__(self): + return len(self.primary_indices) // self.primary_batch_size + + +def iterate_once(iterable): + # print('shuffle labeled_idxs') + return np.random.permutation(iterable) + + +def iterate_eternally(indices): + # print('shuffle unlabeled_idxs') + def infinite_shuffles(): + while True: + yield np.random.permutation(indices) + return itertools.chain.from_iterable(infinite_shuffles()) + + +def grouper(iterable, n): + "Collect data into fixed-length chunks or blocks" + # grouper('ABCDEFG', 3) --> ABC DEF" + args = [iter(iterable)] * n + return zip(*args) + + +if __name__ == '__main__': + labeled_idxs = list(range(12)) + unlabeled_idxs = list(range(12,60)) + batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, 4, 2) + for _ in range(2): + i = 0 + for x in batch_sampler: + i += 1 + print('%02d' % i, '\t', x)