Switch to side-by-side view

--- a
+++ b/pathaia/graphs/object_api.py
@@ -0,0 +1,505 @@
+"""Classes used to represent graphs."""
+from typing import List, Sequence, Optional, Union, Tuple
+import json
+import warnings
+from scipy.sparse import spmatrix, dok_matrix
+import numpy as np
+from ordered_set import OrderedSet
+from shapely.geometry import Polygon
+from shapely.ops import unary_union
+from shapely.affinity import translate
+from pathlib import Path
+from .types import (
+    Node,
+    NodeProperties,
+    BinaryNodeProperty,
+    NumericalNodeProperty,
+    Parenthood,
+    Childhood,
+    Edge,
+    UEdge,
+    EdgeProperties,
+    NumericalEdgeProperty,
+)
+from ..util.types import PathLike
+from .errors import (
+    InvalidNodeProps,
+    UndefinedParenthood,
+    UndefinedChildhood,
+    UnknownNodeProperty,
+)
+from .functional_api import (
+    complete_tree as _complete_tree,
+    get_nodeprops_edgeprops,
+    get_root as _get_root,
+    get_root_path as _get_root_path,
+    get_leaves as _get_leaves,
+    tree_to_json as _tree_to_json,
+    kruskal_tree as _kruskal_tree,
+    cut_on_property as _cut_on_property,
+    common_ancestor as _common_ancestor,
+    edge_dist as _edge_dist,
+    weighted_dist as _weighted_dist,
+    get_kneighbors_graph,
+)
+from ..util.basic import ifnone
+import ast
+
+MAX_N_NODES = int(10e7)
+
+
+class Graph:
+    """Object to represent a directed graph."""
+
+    def __init__(
+        self,
+        nodes: Optional[Sequence[Node]] = None,
+        edges: Optional[Sequence[Edge]] = None,
+        A: Optional[spmatrix] = None,
+        nodeprops: Optional[NodeProperties] = None,
+        edgeprops: Optional[EdgeProperties] = None,
+    ):
+        self.A_ = dok_matrix((MAX_N_NODES, MAX_N_NODES), dtype=bool)
+        if nodes is None:
+            self.nodes_ = OrderedSet()
+            if edges is not None:
+                self.edges_ = set(edges)
+                for x, y in edges:
+                    i = self.nodes_.add(x)
+                    j = self.nodes_.add(y)
+                    self.A_[i, j] = True
+            elif A is not None:
+                self.nodes_ = OrderedSet(np.arange(A.shape[0]))
+                self.edges_ = set()
+                for i, j in zip(*A.nonzero()):
+                    self.edges_.add((i, j))
+                    self.A_[i, j] = True
+            else:
+                self.edges_ = set()
+        else:
+            self.nodes_ = OrderedSet(nodes)
+            if edges is not None:
+                self.edges_ = set(edges)
+                for x, y in edges:
+                    i = self.nodes_.index(x)
+                    j = self.nodes_.index(y)
+                    self.A_[i, j] = True
+            elif A is not None:
+                self.edges_ = set()
+                for i, j in zip(*A.nonzero()):
+                    self.edges_.add((self.nodes_[i], self.nodes_[j]))
+                    self.A_[i, j] = True
+            else:
+                self.edges_ = set()
+
+        self.nodeprops = ifnone(nodeprops, {})
+        self.edgeprops = ifnone(edgeprops, {})
+
+    @property
+    def n_nodes(self):
+        return len(self.nodes_)
+
+    @property
+    def nodes(self):
+        return self.nodes_
+
+    @property
+    def edges(self):
+        return self.edges_
+
+    @property
+    def A(self):
+        return self.A_.tocsr()[: self.n_nodes, : self.n_nodes]
+
+    def add_node(self, node: Node):
+        self.nodes_.add(node)
+
+    def add_nodes(self, nodes: Sequence[Node]):
+        for node in nodes:
+            self.add_node(node)
+
+    def add_edge(self, edge: Edge):
+        self.add_nodes(edge)
+        self.edges_.add(edge)
+        n1, n2 = edge
+        i = self.nodes_.index(n1)
+        j = self.nodes_.index(n2)
+        self.A_[i, j] = True
+
+    def add_edges(self, edges: Sequence[Edge]):
+        for edge in edges:
+            self.add_edge(edge)
+
+    def remove_edge(self, edge: Edge):
+        try:
+            self.edges_.remove(edge)
+        except KeyError:
+            print(f"Edge {edge} was not found in graph")
+        n1, n2 = edge
+        i = self.nodes_.index(n1)
+        j = self.nodes_.index(n2)
+        self.A_[i, j] = False
+
+    def reset(self):
+        self.nodes_ = OrderedSet()
+        self.edges_ = set()
+        self.A_ = dok_matrix((MAX_N_NODES, MAX_N_NODES), dtype=bool)
+        self.nodeprops = {}
+        self.edgeprops = {}
+
+
+class UGraph(Graph):
+    """Class to represent an undirected graph."""
+
+    def __init__(
+        self,
+        nodes: Optional[Sequence[Node]] = None,
+        edges: Optional[Sequence[Edge]] = None,
+        A: Optional[spmatrix] = None,
+        nodeprops: Optional[NodeProperties] = None,
+        edgeprops: Optional[EdgeProperties] = None,
+    ):
+        super().__init__(nodes, edges, A, nodeprops, edgeprops)
+        self.edges_ = {UEdge(edge, key=self.nodes_.index) for edge in self.edges_}
+
+    @property
+    def A(self):
+        A = self.A_.tocsr()[: self.n_nodes, : self.n_nodes]
+        return A + A.T
+
+    def add_edge(self, edge: Edge):
+        super().add_edge(UEdge(edge, key=self.nodes_.index))
+
+    def remove_edge(self, edge: Edge):
+        super().remove_edge(UEdge(edge, key=self.nodes_.index))
+        n1, n2 = edge
+        i = self.nodes_.index(n1)
+        j = self.nodes_.index(n2)
+        self.A_[j, i] = False
+
+    @classmethod
+    def from_hovernet_wsi_file(
+        cls,
+        wsi_file: PathLike,
+        n_farthest_samples: Union[int, float] = 0.3,
+        n_random_samples: Union[int, float] = 0.1,
+        dmax: int = 500,
+        n_neighbors: int = 5,
+        n_jobs: Optional[int] = None,
+    ):
+        """
+        Create a cell graph from a single hovernet json file generated from their WSI
+        script.
+
+        Args:
+            wsi_file: json_file generated by hovernet's run_wsi.sh.
+            n_farthest_samples: number of points to keep using farthest points sampling.
+                If a float is given, represents the proportion of points used instead.
+            n_random_samples: number of points to keep using random sampling. If a float
+                is given, represents the proportion of points used instead.
+            dmax: maximum distance in pixels between two adjacent nodes.
+            n_neighbors: number of neighbors to use for KNN algorithm.
+            n_jobs: number of parallel jobs to run for neighbors search. None means 1.
+
+        Returns:
+            A UGraph representing cell nuclei connections.
+        """
+        with open(wsi_file, "r") as f:
+            nuc_dict = json.load(f)
+        centroids = []
+
+        for k in nuc_dict["nuc"]:
+            x, y = nuc_dict["nuc"][k]["centroid"]
+            centroids.append((x, y))
+        centroids = np.array(centroids)
+
+        A = get_kneighbors_graph(
+            centroids,
+            n_farthest_samples=n_farthest_samples,
+            n_random_samples=n_random_samples,
+            dmax=dmax,
+            n_neighbors=n_neighbors,
+            n_jobs=n_jobs,
+        )
+        nodeprops, edgeprops = get_nodeprops_edgeprops(A, centroids)
+        return cls(A=A, nodeprops=nodeprops, edgeprops=edgeprops)
+
+    @classmethod
+    def from_hovernet_patch_file(
+        cls,
+        patch_folder: PathLike,
+        n_farthest_samples: Union[int, float] = 0.3,
+        n_random_samples: Union[int, float] = 0.1,
+        dmax: int = 500,
+        n_neighbors: int = 5,
+        n_jobs: Optional[int] = None,
+    ):
+        """
+        Create a cell graph from a folder containing hovernet json files generated from
+        their tile script.
+
+        Args:
+            patch_folder: folder containing json_files generated by hovernet's
+            run_tile.sh. Files must be named with x_y_level.json formatting.
+            n_farthest_samples: number of points to keep using farthest points sampling.
+                If a float is given, represents the proportion of points used instead.
+            n_random_samples: number of points to keep using random sampling. If a float
+                is given, represents the proportion of points used instead.
+            dmax: maximum distance in pixels between two adjacent nodes.
+            n_neighbors: number of neighbors to use for KNN algorithm.
+            n_jobs: number of parallel jobs to run for neighbors search. None means 1.
+
+        Returns:
+            A UGraph representing cell nuclei connections.
+        """
+        patch_folder = Path(patch_folder)
+        polygons = []
+
+        for json_file in patch_folder.iterdir():
+            with open(json_file, "r") as f:
+                nuc_dict = json.load(f)
+            x, y = map(int, json_file.stem.split("_")[:2])
+            for k in nuc_dict["nuc"]:
+                contour = nuc_dict["nuc"][k]["contour"]
+                polygon = Polygon(contour)
+                polygon = translate(polygon, xoff=x, yoff=y)
+                polygons.append(polygon)
+        polygons = unary_union(polygons)
+
+        centroids = [(polygon.centroid.x, polygon.centroid.y) for polygon in polygons]
+        centroids = np.array(centroids, dtype=np.int32)
+
+        A = get_kneighbors_graph(
+            centroids,
+            n_farthest_samples=n_farthest_samples,
+            n_random_samples=n_random_samples,
+            dmax=dmax,
+            n_neighbors=n_neighbors,
+            n_jobs=n_jobs,
+        )
+        nodeprops, edgeprops = get_nodeprops_edgeprops(A, centroids)
+        return cls(A=A, edgeprops=edgeprops, nodeprops=nodeprops)
+
+
+class Tree(Graph):
+    """Object to handle trees."""
+
+    def __init__(
+        self,
+        nodes: Optional[Sequence[Node]] = None,
+        edges: Optional[Sequence[Edge]] = None,
+        parents: Optional[Parenthood] = None,
+        children: Optional[Childhood] = None,
+        nodeprops: Optional[NodeProperties] = None,
+        edgeprops: Optional[EdgeProperties] = None,
+        jsonfile: Optional[str] = None,
+    ):
+        """Init tree object."""
+        if jsonfile is not None:
+            self.from_json(jsonfile)
+            edges = set()
+            for parent in self.children_:
+                for child in self.children_[parent]:
+                    edges.add((parent, child))
+        else:
+            if edges is not None and (parents is not None or children is not None):
+                warnings.warn(
+                    "Be careful when specifying both edges and parents/children,"
+                    "consistency will not be checked and edges will be prioritized."
+                )
+            if edges is None:
+                edges = set()
+                self.parents_, self.children_ = _complete_tree(parents, children)
+                for parent in self.children_:
+                    for child in self.children_[parent]:
+                        edges.add((parent, child))
+            else:
+                edges = set(edges)
+                self.parents_ = {}
+                self.children_ = {}
+                for parent, child in edges:
+                    self.parents_[child] = parent
+                    try:
+                        self.children_[parent].add(child)
+                    except KeyError:
+                        self.children_[parent] = {child}
+        super().__init__(
+            nodes=nodes, edges=edges, nodeprops=nodeprops, edgeprops=edgeprops
+        )
+
+    @property
+    def parents(self) -> Parenthood:
+        return self.parents_
+
+    @property
+    def children(self) -> Childhood:
+        return self.children_
+
+    def add_edge(self, parent: Node, child: Node):
+        self.parents_[child] = parent
+        try:
+            self.children_[parent].add(child)
+        except KeyError:
+            self.children_[parent] = {child}
+        super().add_edge((parent, child))
+
+    def add_children(self, parent: Node, children: Sequence[Node]):
+        for child in children:
+            self.parents_[child] = parent
+            super().add_edge((parent, child))
+        try:
+            self.children_[parent] |= set(children)
+        except KeyError:
+            self.children_[parent] = set(children)
+
+    def add_edges(self, edges: Sequence[Tuple[Node, Union[Node, Sequence[Node]]]]):
+        for p, c in edges:
+            if isinstance(c, Node):
+                self.add_edge(p, c)
+            else:
+                self.add_children(p, c)
+
+    def reset(self):
+        super().reset()
+        self.parents_ = {}
+        self.children_ = {}
+
+    def get_root(self, node: Node = None) -> Node:
+        """Give root of the tree."""
+        if self.parents_ is not None:
+            return _get_root(self.parents_, node)
+        raise UndefinedParenthood(
+            "Parenthood of the tree was not defined, "
+            "please build the tree before use."
+        )
+
+    def get_root_path(self, node: Node) -> List[Node]:
+        """Get path to root of the tree."""
+        if self.parents_ is not None:
+            return _get_root_path(self.parents_, node)
+        raise UndefinedParenthood(
+            "Parenthood of the tree was not defined, "
+            "please build the tree before use."
+        )
+
+    def get_leaves(
+        self, node: Node, prop: Optional[BinaryNodeProperty] = None
+    ) -> List[Node]:
+        """Get leaves of a node."""
+        if self.children_ is not None:
+            return _get_leaves(self.children_, node, prop)
+        raise UndefinedChildhood(
+            "Childhood of the tree was not defined, "
+            "please build the tree before use."
+        )
+
+    def to_json(self, jsonfile):
+        """Store the tree to json file."""
+        _tree_to_json(
+            self.nodes_,
+            self.parents_,
+            self.children_,
+            jsonfile,
+            self.nodeprops,
+            self.edgeprops,
+        )
+
+    def from_json(self, jsonfile):
+        """Create the tree from a json file."""
+        # Keep in mind that json keys have to be str.
+        # In treez framework, they can be python object as well
+        # We use ast to parse the str to a python object before
+
+        # This behaviour might limit even more the types of
+        # parenthood/childhood/props keys when using treez...
+        with open(jsonfile, "r") as jf:
+            json_dict = json.load(jf)
+        self.reset()
+        for parent, children in json_dict["children"].items():
+            try:
+                parentkey = ast.literal_eval(parent)
+                self.add_children(parentkey, children)
+            except (ValueError, SyntaxError):
+                self.add_children(parent, children)
+        self.edgeprops = dict()
+        for name, edgeprop in json_dict["edgeprops"].item():
+            self.edgeprops[name] = dict()
+            for edgein, edgeout in edgeprop.items():
+                try:
+                    edgekey = ast.literal_eval(edgein)
+                    self.edgeprops[name][edgekey] = edgeout
+                except (ValueError, SyntaxError):
+                    self.edgeprops[name][edgein] = edgeout
+        for name, nodeprop in json_dict["nodeprops"].items():
+            for nodein, nodeout in nodeprop.items():
+                try:
+                    nodekey = ast.literal_eval(nodein)
+                    self.nodeprops[name][nodekey] = nodeout
+                except (ValueError, SyntaxError):
+                    self.nodeprops[name][nodein] = nodeout
+
+    def build_kruskal(
+        self,
+        edges: Sequence[Edge],
+        weights: NumericalEdgeProperty,
+        size: NumericalNodeProperty,
+    ):
+        """Build tree with kruskal algorithm from graph edges."""
+        _, k_children, k_props = _kruskal_tree(edges, weights, size)
+        for parent in k_children:
+            self.add_children(parent, k_children)
+        self.nodeprops = k_props
+
+    def cut_on_property(self, cut_name: str, prop: str, threshold: Union[int, float]):
+        """
+        Produce a list of authorized nodes given a property threshold.
+        Set a new property to these nodes.
+        """
+        if prop in self.nodeprops:
+            node_of_interest = _cut_on_property(
+                self.parents_, self.children_, self.nodeprops[prop], threshold
+            )
+            cut = dict()
+            for node in self.nodes:
+                if node in node_of_interest:
+                    cut[node] = True
+                else:
+                    cut[node] = False
+            self.nodeprops[cut_name] = cut
+
+        else:
+            raise UnknownNodeProperty(
+                "Property {}"
+                " is not in the tree properties: {}".format(
+                    prop, list(self.nodeprops.keys())
+                )
+            )
+
+    def common_ancestor(self, node1: Node, node2: Node) -> Node:
+        """Return the common ancestor of node1 and node2."""
+        return _common_ancestor(self.parents_, node1, node2)
+
+    def edge_dist(self, node1: Node, node2: Node) -> int:
+        """Return the number of edges to go from node1 to node2 (by common ancestor)."""
+        return _edge_dist(self.parents_, node1, node2)
+
+    def weighted_dist(
+        self, weights: Union[NumericalNodeProperty, str], node1: Node, node2: Node
+    ) -> float:
+        """Return the number of edges to go from node1 to node2 (by common ancestor)."""
+        if isinstance(weights, str):
+            if weights in self.nodeprops:
+                return _weighted_dist(
+                    self.parents_, self.nodeprops[weights], node1, node2
+                )
+            raise InvalidNodeProps(
+                "Property {} is not in tree properties: {}".format(
+                    weights, self.nodeprops
+                )
+            )
+        if isinstance(weights, dict):
+            return _weighted_dist(self.parents_, weights, node1, node2)
+        raise InvalidNodeProps(
+            "Provided property is not a valid property. "
+            "Expected {} or {}, got {}".format(dict, str, type(weights))
+        )