a b/submission/baselines/common/dataset.py
1
import numpy as np
2
3
class Dataset(object):
4
    def __init__(self, data_map, deterministic=False, shuffle=True):
5
        self.data_map = data_map
6
        self.deterministic = deterministic
7
        self.enable_shuffle = shuffle
8
        self.n = next(iter(data_map.values())).shape[0]
9
        self._next_id = 0
10
        self.shuffle()
11
12
    def shuffle(self):
13
        if self.deterministic:
14
            return
15
        perm = np.arange(self.n)
16
        np.random.shuffle(perm)
17
18
        for key in self.data_map:
19
            self.data_map[key] = self.data_map[key][perm]
20
21
        self._next_id = 0
22
23
    def next_batch(self, batch_size):
24
        if self._next_id >= self.n and self.enable_shuffle:
25
            self.shuffle()
26
27
        cur_id = self._next_id
28
        cur_batch_size = min(batch_size, self.n - self._next_id)
29
        self._next_id += cur_batch_size
30
31
        data_map = dict()
32
        for key in self.data_map:
33
            data_map[key] = self.data_map[key][cur_id:cur_id+cur_batch_size]
34
        return data_map
35
36
    def iterate_once(self, batch_size):
37
        if self.enable_shuffle: self.shuffle()
38
39
        while self._next_id <= self.n - batch_size:
40
            yield self.next_batch(batch_size)
41
        self._next_id = 0
42
43
    def subset(self, num_elements, deterministic=True):
44
        data_map = dict()
45
        for key in self.data_map:
46
            data_map[key] = self.data_map[key][:num_elements]
47
        return Dataset(data_map, deterministic)
48
49
50
def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True):
51
    assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both'
52
    arrays = tuple(map(np.asarray, arrays))
53
    n = arrays[0].shape[0]
54
    assert all(a.shape[0] == n for a in arrays[1:])
55
    inds = np.arange(n)
56
    if shuffle: np.random.shuffle(inds)
57
    sections = np.arange(0, n, batch_size)[1:] if num_batches is None else num_batches
58
    for batch_inds in np.array_split(inds, sections):
59
        if include_final_partial_batch or len(batch_inds) == batch_size:
60
            yield tuple(a[batch_inds] for a in arrays)