Diff of /opengait/data/sampler.py [000000] .. [fd9ef4]

Switch to side-by-side view

--- a
+++ b/opengait/data/sampler.py
@@ -0,0 +1,174 @@
+import math
+import random
+import torch
+import torch.distributed as dist
+import torch.utils.data as tordata
+
+
+class TripletSampler(tordata.sampler.Sampler):
+    def __init__(self, dataset, batch_size, batch_shuffle=False):
+        self.dataset = dataset
+        self.batch_size = batch_size
+        if len(self.batch_size) != 2:
+            raise ValueError(
+                "batch_size should be (P x K) not {}".format(batch_size))
+        self.batch_shuffle = batch_shuffle
+
+        self.world_size = dist.get_world_size()
+        if (self.batch_size[0]*self.batch_size[1]) % self.world_size != 0:
+            raise ValueError("World size ({}) is not divisible by batch_size ({} x {})".format(
+                self.world_size, batch_size[0], batch_size[1]))
+        self.rank = dist.get_rank()
+
+    def __iter__(self):
+        while True:
+            sample_indices = []
+            pid_list = sync_random_sample_list(
+                self.dataset.label_set, self.batch_size[0])
+
+            for pid in pid_list:
+                indices = self.dataset.indices_dict[pid]
+                indices = sync_random_sample_list(
+                    indices, k=self.batch_size[1])
+                sample_indices += indices
+
+            if self.batch_shuffle:
+                sample_indices = sync_random_sample_list(
+                    sample_indices, len(sample_indices))
+
+            total_batch_size = self.batch_size[0] * self.batch_size[1]
+            total_size = int(math.ceil(total_batch_size /
+                                       self.world_size)) * self.world_size
+            sample_indices += sample_indices[:(
+                total_batch_size - len(sample_indices))]
+
+            sample_indices = sample_indices[self.rank:total_size:self.world_size]
+            yield sample_indices
+
+    def __len__(self):
+        return len(self.dataset)
+
+
+def sync_random_sample_list(obj_list, k, common_choice=False):
+    if common_choice:
+        idx = random.choices(range(len(obj_list)), k=k) 
+        idx = torch.tensor(idx)
+    if len(obj_list) < k:
+        idx = random.choices(range(len(obj_list)), k=k)
+        idx = torch.tensor(idx)
+    else:
+        idx = torch.randperm(len(obj_list))[:k]
+    if torch.cuda.is_available():
+        idx = idx.cuda()
+    torch.distributed.broadcast(idx, src=0)
+    idx = idx.tolist()
+    return [obj_list[i] for i in idx]
+
+
+class InferenceSampler(tordata.sampler.Sampler):
+    def __init__(self, dataset, batch_size):
+        self.dataset = dataset
+        self.batch_size = batch_size
+
+        self.size = len(dataset)
+        indices = list(range(self.size))
+
+        world_size = dist.get_world_size()
+        rank = dist.get_rank()
+
+        if batch_size % world_size != 0:
+            raise ValueError("World size ({}) is not divisible by batch_size ({})".format(
+                world_size, batch_size))
+
+        if batch_size != 1:
+            complement_size = math.ceil(self.size / batch_size) * \
+                batch_size
+            indices += indices[:(complement_size - self.size)]
+            self.size = complement_size
+
+        batch_size_per_rank = int(self.batch_size / world_size)
+        indx_batch_per_rank = []
+
+        for i in range(int(self.size / batch_size_per_rank)):
+            indx_batch_per_rank.append(
+                indices[i*batch_size_per_rank:(i+1)*batch_size_per_rank])
+
+        self.idx_batch_this_rank = indx_batch_per_rank[rank::world_size]
+
+    def __iter__(self):
+        yield from self.idx_batch_this_rank
+
+    def __len__(self):
+        return len(self.dataset)
+
+
+class CommonSampler(tordata.sampler.Sampler):
+    def __init__(self,dataset,batch_size,batch_shuffle):
+
+        self.dataset = dataset
+        self.size = len(dataset)
+        self.batch_size = batch_size
+        if isinstance(self.batch_size,int)==False:
+            raise ValueError(
+                "batch_size shoude be (B) not {}".format(batch_size))
+        self.batch_shuffle = batch_shuffle
+        
+        self.world_size = dist.get_world_size()
+        if self.batch_size % self.world_size !=0:
+            raise ValueError("World size ({}) is not divisble by batch_size ({})".format(
+                self.world_size, batch_size))
+        self.rank = dist.get_rank() 
+    
+    def __iter__(self):
+        while True:
+            indices_list = list(range(self.size))
+            sample_indices = sync_random_sample_list(
+                    indices_list, self.batch_size, common_choice=True)
+            total_batch_size =  self.batch_size
+            total_size = int(math.ceil(total_batch_size /
+                                       self.world_size)) * self.world_size
+            sample_indices += sample_indices[:(
+                total_batch_size - len(sample_indices))]
+            sample_indices = sample_indices[self.rank:total_size:self.world_size]
+            yield sample_indices
+
+    def __len__(self):
+        return len(self.dataset)
+
+# **************** For GaitSSB ****************
+# Fan, et al: Learning Gait Representation from Massive Unlabelled Walking Videos: A Benchmark, T-PAMI2023
+import random
+class BilateralSampler(tordata.sampler.Sampler):
+    def __init__(self, dataset, batch_size, batch_shuffle=False):
+        self.dataset = dataset
+        self.batch_size = batch_size
+        self.batch_shuffle = batch_shuffle
+
+        self.world_size = dist.get_world_size()
+        self.rank = dist.get_rank()
+
+        self.dataset_length = len(self.dataset)
+        self.total_indices = list(range(self.dataset_length))
+
+    def __iter__(self):
+        random.shuffle(self.total_indices)
+        count = 0
+        batch_size = self.batch_size[0] * self.batch_size[1]
+        while True:
+            if (count + 1) * batch_size >= self.dataset_length:
+                count = 0
+                random.shuffle(self.total_indices)
+
+            sampled_indices = self.total_indices[count*batch_size:(count+1)*batch_size]
+            sampled_indices = sync_random_sample_list(sampled_indices, len(sampled_indices))
+
+            total_size = int(math.ceil(batch_size / self.world_size)) * self.world_size
+            sampled_indices += sampled_indices[:(batch_size - len(sampled_indices))]
+
+            sampled_indices = sampled_indices[self.rank:total_size:self.world_size]
+            count += 1
+
+            yield sampled_indices * 2
+
+    def __len__(self):
+        return len(self.dataset)
\ No newline at end of file