|
a |
|
b/baselines/common/segment_tree.py |
|
|
1 |
import operator |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
class SegmentTree(object): |
|
|
5 |
def __init__(self, capacity, operation, neutral_element): |
|
|
6 |
"""Build a Segment Tree data structure. |
|
|
7 |
|
|
|
8 |
https://en.wikipedia.org/wiki/Segment_tree |
|
|
9 |
|
|
|
10 |
Can be used as regular array, but with two |
|
|
11 |
important differences: |
|
|
12 |
|
|
|
13 |
a) setting item's value is slightly slower. |
|
|
14 |
It is O(lg capacity) instead of O(1). |
|
|
15 |
b) user has access to an efficient ( O(log segment size) ) |
|
|
16 |
`reduce` operation which reduces `operation` over |
|
|
17 |
a contiguous subsequence of items in the array. |
|
|
18 |
|
|
|
19 |
Paramters |
|
|
20 |
--------- |
|
|
21 |
capacity: int |
|
|
22 |
Total size of the array - must be a power of two. |
|
|
23 |
operation: lambda obj, obj -> obj |
|
|
24 |
and operation for combining elements (eg. sum, max) |
|
|
25 |
must form a mathematical group together with the set of |
|
|
26 |
possible values for array elements (i.e. be associative) |
|
|
27 |
neutral_element: obj |
|
|
28 |
neutral element for the operation above. eg. float('-inf') |
|
|
29 |
for max and 0 for sum. |
|
|
30 |
""" |
|
|
31 |
assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." |
|
|
32 |
self._capacity = capacity |
|
|
33 |
self._value = [neutral_element for _ in range(2 * capacity)] |
|
|
34 |
self._operation = operation |
|
|
35 |
|
|
|
36 |
def _reduce_helper(self, start, end, node, node_start, node_end): |
|
|
37 |
if start == node_start and end == node_end: |
|
|
38 |
return self._value[node] |
|
|
39 |
mid = (node_start + node_end) // 2 |
|
|
40 |
if end <= mid: |
|
|
41 |
return self._reduce_helper(start, end, 2 * node, node_start, mid) |
|
|
42 |
else: |
|
|
43 |
if mid + 1 <= start: |
|
|
44 |
return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) |
|
|
45 |
else: |
|
|
46 |
return self._operation( |
|
|
47 |
self._reduce_helper(start, mid, 2 * node, node_start, mid), |
|
|
48 |
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) |
|
|
49 |
) |
|
|
50 |
|
|
|
51 |
def reduce(self, start=0, end=None): |
|
|
52 |
"""Returns result of applying `self.operation` |
|
|
53 |
to a contiguous subsequence of the array. |
|
|
54 |
|
|
|
55 |
self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) |
|
|
56 |
|
|
|
57 |
Parameters |
|
|
58 |
---------- |
|
|
59 |
start: int |
|
|
60 |
beginning of the subsequence |
|
|
61 |
end: int |
|
|
62 |
end of the subsequences |
|
|
63 |
|
|
|
64 |
Returns |
|
|
65 |
------- |
|
|
66 |
reduced: obj |
|
|
67 |
result of reducing self.operation over the specified range of array elements. |
|
|
68 |
""" |
|
|
69 |
if end is None: |
|
|
70 |
end = self._capacity |
|
|
71 |
if end < 0: |
|
|
72 |
end += self._capacity |
|
|
73 |
end -= 1 |
|
|
74 |
return self._reduce_helper(start, end, 1, 0, self._capacity - 1) |
|
|
75 |
|
|
|
76 |
def __setitem__(self, idx, val): |
|
|
77 |
# index of the leaf |
|
|
78 |
idx += self._capacity |
|
|
79 |
self._value[idx] = val |
|
|
80 |
idx //= 2 |
|
|
81 |
while idx >= 1: |
|
|
82 |
self._value[idx] = self._operation( |
|
|
83 |
self._value[2 * idx], |
|
|
84 |
self._value[2 * idx + 1] |
|
|
85 |
) |
|
|
86 |
idx //= 2 |
|
|
87 |
|
|
|
88 |
def __getitem__(self, idx): |
|
|
89 |
assert 0 <= idx < self._capacity |
|
|
90 |
return self._value[self._capacity + idx] |
|
|
91 |
|
|
|
92 |
|
|
|
93 |
class SumSegmentTree(SegmentTree): |
|
|
94 |
def __init__(self, capacity): |
|
|
95 |
super(SumSegmentTree, self).__init__( |
|
|
96 |
capacity=capacity, |
|
|
97 |
operation=operator.add, |
|
|
98 |
neutral_element=0.0 |
|
|
99 |
) |
|
|
100 |
|
|
|
101 |
def sum(self, start=0, end=None): |
|
|
102 |
"""Returns arr[start] + ... + arr[end]""" |
|
|
103 |
return super(SumSegmentTree, self).reduce(start, end) |
|
|
104 |
|
|
|
105 |
def find_prefixsum_idx(self, prefixsum): |
|
|
106 |
"""Find the highest index `i` in the array such that |
|
|
107 |
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum |
|
|
108 |
|
|
|
109 |
if array values are probabilities, this function |
|
|
110 |
allows to sample indexes according to the discrete |
|
|
111 |
probability efficiently. |
|
|
112 |
|
|
|
113 |
Parameters |
|
|
114 |
---------- |
|
|
115 |
perfixsum: float |
|
|
116 |
upperbound on the sum of array prefix |
|
|
117 |
|
|
|
118 |
Returns |
|
|
119 |
------- |
|
|
120 |
idx: int |
|
|
121 |
highest index satisfying the prefixsum constraint |
|
|
122 |
""" |
|
|
123 |
assert 0 <= prefixsum <= self.sum() + 1e-5 |
|
|
124 |
idx = 1 |
|
|
125 |
while idx < self._capacity: # while non-leaf |
|
|
126 |
if self._value[2 * idx] > prefixsum: |
|
|
127 |
idx = 2 * idx |
|
|
128 |
else: |
|
|
129 |
prefixsum -= self._value[2 * idx] |
|
|
130 |
idx = 2 * idx + 1 |
|
|
131 |
return idx - self._capacity |
|
|
132 |
|
|
|
133 |
|
|
|
134 |
class MinSegmentTree(SegmentTree): |
|
|
135 |
def __init__(self, capacity): |
|
|
136 |
super(MinSegmentTree, self).__init__( |
|
|
137 |
capacity=capacity, |
|
|
138 |
operation=min, |
|
|
139 |
neutral_element=float('inf') |
|
|
140 |
) |
|
|
141 |
|
|
|
142 |
def min(self, start=0, end=None): |
|
|
143 |
"""Returns min(arr[start], ..., arr[end])""" |
|
|
144 |
|
|
|
145 |
return super(MinSegmentTree, self).reduce(start, end) |