[6d389a]: / mmaction / datasets / samplers / distributed_sampler.py

Download this file

136 lines (111 with data), 5.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) OpenMMLab. All rights reserved.
import math
from collections import defaultdict
import torch
from torch.utils.data import DistributedSampler as _DistributedSampler
class DistributedSampler(_DistributedSampler):
"""DistributedSampler inheriting from
``torch.utils.data.DistributedSampler``.
In pytorch of lower versions, there is no ``shuffle`` argument. This child
class will port one to DistributedSampler.
"""
def __init__(self,
dataset,
num_replicas=None,
rank=None,
shuffle=True,
seed=0):
super().__init__(
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
# for the compatibility from PyTorch 1.3+
self.seed = seed if seed is not None else 0
def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch + self.seed)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
class ClassSpecificDistributedSampler(_DistributedSampler):
"""ClassSpecificDistributedSampler inheriting from
``torch.utils.data.DistributedSampler``.
Samples are sampled with a class specific probability, which should be an
attribute of the dataset (dataset.class_prob, which is a dictionary that
map label index to the prob). This sampler is only applicable to single
class recognition dataset. This sampler is also compatible with
RepeatDataset.
The default value of dynamic_length is True, which means we use
oversampling / subsampling, and the dataset length may changed. If
dynamic_length is set as False, the dataset length is fixed.
"""
def __init__(self,
dataset,
num_replicas=None,
rank=None,
dynamic_length=True,
shuffle=True,
seed=0):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle
if type(dataset).__name__ == 'RepeatDataset':
dataset = dataset.dataset
assert hasattr(dataset, 'class_prob')
self.class_prob = dataset.class_prob
self.dynamic_length = dynamic_length
# for the compatibility from PyTorch 1.3+
self.seed = seed if seed is not None else 0
def __iter__(self):
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
class_indices = defaultdict(list)
# To be compatible with RepeatDataset
times = 1
dataset = self.dataset
if type(dataset).__name__ == 'RepeatDataset':
times = dataset.times
dataset = dataset.dataset
for i, item in enumerate(dataset.video_infos):
class_indices[item['label']].append(i)
if self.dynamic_length:
indices = []
for k, prob in self.class_prob.items():
prob = prob * times
for i in range(int(prob // 1)):
indices.extend(class_indices[k])
rem = int((prob % 1) * len(class_indices[k]))
rem_indices = torch.randperm(
len(class_indices[k]), generator=g).tolist()[:rem]
indices.extend(rem_indices)
if self.shuffle:
shuffle = torch.randperm(len(indices), generator=g).tolist()
indices = [indices[i] for i in shuffle]
# re-calc num_samples & total_size
self.num_samples = math.ceil(len(indices) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
else:
# We want to keep the dataloader length same as original
video_labels = [x['label'] for x in dataset.video_infos]
probs = [
self.class_prob[lb] / len(class_indices[lb])
for lb in video_labels
]
indices = torch.multinomial(
torch.Tensor(probs),
self.total_size,
replacement=True,
generator=g)
indices = indices.data.numpy().tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# retrieve indices for current process
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)