--- a
+++ b/model/utils/sampler.py
@@ -0,0 +1,26 @@
+import torch.utils.data as tordata
+import random
+
+
+class TripletSampler(tordata.sampler.Sampler):
+    def __init__(self, dataset, batch_size):
+        self.dataset = dataset
+        self.batch_size = batch_size
+
+    def __iter__(self):
+        while (True):
+            sample_indices = list()
+            pid_list = random.sample(
+                list(self.dataset.label_set),
+                self.batch_size[0])
+            for pid in pid_list:
+                _index = self.dataset.index_dict.loc[pid, :, :].values
+                _index = _index[_index > 0].flatten().tolist()
+                _index = random.choices(
+                    _index,
+                    k=self.batch_size[1])
+                sample_indices += _index
+            yield sample_indices
+
+    def __len__(self):
+        return self.dataset.data_size