Switch to unified view

a b/pathaia/graphs/functional_api.py
1
"""
2
A module to implement useful function to handle trees.
3
Trees are stored as dictionaries.
4
"""
5
import json
6
import warnings
7
from typing import List, Optional, Sequence, Tuple, Union
8
9
import numpy as np
10
from nptyping import NDArray, Number, Shape
11
from scipy.sparse import spmatrix
12
from sklearn.neighbors import NearestNeighbors
13
14
from .errors import (
15
    InvalidEdgeProps,
16
    InvalidNodeProps,
17
    InvalidTree,
18
    UnknownNodeProperty,
19
    UnrelatedNode,
20
)
21
from .kruskal import UFDS
22
from .types import (
23
    BinaryNodeProperty,
24
    Childhood,
25
    Edge,
26
    EdgeProperties,
27
    Node,
28
    NodeProperties,
29
    NumericalEdgeProperty,
30
    NumericalNodeProperty,
31
    Parenthood,
32
)
33
34
35
def complete_tree(
36
    parents: Optional[Parenthood] = None, children: Optional[Childhood] = None
37
):
38
    if parents is None:
39
        parents = {}
40
        if children is None:
41
            children = {}
42
        else:
43
            for parent in children:
44
                children[parent] = set(children[parent])
45
                for child in children[parent]:
46
                    parents[child] = parent
47
    else:
48
        if children is None:
49
            children = {}
50
            for child, parent in parents.items():
51
                try:
52
                    children[parent].add(child)
53
                except KeyError:
54
                    children[parent] = {child}
55
        else:
56
            for parent in children:
57
                children[parent] = set(children[parent])
58
                for child in children[parent]:
59
                    if child not in parents or parents[child] != parent:
60
                        raise InvalidTree
61
    return parents, children
62
63
64
def get_root(parents: Parenthood, node: Node = None) -> Node:
65
    """
66
    Get root of a node in a tree.
67
    *****************************
68
    """
69
    if node is None:
70
        for k, v in parents.items():
71
            node = k
72
            return get_root(parents, k)
73
    if node not in parents:
74
        return node
75
    root = node
76
    while root in parents:
77
        root = parents[root]
78
    return root
79
80
81
def get_root_path(parents: Parenthood, node: Node) -> List[Node]:
82
    """
83
    Get path to root of a node in a tree.
84
    *************************************
85
    """
86
    if node not in parents:
87
        warnings.warn("Requested node {} is not in the parenthood.".format(node))
88
        return [node]
89
    root = node
90
    root_path = [node]
91
    while root in parents:
92
        root = parents[root]
93
        root_path.append(root)
94
    return root_path
95
96
97
def get_root_path_match(parents: Parenthood, node: Node, target: Node) -> List[Node]:
98
    """
99
    Get path to root of a node in a tree.
100
    *************************************
101
    """
102
    if target not in parents:
103
        warnings.warn("Target node {} is not in the parenthood.".format(target))
104
        return get_root_path(parents, node)
105
    if node not in parents:
106
        warnings.warn("Requested node {} is not in the parenthood.".format(node))
107
        return []
108
    root = node
109
    root_path = [node]
110
    while root in parents:
111
        if root == target:
112
            return root_path
113
        root = parents[root]
114
        root_path.append(root)
115
116
117
def get_leaves_without_prop(
118
    children: Childhood,
119
    node: Node,
120
) -> List[Node]:
121
    """
122
    Get leaves of a node in a tree.
123
    *******************************
124
    """
125
    if node not in children:
126
        return [node]
127
    no_lvs = [node]
128
    lvs = []
129
    while len(no_lvs) > 0:
130
        new_no_lvs = []
131
        for n in no_lvs:
132
            if n in children:
133
                for c in children[n]:
134
                    new_no_lvs.append(c)
135
            else:
136
                lvs.append(n)
137
        no_lvs = new_no_lvs
138
    return lvs
139
140
141
def get_leaves_with_prop(
142
    children: Childhood, node: Node, prop: BinaryNodeProperty
143
) -> List[Node]:
144
    """
145
    Get leaves of a node in a tree.
146
    *******************************
147
    """
148
    if node not in prop:
149
        warnings.warn("Node {} does not have the property".format(node))
150
        return []
151
    if not prop[node]:
152
        warnings.warn("Root {} does not pass the property test.".format(node))
153
        return []
154
    if node not in children:
155
        warnings.warn("Children of Root {} does not pass the property.".format(node))
156
        return [node]
157
    no_lvs = [node]
158
    lvs = []
159
    while len(no_lvs) > 0:
160
        new_no_lvs = []
161
        for n in no_lvs:
162
            if n in prop:
163
                if prop[n]:
164
                    if n in children:
165
                        candidates = []
166
                        for c in children[n]:
167
                            if c in prop:
168
                                if prop[c]:
169
                                    candidates.append(c)
170
                        new_no_lvs += candidates
171
                        if len(candidates) == 0:
172
                            lvs.append(n)
173
                    else:
174
                        lvs.append(n)
175
            else:
176
                warnings.warn("Node {} does not have the property".format(node))
177
        no_lvs = new_no_lvs
178
    return lvs
179
180
181
def get_leaves(
182
    children: Childhood, node: Node, prop: Optional[BinaryNodeProperty] = None
183
) -> List[Node]:
184
    """
185
    Get leaves of a node in a tree.
186
    *******************************
187
    """
188
    if prop is None:
189
        return get_leaves_without_prop(children, node)
190
    return get_leaves_with_prop(children, node, prop)
191
192
193
def kruskal_edges(
194
    edges: Sequence[Edge], weights: NumericalEdgeProperty
195
) -> Sequence[Edge]:
196
    """
197
    Yield kruskal edges, given a list of edges.
198
    ********************************************
199
    """
200
    # create Union-Find data structure
201
    components = UFDS()
202
    # edges are sorted by non-decreasing order of dissimilarity
203
    edges = sorted(edges, key=lambda x: weights[x])
204
    k_edges = []
205
    k_weights = []
206
207
    for edge in edges:
208
        # nodes in involved in the edge
209
        n1, n2 = edge
210
        # roots of nodes in the Union-Find
211
        rn1 = components.get_root(n1)
212
        rn2 = components.get_root(n2)
213
        # if components are differents
214
        if rn1 != rn2:
215
            components.union(edge)
216
            k_edges.append(edge)
217
            k_weights.append(weights[edge])
218
    return k_edges, k_weights
219
220
221
def kruskal_tree(
222
    edges: Sequence[Edge], weights: NumericalEdgeProperty, size: NumericalNodeProperty
223
) -> Tuple[Parenthood, Childhood, NumericalNodeProperty]:
224
    """
225
    Create parents an children relationships from kruskal edges.
226
    ***********************************************************
227
    """
228
    parents = dict()
229
    children = dict()
230
    props = {"weights": dict(), "size": dict()}
231
    k_edges, k_weights = kruskal_edges(edges, weights)
232
    max_node = 2 * len(k_edges)
233
    for edge, weight in zip(k_edges, k_weights):
234
        n1, n2 = edge
235
        rn1 = get_root(parents, n1)
236
        rn2 = get_root(parents, n2)
237
        if rn1 in props["size"]:
238
            s1 = props["size"][rn1]
239
        else:
240
            s1 = size[n1]
241
            props["size"][rn1] = s1
242
        if rn2 in props["size"]:
243
            s2 = props["size"][rn2]
244
        else:
245
            s2 = size[n2]
246
            props["size"][rn2] = s2
247
        # since it is already a spanning tree,
248
        # I know rn1 and rn2 have different roots
249
        parents[rn1] = max_node
250
        parents[rn2] = max_node
251
        children[max_node] = [rn1, rn2]
252
        props["weights"][max_node] = weight
253
        props["size"][max_node] = s1 + s2
254
255
        max_node += 1
256
    return parents, children, props
257
258
259
def tree_to_json(
260
    nodes: Sequence[Node],
261
    parents: Parenthood,
262
    children: Childhood,
263
    jsonfile: str,
264
    nodeprops: Optional[NodeProperties] = None,
265
    edgeprops: Optional[EdgeProperties] = None,
266
):
267
    """Store a jsonified tree to a json file."""
268
    output_dict = dict()
269
    output_dict["nodes"] = nodes
270
    output_dict["parents"] = parents
271
    output_dict["children"] = children
272
    output_dict["nodeprops"] = dict()
273
    output_dict["edgeprops"] = dict()
274
    if nodeprops is not None:
275
        if isinstance(nodeprops, dict):
276
            for k, v in nodeprops.items():
277
                output_dict["nodeprops"][k] = v
278
        else:
279
            raise InvalidNodeProps(
280
                "Invalid node props, "
281
                "expected {} but got {}".format(dict, type(nodeprops))
282
            )
283
    if edgeprops is not None:
284
        if isinstance(edgeprops, dict):
285
            for k, v in edgeprops.items():
286
                output_dict["edgeprops"][k] = v
287
        else:
288
            raise InvalidEdgeProps(
289
                "Invalid node props, "
290
                "expected {} but got {}".format(dict, type(edgeprops))
291
            )
292
    json_dict = json.dumps(output_dict)
293
    with open(jsonfile, "w") as outputjson:
294
        outputjson.write(json_dict)
295
296
297
def _expand_on_property(
298
    cut: List[Node],
299
    children: Childhood,
300
    prop: NumericalNodeProperty,
301
    threshold: Union[int, float],
302
) -> List[Node]:
303
    """Create a new tree by cutting based on property threshold."""
304
    candidates = []
305
    expansion = []
306
    for node in cut:
307
        if node in children:
308
            candidates += children[node]
309
    for candidate in candidates:
310
        if candidate in prop:
311
            if prop[candidate] >= threshold:
312
                expansion.append(candidate)
313
    return expansion
314
315
316
def cut_on_property(
317
    parents: Parenthood,
318
    children: Childhood,
319
    prop: NumericalNodeProperty,
320
    threshold: Union[int, float],
321
) -> List[Node]:
322
    """Produce a list of authorized nodes given a property threshold."""
323
    root = get_root(parents)
324
    cut = set()
325
    remaining = [root]
326
    while len(remaining) > 0:
327
        cut |= set(remaining)
328
        remaining = _expand_on_property(remaining, children, prop, threshold)
329
    return list(cut)
330
331
332
def common_ancestor(parents: Parenthood, node1: Node, node2: Node) -> Node:
333
    """Get the common ancestor of two nodes and store their distances to him."""
334
    if node1 in parents and node2 in parents:
335
        rp1 = get_root_path(parents, node1)
336
        rp2 = get_root_path(parents, node2)
337
        if len(set(rp1) & set(rp2)) > 0:
338
            for node in rp1:
339
                if node in rp2:
340
                    return node
341
        raise UnrelatedNode(
342
            "Nodes {} and {} have no common ancestors!!!".format(node1, node2)
343
        )
344
    raise UnrelatedNode(
345
        "One of the provided nodes: ({}, {}) has no parent...".format(node1, node2)
346
    )
347
348
349
def edge_dist(parents: Parenthood, node1: Node, node2: Node) -> int:
350
    """Return the number of edges to go from node1 to node2 (by common ancestor)."""
351
    ancestor = common_ancestor(parents, node1, node2)
352
    rpm1 = get_root_path_match(parents, node1, ancestor)
353
    rpm2 = get_root_path_match(parents, node2, ancestor)
354
    return len(set(rpm1) | set(rpm2))
355
356
357
def weighted_dist(
358
    parents: Parenthood, weights: NumericalNodeProperty, node1: Node, node2: Node
359
) -> float:
360
    """Return the number of edges to go from node1 to node2 (by common ancestor)."""
361
    ancestor = common_ancestor(parents, node1, node2)
362
    rpm1 = get_root_path_match(parents, node1, ancestor)
363
    rpm2 = get_root_path_match(parents, node2, ancestor)
364
    nodes_in_path = set(rpm1) | set(rpm2)
365
    nodes_in_path.discard(node1)
366
    nodes_in_path.discard(node2)
367
    dist = 0.0
368
    for node in nodes_in_path:
369
        if node not in weights:
370
            raise UnknownNodeProperty(
371
                "Missing weight for node {} to compute a weighted distance!!!".format(
372
                    node
373
                )
374
            )
375
        dist += weights[node]
376
    # minus 1 otherwise ancestor is counted twice
377
    return dist
378
379
380
def farthest_point_sampling(
381
    coords: NDArray[Shape["N_points, N_dims"], Number], n_samples: Union[int, float]
382
) -> NDArray[Shape["N_samples"], np.int32]:
383
    """
384
    Perform farthest points sampling using point coordinates array.
385
386
    Args:
387
        coords: array containing point coordinates.
388
        n_samples: number of point to sample. If a float is given, represents the
389
            proportion of points used instead.
390
391
    Returns:
392
        Array containing idxs of sampled points.
393
    """
394
    if isinstance(n_samples, float):
395
        n_samples = int(n_samples * len(coords))
396
397
    idxs = np.zeros(n_samples, dtype=np.int32)
398
    idxs[0] = np.random.randint(len(coords))
399
    distances = ((coords[idxs[0]] - coords) ** 2).sum(1)
400
    for i in range(1, n_samples):
401
        idxs[i] = np.argmax(distances)
402
        distances = np.minimum(distances, ((coords[idxs[i]] - coords) ** 2).sum(1))
403
404
    return idxs
405
406
407
def random_farthest_point_sampling(
408
    coords: NDArray[Shape["N_points, N_dims"], Number],
409
    n_farthest_samples: Union[int, float] = 0.3,
410
    n_random_samples: Union[int, float] = 0.1,
411
) -> NDArray[Shape["N_samples"], np.int32]:
412
    """
413
    Perform farthest points sampling using point coordinates array followed by random
414
    sampling .
415
416
    Args:
417
        coords: array containing point coordinates.
418
        n_farthest_samples: number of points to keep using farthest points sampling.
419
            If a float is given, represents the proportion of points used instead.
420
        n_random_samples: number of points to keep using random sampling. If a float
421
            is given, represents the proportion of points used instead.
422
423
    Returns:
424
        Array containing idxs of sampled points.
425
    """
426
    farthest_idxs = farthest_point_sampling(coords, n_farthest_samples)
427
428
    if isinstance(n_random_samples, float):
429
        n_random_samples = int(n_random_samples * len(coords))
430
431
    probs = np.ones(len(coords))
432
    probs[farthest_idxs] = 0
433
    random_idxs = np.arange(len(coords))
434
    random_idxs = np.random.choice(random_idxs, size=n_random_samples, p=probs)
435
    idxs = np.concatenate((farthest_idxs, random_idxs))
436
437
    return idxs
438
439
440
def get_kneighbors_graph(
441
    points: NDArray[Shape["N_points, N_dims"], Number],
442
    n_farthest_samples: Union[int, float] = 0.3,
443
    n_random_samples: Union[int, float] = 0.1,
444
    dmax: int = 500,
445
    n_neighbors: int = 5,
446
    n_jobs: Optional[int] = None,
447
) -> spmatrix:
448
    """
449
    Get a graph generated by KNN on given points.
450
451
    Args:
452
        points: array containg point coordinates.
453
        n_farthest_samples: number of points to keep using farthest points sampling. If
454
            a float is given, represents the proportion of points used instead.
455
        n_random_samples: number of points to keep using random sampling. If a float is
456
            given, represents the proportion of points used instead.
457
        dmax: maximum distance in pixels between two adjacent nodes.
458
        n_neighbors: number of neighbors to use for KNN algorithm.
459
        n_jobs: number of parallel jobs to run for neighbors search. None means 1.
460
461
    Returns:
462
        Sparse distance matrix representing the graph.
463
    """
464
    idxs = random_farthest_point_sampling(
465
        points,
466
        n_farthest_samples=n_farthest_samples,
467
        n_random_samples=n_random_samples,
468
    )
469
    X = points[idxs]
470
471
    knn = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=n_jobs).fit(X)
472
473
    A = knn.kneighbors_graph(mode="distance")
474
    Abool = A.astype(bool) - (A > dmax)
475
    A = A.multiply(Abool)
476
    return A.maximum(A.T)
477
478
479
def get_nodeprops_edgeprops(
480
    A: spmatrix, coords: NDArray[Shape["N_points, N_dims"], Number]
481
) -> Tuple[NodeProperties, EdgeProperties]:
482
    """
483
    Get coordinates and distances between edges of a graph as NodeProperties and
484
    EdgeProperties.
485
486
    Args:
487
        A: Sparse distance matrix representing the graph.
488
        coords: coordinates of the nodes.
489
490
    Returns:
491
        NodeProperties dictionary containing 'x' and 'y' entries for node coordinates
492
        and EdgeProperties dictionary containing a 'distance' entry for distances
493
        between edges.
494
    """
495
    edgeprops = {"distance": {(i, j): A[i, j] for i, j in zip(*A.nonzero())}}
496
    nodeprops = {"x": {}, "y": {}}
497
    for i, (x, y) in enumerate(coords):
498
        nodeprops["x"][i] = x
499
        nodeprops["y"][i] = y
500
    return nodeprops, edgeprops