[40f229]: / model / utils / sampler.py

Download this file

27 lines (22 with data), 828 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch.utils.data as tordata
import random
class TripletSampler(tordata.sampler.Sampler):
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
def __iter__(self):
while (True):
sample_indices = list()
pid_list = random.sample(
list(self.dataset.label_set),
self.batch_size[0])
for pid in pid_list:
_index = self.dataset.index_dict.loc[pid, :, :].values
_index = _index[_index > 0].flatten().tolist()
_index = random.choices(
_index,
k=self.batch_size[1])
sample_indices += _index
yield sample_indices
def __len__(self):
return self.dataset.data_size