|
a |
|
b/model/utils/sampler.py |
|
|
1 |
import torch.utils.data as tordata |
|
|
2 |
import random |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
class TripletSampler(tordata.sampler.Sampler): |
|
|
6 |
def __init__(self, dataset, batch_size): |
|
|
7 |
self.dataset = dataset |
|
|
8 |
self.batch_size = batch_size |
|
|
9 |
|
|
|
10 |
def __iter__(self): |
|
|
11 |
while (True): |
|
|
12 |
sample_indices = list() |
|
|
13 |
pid_list = random.sample( |
|
|
14 |
list(self.dataset.label_set), |
|
|
15 |
self.batch_size[0]) |
|
|
16 |
for pid in pid_list: |
|
|
17 |
_index = self.dataset.index_dict.loc[pid, :, :].values |
|
|
18 |
_index = _index[_index > 0].flatten().tolist() |
|
|
19 |
_index = random.choices( |
|
|
20 |
_index, |
|
|
21 |
k=self.batch_size[1]) |
|
|
22 |
sample_indices += _index |
|
|
23 |
yield sample_indices |
|
|
24 |
|
|
|
25 |
def __len__(self): |
|
|
26 |
return self.dataset.data_size |