[6d389a]: / tests / test_data / test_sampler.py

Download this file

97 lines (73 with data), 3.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (c) OpenMMLab. All rights reserved.
from torch.utils.data import DataLoader, Dataset
from mmaction.datasets.samplers import (ClassSpecificDistributedSampler,
DistributedSampler)
class MyDataset(Dataset):
def __init__(self, class_prob={i: 1 for i in range(10)}):
super().__init__()
self.class_prob = class_prob
self.video_infos = [
dict(data=idx, label=idx % 10) for idx in range(100)
]
def __len__(self):
return len(self.video_infos)
def __getitem__(self, idx):
return self.video_infos[idx]
def test_distributed_sampler():
dataset = MyDataset()
sampler = DistributedSampler(dataset, num_replicas=1, rank=0)
data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
batches = []
for _, data in enumerate(data_loader):
batches.append(data)
assert len(batches) == 25
assert sum([len(x['data']) for x in batches]) == 100
sampler = DistributedSampler(dataset, num_replicas=4, rank=2)
data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
batches = []
for i, data in enumerate(data_loader):
batches.append(data)
assert len(batches) == 7
assert sum([len(x['data']) for x in batches]) == 25
sampler = DistributedSampler(dataset, num_replicas=6, rank=3)
data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
batches = []
for i, data in enumerate(data_loader):
batches.append(data)
assert len(batches) == 5
assert sum([len(x['data']) for x in batches]) == 17
def test_class_specific_distributed_sampler():
class_prob = dict(zip(list(range(10)), [1] * 5 + [3] * 5))
dataset = MyDataset(class_prob=class_prob)
sampler = ClassSpecificDistributedSampler(
dataset, num_replicas=1, rank=0, dynamic_length=True)
data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
batches = []
for _, data in enumerate(data_loader):
batches.append(data)
assert len(batches) == 50
assert sum([len(x['data']) for x in batches]) == 200
sampler = ClassSpecificDistributedSampler(
dataset, num_replicas=1, rank=0, dynamic_length=False)
data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
batches = []
for i, data in enumerate(data_loader):
batches.append(data)
assert len(batches) == 25
assert sum([len(x['data']) for x in batches]) == 100
sampler = ClassSpecificDistributedSampler(
dataset, num_replicas=6, rank=2, dynamic_length=True)
data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
batches = []
for i, data in enumerate(data_loader):
batches.append(data)
assert len(batches) == 9
assert sum([len(x['data']) for x in batches]) == 34
sampler = ClassSpecificDistributedSampler(
dataset, num_replicas=6, rank=2, dynamic_length=False)
data_loader = DataLoader(dataset, batch_size=4, sampler=sampler)
batches = []
for i, data in enumerate(data_loader):
batches.append(data)
assert len(batches) == 5
assert sum([len(x['data']) for x in batches]) == 17