a b/opengait/data/sampler.py
1
import math
2
import random
3
import torch
4
import torch.distributed as dist
5
import torch.utils.data as tordata
6
7
8
class TripletSampler(tordata.sampler.Sampler):
9
    def __init__(self, dataset, batch_size, batch_shuffle=False):
10
        self.dataset = dataset
11
        self.batch_size = batch_size
12
        if len(self.batch_size) != 2:
13
            raise ValueError(
14
                "batch_size should be (P x K) not {}".format(batch_size))
15
        self.batch_shuffle = batch_shuffle
16
17
        self.world_size = dist.get_world_size()
18
        if (self.batch_size[0]*self.batch_size[1]) % self.world_size != 0:
19
            raise ValueError("World size ({}) is not divisible by batch_size ({} x {})".format(
20
                self.world_size, batch_size[0], batch_size[1]))
21
        self.rank = dist.get_rank()
22
23
    def __iter__(self):
24
        while True:
25
            sample_indices = []
26
            pid_list = sync_random_sample_list(
27
                self.dataset.label_set, self.batch_size[0])
28
29
            for pid in pid_list:
30
                indices = self.dataset.indices_dict[pid]
31
                indices = sync_random_sample_list(
32
                    indices, k=self.batch_size[1])
33
                sample_indices += indices
34
35
            if self.batch_shuffle:
36
                sample_indices = sync_random_sample_list(
37
                    sample_indices, len(sample_indices))
38
39
            total_batch_size = self.batch_size[0] * self.batch_size[1]
40
            total_size = int(math.ceil(total_batch_size /
41
                                       self.world_size)) * self.world_size
42
            sample_indices += sample_indices[:(
43
                total_batch_size - len(sample_indices))]
44
45
            sample_indices = sample_indices[self.rank:total_size:self.world_size]
46
            yield sample_indices
47
48
    def __len__(self):
49
        return len(self.dataset)
50
51
52
def sync_random_sample_list(obj_list, k, common_choice=False):
53
    if common_choice:
54
        idx = random.choices(range(len(obj_list)), k=k) 
55
        idx = torch.tensor(idx)
56
    if len(obj_list) < k:
57
        idx = random.choices(range(len(obj_list)), k=k)
58
        idx = torch.tensor(idx)
59
    else:
60
        idx = torch.randperm(len(obj_list))[:k]
61
    if torch.cuda.is_available():
62
        idx = idx.cuda()
63
    torch.distributed.broadcast(idx, src=0)
64
    idx = idx.tolist()
65
    return [obj_list[i] for i in idx]
66
67
68
class InferenceSampler(tordata.sampler.Sampler):
69
    def __init__(self, dataset, batch_size):
70
        self.dataset = dataset
71
        self.batch_size = batch_size
72
73
        self.size = len(dataset)
74
        indices = list(range(self.size))
75
76
        world_size = dist.get_world_size()
77
        rank = dist.get_rank()
78
79
        if batch_size % world_size != 0:
80
            raise ValueError("World size ({}) is not divisible by batch_size ({})".format(
81
                world_size, batch_size))
82
83
        if batch_size != 1:
84
            complement_size = math.ceil(self.size / batch_size) * \
85
                batch_size
86
            indices += indices[:(complement_size - self.size)]
87
            self.size = complement_size
88
89
        batch_size_per_rank = int(self.batch_size / world_size)
90
        indx_batch_per_rank = []
91
92
        for i in range(int(self.size / batch_size_per_rank)):
93
            indx_batch_per_rank.append(
94
                indices[i*batch_size_per_rank:(i+1)*batch_size_per_rank])
95
96
        self.idx_batch_this_rank = indx_batch_per_rank[rank::world_size]
97
98
    def __iter__(self):
99
        yield from self.idx_batch_this_rank
100
101
    def __len__(self):
102
        return len(self.dataset)
103
104
105
class CommonSampler(tordata.sampler.Sampler):
106
    def __init__(self,dataset,batch_size,batch_shuffle):
107
108
        self.dataset = dataset
109
        self.size = len(dataset)
110
        self.batch_size = batch_size
111
        if isinstance(self.batch_size,int)==False:
112
            raise ValueError(
113
                "batch_size shoude be (B) not {}".format(batch_size))
114
        self.batch_shuffle = batch_shuffle
115
        
116
        self.world_size = dist.get_world_size()
117
        if self.batch_size % self.world_size !=0:
118
            raise ValueError("World size ({}) is not divisble by batch_size ({})".format(
119
                self.world_size, batch_size))
120
        self.rank = dist.get_rank() 
121
    
122
    def __iter__(self):
123
        while True:
124
            indices_list = list(range(self.size))
125
            sample_indices = sync_random_sample_list(
126
                    indices_list, self.batch_size, common_choice=True)
127
            total_batch_size =  self.batch_size
128
            total_size = int(math.ceil(total_batch_size /
129
                                       self.world_size)) * self.world_size
130
            sample_indices += sample_indices[:(
131
                total_batch_size - len(sample_indices))]
132
            sample_indices = sample_indices[self.rank:total_size:self.world_size]
133
            yield sample_indices
134
135
    def __len__(self):
136
        return len(self.dataset)
137
138
# **************** For GaitSSB ****************
139
# Fan, et al: Learning Gait Representation from Massive Unlabelled Walking Videos: A Benchmark, T-PAMI2023
140
import random
141
class BilateralSampler(tordata.sampler.Sampler):
142
    def __init__(self, dataset, batch_size, batch_shuffle=False):
143
        self.dataset = dataset
144
        self.batch_size = batch_size
145
        self.batch_shuffle = batch_shuffle
146
147
        self.world_size = dist.get_world_size()
148
        self.rank = dist.get_rank()
149
150
        self.dataset_length = len(self.dataset)
151
        self.total_indices = list(range(self.dataset_length))
152
153
    def __iter__(self):
154
        random.shuffle(self.total_indices)
155
        count = 0
156
        batch_size = self.batch_size[0] * self.batch_size[1]
157
        while True:
158
            if (count + 1) * batch_size >= self.dataset_length:
159
                count = 0
160
                random.shuffle(self.total_indices)
161
162
            sampled_indices = self.total_indices[count*batch_size:(count+1)*batch_size]
163
            sampled_indices = sync_random_sample_list(sampled_indices, len(sampled_indices))
164
165
            total_size = int(math.ceil(batch_size / self.world_size)) * self.world_size
166
            sampled_indices += sampled_indices[:(batch_size - len(sampled_indices))]
167
168
            sampled_indices = sampled_indices[self.rank:total_size:self.world_size]
169
            count += 1
170
171
            yield sampled_indices * 2
172
173
    def __len__(self):
174
        return len(self.dataset)