Diff of /model/utils/sampler.py [000000] .. [40f229]

Switch to unified view

a b/model/utils/sampler.py
1
import torch.utils.data as tordata
2
import random
3
4
5
class TripletSampler(tordata.sampler.Sampler):
6
    def __init__(self, dataset, batch_size):
7
        self.dataset = dataset
8
        self.batch_size = batch_size
9
10
    def __iter__(self):
11
        while (True):
12
            sample_indices = list()
13
            pid_list = random.sample(
14
                list(self.dataset.label_set),
15
                self.batch_size[0])
16
            for pid in pid_list:
17
                _index = self.dataset.index_dict.loc[pid, :, :].values
18
                _index = _index[_index > 0].flatten().tolist()
19
                _index = random.choices(
20
                    _index,
21
                    k=self.batch_size[1])
22
                sample_indices += _index
23
            yield sample_indices
24
25
    def __len__(self):
26
        return self.dataset.data_size