a b/pathaia/graphs/clustering.py
1
from typing import Any, Dict, Optional, Sequence, Tuple, Union
2
3
import numpy as np
4
from nptyping import NDArray, Shape
5
from scipy.sparse import triu
6
from sortedcontainers import SortedDict
7
from tqdm import tqdm
8
9
from .object_api import Tree, UGraph
10
from .types import Edge, Node
11
12
13
class AgglomerativeClustering:
14
    r"""
15
    Object used to hierarchically cluster nodes on a graph. Clustering greedily chooses
16
    to merge linked nodes that have minimum distance/strength ratio. Strength between
17
    2 nodes is initially 1 for every edge and 0 when there is no edge, then when 2 nodes
18
    are merged the strength of a newly formed link between the new node and another node
19
    is the weighted (by node population) average of the strengths between the 2 old
20
    nodes and the other node. This algorithm uses centroid linkage clustering (UPGMC).
21
22
    Args:
23
        compute_all: whether to initially compute all distances between nodes regardless
24
            of there linkage.
25
    """
26
27
    def __init__(self, compute_all: bool = False):
28
        self.compute_all = compute_all
29
30
    def init_graph(
31
        self,
32
        G: UGraph,
33
        feats: Union[Dict[Node, NDArray[Shape["*"], Any]], Sequence[str]],
34
        weights: Optional[Union[Dict[Edge, float], str]] = None,
35
    ):
36
        r"""
37
        Initialize main graph attributes (adjacency matrix, n_nodes and features) using
38
        a graph object, a list of features and a list of weights.
39
40
        Args:
41
            G: graph to cluster nodes on.
42
            feats: either a dictionary that maps nodes to their corresponding feature
43
                vectors or a sequence of property names that will be used as features.
44
            weights: either a dictionary that maps edges to their corresponding weight
45
                or a property name that will be used as weight. If `None` is passed,
46
                weights are computed using euclidian distances between feature vectors.
47
        """
48
        self.A = triu(G.A, format="csr").astype(np.float32)
49
        self.n_nodes = G.n_nodes
50
        if isinstance(feats, dict):
51
            feats = [feats[node] for node in G.nodes]
52
            self.feats = np.stack(feats)
53
        else:
54
            self.feats = []
55
            for node in G.nodes:
56
                self.feats.append([G.nodeprops[feat][node] for feat in feats])
57
            self.feats = np.array(self.feats)
58
59
        if weights is None:
60
            ii, jj = self.A.nonzero()
61
            dists = ((feats[ii] - feats[jj]) ** 2).sum(1)
62
            self.A[ii, jj] = dists
63
        elif isinstance(weights, dict):
64
            for (n1, n2) in weights:
65
                i, j = sorted((G.nodes.index(n1), G.nodes.index(n2)))
66
                self.A[i, j] = weights[n1, n2] ** 2
67
        else:
68
            for (n1, n2) in G.edges:
69
                i, j = sorted((G.nodes.index(n1), G.nodes.index(n2)))
70
                self.A[i, j] = G.edgeprops[str][(n1, n2)] ** 2
71
72
    def reset(self):
73
        """
74
        Reset the algorithm attributes. Populations are initiated to 1 for every node,
75
        strengths are initiated to 1 for every edge, dendrogram is emptied.
76
        """
77
        self.populations_ = {k: 1 for k in range(self.n_nodes)}
78
        ii, jj = self.A.nonzero()
79
        self.centroids_ = {k: self.feats[k] for k in range(self.n_nodes)}
80
        self.links_ = {
81
            k: set(jj[ii == k].tolist() + ii[jj == k].tolist())
82
            for k in range(self.n_nodes)
83
        }
84
        self.strengths_ = {(i, j): 1 for i, j in zip(ii, jj)}
85
        if self.compute_all:
86
            self.distances_ = {
87
                (i, j): self.distance(i, j)
88
                for i in range(self.n_nodes)
89
                for j in range(i + 1, self.n_nodes)
90
            }
91
        else:
92
            self.distances_ = {(i, j): self.distance(i, j) for i, j in zip(ii, jj)}
93
        self.edges_ = SortedDict(
94
            self.criterion, {(i, j): self.distances_[i, j] for i, j in zip(ii, jj)}
95
        )
96
        self.dendrogram_ = np.zeros((self.n_nodes - 1, 4))
97
98
    def distance(self, i: int, j: int) -> float:
99
        """
100
        Get squared distance between nodes `i` and `j`. If available in the adjacency
101
        matrix or in the `distance` dictionary it is not recomputed.
102
103
        Args:
104
            i: first node.
105
            j: second node.
106
107
        Returns:
108
            Squared euclidian distance between i and j.
109
        """
110
        i, j = sorted((i, j))
111
        try:
112
            d = self.A[i, j]
113
        except IndexError:
114
            d = self.distances_.get((i, j), 0)
115
        if not d:
116
            d = ((self.centroids_[i] - self.centroids_[j]) ** 2).sum()
117
        return d
118
119
    def criterion(self, x: Tuple[int, int]) -> float:
120
        """
121
        Criterion function used to find the next nodes to merge. Override it to use
122
        another criterion.
123
124
        Args:
125
            x: tuple containg the two nodes to merge.
126
127
        Returns:
128
            Squared distance between the 2 nodes divided by link strength.
129
        """
130
        i, j = x
131
        return self.distances_[i, j] / self.strengths_[i, j]
132
133
    def create_centroid_link(self, i, j, c, k):
134
        """
135
        Create a new link between centroid `c` (that comes from merging nodes `i` and
136
        `j`) and node `k`.
137
138
        Args:
139
            i: first merged node.
140
            j: second merged node.
141
            c: centroid of nodes `i` and `j`.
142
            k: node linked to either `i` or `j` or both.
143
        """
144
        if i == k or j == k:
145
            return
146
        ik, ki = sorted((i, k))
147
        jk, kj = sorted((j, k))
148
        ck, kc = sorted((c, k))
149
        ij, ji = sorted((i, j))
150
151
        pi = self.populations_[i]
152
        pj = self.populations_[j]
153
        ri = pi / (pi + pj)
154
        rj = pj / (pi + pj)
155
        try:
156
            dik = self.distances_[ik, ki]
157
            djk = self.distances_[jk, kj]
158
            dij = self.distances_[ij, ji]
159
            dck = ri * dik + rj * djk - ri * rj * dij
160
        except KeyError:
161
            dck = self.distance(c, k)
162
163
        self.distances_[ck, kc] = dck
164
165
        self.edges_.pop((ik, ki), 0)
166
        self.edges_.pop((jk, kj), 0)
167
        sik = self.strengths_.get((ik, ki), 0)
168
        sjk = self.strengths_.get((jk, kj), 0)
169
        if sik or sjk:
170
            self.strengths_[ck, kc] = ri * sik + rj * sjk
171
            self.edges_[ck, kc] = dck
172
173
        self.links_[k].discard(i)
174
        self.links_[k].discard(j)
175
        self.links_[k].add(c)
176
        self.links_[j].discard(k)
177
        self.links_[c].add(k)
178
179
    def add_link(self, i: int, j: int):
180
        """
181
        Create a new link between 2 nodes.
182
183
        Args:
184
            i: first node.
185
            j: second node.
186
        """
187
        i, j = sorted((i, j))
188
        dij = self.distance(i, j)
189
        self.distances_[i, j] = dij
190
        self.strengths_[i, j] = 1
191
        self.edges_[i, j] = dij
192
        self.links_[i].add(j)
193
        self.links_[j].add(i)
194
195
    def fit(
196
        self,
197
        G: UGraph,
198
        feats: Union[Dict[Node, NDArray[Shape["*"], Any]], Sequence[str]],
199
        weights: Optional[Union[Dict[Edge, float], str]] = None,
200
    ):
201
        r"""
202
        Fits on the given graph and completes the dendrogram. A dendrogram is an array
203
        of size :math:`(n-1) \times 4` (whre :math:`n` is the number of nodes)
204
        representing the successive merges of nodes. Each row gives the two merged
205
        nodes, their distance and the size of the resulting cluster. Any new node
206
        resulting from a merge takes the first available index (e.g., the first merge
207
        corresponds to node :math:`n`).
208
209
        Args:
210
            G: graph to cluster nodes on.
211
            feats: either a dictionary that maps nodes to their corresponding feature
212
                vectors or a sequence of property names that will be used as features.
213
            weights: either a dictionary that maps edges to their corresponding weight
214
                or a property name that will be used as weight. If `None` is passed,
215
                weights are computed using euclidian distances between feature vectors.
216
        """
217
        self.init_graph(G, feats, weights)
218
        self.reset()
219
220
        c = self.n_nodes
221
        for n in tqdm(range(self.n_nodes - 1), total=self.n_nodes - 1):
222
            if not self.edges_:
223
                cur_dendrogram = self.dendrogram_[:n]
224
                missing = sorted(
225
                    [
226
                        k
227
                        for k in range(c)
228
                        if k not in cur_dendrogram[:, 0]
229
                        and k not in cur_dendrogram[:, 1]
230
                    ]
231
                )
232
                for k, i in enumerate(missing):
233
                    for j in missing[k + 1 :]:
234
                        self.add_link(i, j)
235
            (i, j), _ = self.edges_.popitem(0)
236
237
            pi = self.populations_[i]
238
            pj = self.populations_[j]
239
            ri = pi / (pi + pj)
240
            rj = pj / (pi + pj)
241
            self.dendrogram_[n] = [i, j, self.criterion((i, j)), pi + pj]
242
243
            self.centroids_[c] = ri * self.centroids_[i] + rj * self.centroids_[j]
244
            self.populations_[c] = pi + pj
245
            self.links_[c] = set()
246
247
            while self.links_[i]:
248
                k = self.links_[i].pop()
249
                self.create_centroid_link(i, j, c, k)
250
            while self.links_[j]:
251
                k = self.links_[j].pop()
252
                self.create_centroid_link(j, i, c, k)
253
            self.links_.pop(i)
254
            self.links_.pop(j)
255
            c += 1
256
257
    def fit_transform(
258
        self,
259
        G: UGraph,
260
        feats: Union[Dict[Node, NDArray[Shape["*"], Any]], Sequence[str]],
261
        weights: Optional[Union[Dict[Edge, float], str]] = None,
262
    ) -> Tree:
263
        """
264
        Fits on the given graph and returns the hierarchical clustering tree.
265
266
        Args:
267
            G: graph to cluster nodes on.
268
            feats: either a dictionary that maps nodes to their corresponding feature
269
                vectors or a sequence of property names that will be used as features.
270
            weights: either a dictionary that maps edges to their corresponding weight
271
                or a property name that will be used as weight. If `None` is passed,
272
                weights are computed using euclidian distances between feature vectors.
273
274
        Returns:
275
            The tree that describes the hierarchical clustering procedure.
276
        """
277
        self.fit(G, feats, weights)
278
279
        children = []
280
        parents = []
281
        nodes = list(range(G.n_nodes))
282
283
        if isinstance(weights, str):
284
            key = weights
285
        else:
286
            key = "weight"
287
        edgeprops = {key: {}}
288
289
        if isinstance(feats, dict):
290
            n_feats = next(iter(feats)).shape
291
            nodeprops = {
292
                k: {n: feats[node][k] for n, node in enumerate(G.nodes)}
293
                for k in range(n_feats)
294
            }
295
        else:
296
            nodeprops = {
297
                feat: {n: G.nodeprops[feat][node] for n, node in enumerate(G.nodes)}
298
                for feat in feats
299
            }
300
        nodeprops["population"] = {n: 1 for n in range(G.n_nodes)}
301
302
        for k, row in enumerate(self.dendrogram_):
303
            n = k + self.n_nodes
304
            n1, n2 = row[:2]
305
            children[n] = [n1, n2]
306
            parents[n1] = n
307
            parents[n2] = n
308
            nodes.append(n)
309
            if isinstance(feats, dict):
310
                for k, centroid in enumerate(self.centroids_[n]):
311
                    nodeprops[k][n] = centroid
312
            else:
313
                for k, feat in enumerate(feats):
314
                    nodeprops[feat][n] = self.centroids_[n, k]
315
            edgeprops[key][n, n1] = self.distance(n, n1) ** 0.5
316
            edgeprops[key][n, n2] = self.distance(n, n2) ** 0.5
317
            nodeprops["population"][n] = self.populations_[n]
318
319
        return Tree(nodes, parents, children, nodeprops, edgeprops)