Switch to unified view

a b/baselines/common/segment_tree.py
1
import operator
2
3
4
class SegmentTree(object):
5
    def __init__(self, capacity, operation, neutral_element):
6
        """Build a Segment Tree data structure.
7
8
        https://en.wikipedia.org/wiki/Segment_tree
9
10
        Can be used as regular array, but with two
11
        important differences:
12
13
            a) setting item's value is slightly slower.
14
               It is O(lg capacity) instead of O(1).
15
            b) user has access to an efficient ( O(log segment size) )
16
               `reduce` operation which reduces `operation` over
17
               a contiguous subsequence of items in the array.
18
19
        Paramters
20
        ---------
21
        capacity: int
22
            Total size of the array - must be a power of two.
23
        operation: lambda obj, obj -> obj
24
            and operation for combining elements (eg. sum, max)
25
            must form a mathematical group together with the set of
26
            possible values for array elements (i.e. be associative)
27
        neutral_element: obj
28
            neutral element for the operation above. eg. float('-inf')
29
            for max and 0 for sum.
30
        """
31
        assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2."
32
        self._capacity = capacity
33
        self._value = [neutral_element for _ in range(2 * capacity)]
34
        self._operation = operation
35
36
    def _reduce_helper(self, start, end, node, node_start, node_end):
37
        if start == node_start and end == node_end:
38
            return self._value[node]
39
        mid = (node_start + node_end) // 2
40
        if end <= mid:
41
            return self._reduce_helper(start, end, 2 * node, node_start, mid)
42
        else:
43
            if mid + 1 <= start:
44
                return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
45
            else:
46
                return self._operation(
47
                    self._reduce_helper(start, mid, 2 * node, node_start, mid),
48
                    self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
49
                )
50
51
    def reduce(self, start=0, end=None):
52
        """Returns result of applying `self.operation`
53
        to a contiguous subsequence of the array.
54
55
            self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))
56
57
        Parameters
58
        ----------
59
        start: int
60
            beginning of the subsequence
61
        end: int
62
            end of the subsequences
63
64
        Returns
65
        -------
66
        reduced: obj
67
            result of reducing self.operation over the specified range of array elements.
68
        """
69
        if end is None:
70
            end = self._capacity
71
        if end < 0:
72
            end += self._capacity
73
        end -= 1
74
        return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
75
76
    def __setitem__(self, idx, val):
77
        # index of the leaf
78
        idx += self._capacity
79
        self._value[idx] = val
80
        idx //= 2
81
        while idx >= 1:
82
            self._value[idx] = self._operation(
83
                self._value[2 * idx],
84
                self._value[2 * idx + 1]
85
            )
86
            idx //= 2
87
88
    def __getitem__(self, idx):
89
        assert 0 <= idx < self._capacity
90
        return self._value[self._capacity + idx]
91
92
93
class SumSegmentTree(SegmentTree):
94
    def __init__(self, capacity):
95
        super(SumSegmentTree, self).__init__(
96
            capacity=capacity,
97
            operation=operator.add,
98
            neutral_element=0.0
99
        )
100
101
    def sum(self, start=0, end=None):
102
        """Returns arr[start] + ... + arr[end]"""
103
        return super(SumSegmentTree, self).reduce(start, end)
104
105
    def find_prefixsum_idx(self, prefixsum):
106
        """Find the highest index `i` in the array such that
107
            sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
108
109
        if array values are probabilities, this function
110
        allows to sample indexes according to the discrete
111
        probability efficiently.
112
113
        Parameters
114
        ----------
115
        perfixsum: float
116
            upperbound on the sum of array prefix
117
118
        Returns
119
        -------
120
        idx: int
121
            highest index satisfying the prefixsum constraint
122
        """
123
        assert 0 <= prefixsum <= self.sum() + 1e-5
124
        idx = 1
125
        while idx < self._capacity:  # while non-leaf
126
            if self._value[2 * idx] > prefixsum:
127
                idx = 2 * idx
128
            else:
129
                prefixsum -= self._value[2 * idx]
130
                idx = 2 * idx + 1
131
        return idx - self._capacity
132
133
134
class MinSegmentTree(SegmentTree):
135
    def __init__(self, capacity):
136
        super(MinSegmentTree, self).__init__(
137
            capacity=capacity,
138
            operation=min,
139
            neutral_element=float('inf')
140
        )
141
142
    def min(self, start=0, end=None):
143
        """Returns min(arr[start], ...,  arr[end])"""
144
145
        return super(MinSegmentTree, self).reduce(start, end)