--- a +++ b/pathaia/graphs/object_api.py @@ -0,0 +1,505 @@ +"""Classes used to represent graphs.""" +from typing import List, Sequence, Optional, Union, Tuple +import json +import warnings +from scipy.sparse import spmatrix, dok_matrix +import numpy as np +from ordered_set import OrderedSet +from shapely.geometry import Polygon +from shapely.ops import unary_union +from shapely.affinity import translate +from pathlib import Path +from .types import ( + Node, + NodeProperties, + BinaryNodeProperty, + NumericalNodeProperty, + Parenthood, + Childhood, + Edge, + UEdge, + EdgeProperties, + NumericalEdgeProperty, +) +from ..util.types import PathLike +from .errors import ( + InvalidNodeProps, + UndefinedParenthood, + UndefinedChildhood, + UnknownNodeProperty, +) +from .functional_api import ( + complete_tree as _complete_tree, + get_nodeprops_edgeprops, + get_root as _get_root, + get_root_path as _get_root_path, + get_leaves as _get_leaves, + tree_to_json as _tree_to_json, + kruskal_tree as _kruskal_tree, + cut_on_property as _cut_on_property, + common_ancestor as _common_ancestor, + edge_dist as _edge_dist, + weighted_dist as _weighted_dist, + get_kneighbors_graph, +) +from ..util.basic import ifnone +import ast + +MAX_N_NODES = int(10e7) + + +class Graph: + """Object to represent a directed graph.""" + + def __init__( + self, + nodes: Optional[Sequence[Node]] = None, + edges: Optional[Sequence[Edge]] = None, + A: Optional[spmatrix] = None, + nodeprops: Optional[NodeProperties] = None, + edgeprops: Optional[EdgeProperties] = None, + ): + self.A_ = dok_matrix((MAX_N_NODES, MAX_N_NODES), dtype=bool) + if nodes is None: + self.nodes_ = OrderedSet() + if edges is not None: + self.edges_ = set(edges) + for x, y in edges: + i = self.nodes_.add(x) + j = self.nodes_.add(y) + self.A_[i, j] = True + elif A is not None: + self.nodes_ = OrderedSet(np.arange(A.shape[0])) + self.edges_ = set() + for i, j in zip(*A.nonzero()): + self.edges_.add((i, j)) + self.A_[i, j] = True + else: + self.edges_ = set() + else: + self.nodes_ = OrderedSet(nodes) + if edges is not None: + self.edges_ = set(edges) + for x, y in edges: + i = self.nodes_.index(x) + j = self.nodes_.index(y) + self.A_[i, j] = True + elif A is not None: + self.edges_ = set() + for i, j in zip(*A.nonzero()): + self.edges_.add((self.nodes_[i], self.nodes_[j])) + self.A_[i, j] = True + else: + self.edges_ = set() + + self.nodeprops = ifnone(nodeprops, {}) + self.edgeprops = ifnone(edgeprops, {}) + + @property + def n_nodes(self): + return len(self.nodes_) + + @property + def nodes(self): + return self.nodes_ + + @property + def edges(self): + return self.edges_ + + @property + def A(self): + return self.A_.tocsr()[: self.n_nodes, : self.n_nodes] + + def add_node(self, node: Node): + self.nodes_.add(node) + + def add_nodes(self, nodes: Sequence[Node]): + for node in nodes: + self.add_node(node) + + def add_edge(self, edge: Edge): + self.add_nodes(edge) + self.edges_.add(edge) + n1, n2 = edge + i = self.nodes_.index(n1) + j = self.nodes_.index(n2) + self.A_[i, j] = True + + def add_edges(self, edges: Sequence[Edge]): + for edge in edges: + self.add_edge(edge) + + def remove_edge(self, edge: Edge): + try: + self.edges_.remove(edge) + except KeyError: + print(f"Edge {edge} was not found in graph") + n1, n2 = edge + i = self.nodes_.index(n1) + j = self.nodes_.index(n2) + self.A_[i, j] = False + + def reset(self): + self.nodes_ = OrderedSet() + self.edges_ = set() + self.A_ = dok_matrix((MAX_N_NODES, MAX_N_NODES), dtype=bool) + self.nodeprops = {} + self.edgeprops = {} + + +class UGraph(Graph): + """Class to represent an undirected graph.""" + + def __init__( + self, + nodes: Optional[Sequence[Node]] = None, + edges: Optional[Sequence[Edge]] = None, + A: Optional[spmatrix] = None, + nodeprops: Optional[NodeProperties] = None, + edgeprops: Optional[EdgeProperties] = None, + ): + super().__init__(nodes, edges, A, nodeprops, edgeprops) + self.edges_ = {UEdge(edge, key=self.nodes_.index) for edge in self.edges_} + + @property + def A(self): + A = self.A_.tocsr()[: self.n_nodes, : self.n_nodes] + return A + A.T + + def add_edge(self, edge: Edge): + super().add_edge(UEdge(edge, key=self.nodes_.index)) + + def remove_edge(self, edge: Edge): + super().remove_edge(UEdge(edge, key=self.nodes_.index)) + n1, n2 = edge + i = self.nodes_.index(n1) + j = self.nodes_.index(n2) + self.A_[j, i] = False + + @classmethod + def from_hovernet_wsi_file( + cls, + wsi_file: PathLike, + n_farthest_samples: Union[int, float] = 0.3, + n_random_samples: Union[int, float] = 0.1, + dmax: int = 500, + n_neighbors: int = 5, + n_jobs: Optional[int] = None, + ): + """ + Create a cell graph from a single hovernet json file generated from their WSI + script. + + Args: + wsi_file: json_file generated by hovernet's run_wsi.sh. + n_farthest_samples: number of points to keep using farthest points sampling. + If a float is given, represents the proportion of points used instead. + n_random_samples: number of points to keep using random sampling. If a float + is given, represents the proportion of points used instead. + dmax: maximum distance in pixels between two adjacent nodes. + n_neighbors: number of neighbors to use for KNN algorithm. + n_jobs: number of parallel jobs to run for neighbors search. None means 1. + + Returns: + A UGraph representing cell nuclei connections. + """ + with open(wsi_file, "r") as f: + nuc_dict = json.load(f) + centroids = [] + + for k in nuc_dict["nuc"]: + x, y = nuc_dict["nuc"][k]["centroid"] + centroids.append((x, y)) + centroids = np.array(centroids) + + A = get_kneighbors_graph( + centroids, + n_farthest_samples=n_farthest_samples, + n_random_samples=n_random_samples, + dmax=dmax, + n_neighbors=n_neighbors, + n_jobs=n_jobs, + ) + nodeprops, edgeprops = get_nodeprops_edgeprops(A, centroids) + return cls(A=A, nodeprops=nodeprops, edgeprops=edgeprops) + + @classmethod + def from_hovernet_patch_file( + cls, + patch_folder: PathLike, + n_farthest_samples: Union[int, float] = 0.3, + n_random_samples: Union[int, float] = 0.1, + dmax: int = 500, + n_neighbors: int = 5, + n_jobs: Optional[int] = None, + ): + """ + Create a cell graph from a folder containing hovernet json files generated from + their tile script. + + Args: + patch_folder: folder containing json_files generated by hovernet's + run_tile.sh. Files must be named with x_y_level.json formatting. + n_farthest_samples: number of points to keep using farthest points sampling. + If a float is given, represents the proportion of points used instead. + n_random_samples: number of points to keep using random sampling. If a float + is given, represents the proportion of points used instead. + dmax: maximum distance in pixels between two adjacent nodes. + n_neighbors: number of neighbors to use for KNN algorithm. + n_jobs: number of parallel jobs to run for neighbors search. None means 1. + + Returns: + A UGraph representing cell nuclei connections. + """ + patch_folder = Path(patch_folder) + polygons = [] + + for json_file in patch_folder.iterdir(): + with open(json_file, "r") as f: + nuc_dict = json.load(f) + x, y = map(int, json_file.stem.split("_")[:2]) + for k in nuc_dict["nuc"]: + contour = nuc_dict["nuc"][k]["contour"] + polygon = Polygon(contour) + polygon = translate(polygon, xoff=x, yoff=y) + polygons.append(polygon) + polygons = unary_union(polygons) + + centroids = [(polygon.centroid.x, polygon.centroid.y) for polygon in polygons] + centroids = np.array(centroids, dtype=np.int32) + + A = get_kneighbors_graph( + centroids, + n_farthest_samples=n_farthest_samples, + n_random_samples=n_random_samples, + dmax=dmax, + n_neighbors=n_neighbors, + n_jobs=n_jobs, + ) + nodeprops, edgeprops = get_nodeprops_edgeprops(A, centroids) + return cls(A=A, edgeprops=edgeprops, nodeprops=nodeprops) + + +class Tree(Graph): + """Object to handle trees.""" + + def __init__( + self, + nodes: Optional[Sequence[Node]] = None, + edges: Optional[Sequence[Edge]] = None, + parents: Optional[Parenthood] = None, + children: Optional[Childhood] = None, + nodeprops: Optional[NodeProperties] = None, + edgeprops: Optional[EdgeProperties] = None, + jsonfile: Optional[str] = None, + ): + """Init tree object.""" + if jsonfile is not None: + self.from_json(jsonfile) + edges = set() + for parent in self.children_: + for child in self.children_[parent]: + edges.add((parent, child)) + else: + if edges is not None and (parents is not None or children is not None): + warnings.warn( + "Be careful when specifying both edges and parents/children," + "consistency will not be checked and edges will be prioritized." + ) + if edges is None: + edges = set() + self.parents_, self.children_ = _complete_tree(parents, children) + for parent in self.children_: + for child in self.children_[parent]: + edges.add((parent, child)) + else: + edges = set(edges) + self.parents_ = {} + self.children_ = {} + for parent, child in edges: + self.parents_[child] = parent + try: + self.children_[parent].add(child) + except KeyError: + self.children_[parent] = {child} + super().__init__( + nodes=nodes, edges=edges, nodeprops=nodeprops, edgeprops=edgeprops + ) + + @property + def parents(self) -> Parenthood: + return self.parents_ + + @property + def children(self) -> Childhood: + return self.children_ + + def add_edge(self, parent: Node, child: Node): + self.parents_[child] = parent + try: + self.children_[parent].add(child) + except KeyError: + self.children_[parent] = {child} + super().add_edge((parent, child)) + + def add_children(self, parent: Node, children: Sequence[Node]): + for child in children: + self.parents_[child] = parent + super().add_edge((parent, child)) + try: + self.children_[parent] |= set(children) + except KeyError: + self.children_[parent] = set(children) + + def add_edges(self, edges: Sequence[Tuple[Node, Union[Node, Sequence[Node]]]]): + for p, c in edges: + if isinstance(c, Node): + self.add_edge(p, c) + else: + self.add_children(p, c) + + def reset(self): + super().reset() + self.parents_ = {} + self.children_ = {} + + def get_root(self, node: Node = None) -> Node: + """Give root of the tree.""" + if self.parents_ is not None: + return _get_root(self.parents_, node) + raise UndefinedParenthood( + "Parenthood of the tree was not defined, " + "please build the tree before use." + ) + + def get_root_path(self, node: Node) -> List[Node]: + """Get path to root of the tree.""" + if self.parents_ is not None: + return _get_root_path(self.parents_, node) + raise UndefinedParenthood( + "Parenthood of the tree was not defined, " + "please build the tree before use." + ) + + def get_leaves( + self, node: Node, prop: Optional[BinaryNodeProperty] = None + ) -> List[Node]: + """Get leaves of a node.""" + if self.children_ is not None: + return _get_leaves(self.children_, node, prop) + raise UndefinedChildhood( + "Childhood of the tree was not defined, " + "please build the tree before use." + ) + + def to_json(self, jsonfile): + """Store the tree to json file.""" + _tree_to_json( + self.nodes_, + self.parents_, + self.children_, + jsonfile, + self.nodeprops, + self.edgeprops, + ) + + def from_json(self, jsonfile): + """Create the tree from a json file.""" + # Keep in mind that json keys have to be str. + # In treez framework, they can be python object as well + # We use ast to parse the str to a python object before + + # This behaviour might limit even more the types of + # parenthood/childhood/props keys when using treez... + with open(jsonfile, "r") as jf: + json_dict = json.load(jf) + self.reset() + for parent, children in json_dict["children"].items(): + try: + parentkey = ast.literal_eval(parent) + self.add_children(parentkey, children) + except (ValueError, SyntaxError): + self.add_children(parent, children) + self.edgeprops = dict() + for name, edgeprop in json_dict["edgeprops"].item(): + self.edgeprops[name] = dict() + for edgein, edgeout in edgeprop.items(): + try: + edgekey = ast.literal_eval(edgein) + self.edgeprops[name][edgekey] = edgeout + except (ValueError, SyntaxError): + self.edgeprops[name][edgein] = edgeout + for name, nodeprop in json_dict["nodeprops"].items(): + for nodein, nodeout in nodeprop.items(): + try: + nodekey = ast.literal_eval(nodein) + self.nodeprops[name][nodekey] = nodeout + except (ValueError, SyntaxError): + self.nodeprops[name][nodein] = nodeout + + def build_kruskal( + self, + edges: Sequence[Edge], + weights: NumericalEdgeProperty, + size: NumericalNodeProperty, + ): + """Build tree with kruskal algorithm from graph edges.""" + _, k_children, k_props = _kruskal_tree(edges, weights, size) + for parent in k_children: + self.add_children(parent, k_children) + self.nodeprops = k_props + + def cut_on_property(self, cut_name: str, prop: str, threshold: Union[int, float]): + """ + Produce a list of authorized nodes given a property threshold. + Set a new property to these nodes. + """ + if prop in self.nodeprops: + node_of_interest = _cut_on_property( + self.parents_, self.children_, self.nodeprops[prop], threshold + ) + cut = dict() + for node in self.nodes: + if node in node_of_interest: + cut[node] = True + else: + cut[node] = False + self.nodeprops[cut_name] = cut + + else: + raise UnknownNodeProperty( + "Property {}" + " is not in the tree properties: {}".format( + prop, list(self.nodeprops.keys()) + ) + ) + + def common_ancestor(self, node1: Node, node2: Node) -> Node: + """Return the common ancestor of node1 and node2.""" + return _common_ancestor(self.parents_, node1, node2) + + def edge_dist(self, node1: Node, node2: Node) -> int: + """Return the number of edges to go from node1 to node2 (by common ancestor).""" + return _edge_dist(self.parents_, node1, node2) + + def weighted_dist( + self, weights: Union[NumericalNodeProperty, str], node1: Node, node2: Node + ) -> float: + """Return the number of edges to go from node1 to node2 (by common ancestor).""" + if isinstance(weights, str): + if weights in self.nodeprops: + return _weighted_dist( + self.parents_, self.nodeprops[weights], node1, node2 + ) + raise InvalidNodeProps( + "Property {} is not in tree properties: {}".format( + weights, self.nodeprops + ) + ) + if isinstance(weights, dict): + return _weighted_dist(self.parents_, weights, node1, node2) + raise InvalidNodeProps( + "Provided property is not a valid property. " + "Expected {} or {}, got {}".format(dict, str, type(weights)) + )