Switch to unified view

a b/src/scpanel/GATclassifier.py
1
import copy
2
import inspect
3
import os.path as osp
4
import random
5
6
# import os,sys,pickle,time,random,glob
7
import time
8
from typing import Optional, Tuple, List
9
10
import numpy as np
11
import torch
12
import torch.nn.functional as F
13
import torch.utils.data
14
from sklearn.base import BaseEstimator
15
from torch_geometric.data import Data
16
from torch_geometric.nn import GATConv
17
from torch_sparse import SparseTensor, cat
18
import torch_geometric.data.data
19
from numpy import ndarray
20
from pandas.core.arrays.categorical import Categorical
21
from scipy.sparse._csr import csr_matrix
22
23
# from .utils_func import get_X_y_from_ann
24
25
26
# import pandas as pd
27
28
29
# Seed
30
seed = 42
31
torch.manual_seed(seed)
32
torch.cuda.manual_seed(seed)
33
torch.cuda.manual_seed_all(seed)
34
np.random.seed(seed)
35
random.seed(seed)
36
torch.backends.cudnn.benchmark = False
37
torch.backends.cudnn.deterministic = True
38
39
40
def scipysparse2torchsparse(x: csr_matrix) -> Tuple[torch.Tensor, torch.Tensor]:
41
    """
42
    Input: scipy csr_matrix
43
    Returns: torch tensor in experimental sparse format
44
45
    REF: Code adatped from [PyTorch discussion forum](https://discuss.pytorch.org/t/better-way-to-forward-sparse-matrix/21915>)
46
    """
47
    samples = x.shape[0]
48
    features = x.shape[1]
49
    values = x.data
50
    coo_data = x.tocoo()
51
    indices = torch.LongTensor(
52
        [coo_data.row, coo_data.col]
53
    )  # OR transpose list of index tuples
54
    t = torch.sparse.FloatTensor(
55
        indices, torch.from_numpy(values).float(), [samples, features]
56
    )
57
    return indices, t
58
59
60
class ClusterData(torch.utils.data.Dataset):
61
    r"""Clusters/partitions a graph data object into multiple subgraphs, as
62
    motivated by the `"Cluster-GCN: An Efficient Algorithm for Training Deep
63
    and Large Graph Convolutional Networks"
64
    <https://arxiv.org/abs/1905.07953>`_ paper.
65
66
    Args:
67
        data (torch_geometric.data.Data): The graph data object.
68
        num_parts (int): The number of partitions.
69
        recursive (bool, optional): If set to :obj:`True`, will use multilevel
70
            recursive bisection instead of multilevel k-way partitioning.
71
            (default: :obj:`False`)
72
        save_dir (string, optional): If set, will save the partitioned data to
73
            the :obj:`save_dir` directory for faster re-use.
74
    """
75
76
    def __init__(self, data:     torch_geometric.data.data.Data, num_parts: int, recursive: bool=False, save_dir: None=None) -> None:
77
        assert data.edge_index is not None
78
79
        self.num_parts = num_parts
80
        self.recursive = recursive
81
        self.save_dir = save_dir
82
83
        self.process(data)
84
85
    def process(self, data:     torch_geometric.data.data.Data) -> None:
86
        recursive = "_recursive" if self.recursive else ""
87
        filename = f"part_data_{self.num_parts}{recursive}.pt"
88
89
        path = osp.join(self.save_dir or "", filename)
90
        if self.save_dir is not None and osp.exists(path):
91
            data, partptr, perm = torch.load(path)
92
        else:
93
            data = copy.copy(data)
94
            num_nodes = data.num_nodes
95
96
            (row, col), edge_attr = data.edge_index, data.edge_attr
97
            adj = SparseTensor(row=row, col=col, value=edge_attr)
98
            adj, partptr, perm = adj.partition(self.num_parts, self.recursive)
99
100
            for key, item in data:
101
                if item.size(0) == num_nodes:
102
                    data[key] = item[perm]
103
104
            data.edge_index = None
105
            data.edge_attr = None
106
            data.adj = adj
107
108
            if self.save_dir is not None:
109
                torch.save((data, partptr, perm), path)
110
111
        self.data = data
112
        self.perm = perm
113
        self.partptr = partptr
114
115
    def __len__(self) -> int:
116
        return self.partptr.numel() - 1
117
118
    def __getitem__(self, idx):
119
        start = int(self.partptr[idx])
120
        length = int(self.partptr[idx + 1]) - start
121
122
        data = copy.copy(self.data)
123
        num_nodes = data.num_nodes
124
125
        for key, item in data:
126
            if item.size(0) == num_nodes:
127
                data[key] = item.narrow(0, start, length)
128
129
        data.adj = data.adj.narrow(1, start, length)
130
131
        row, col, value = data.adj.coo()
132
        data.adj = None
133
        data.edge_index = torch.stack([row, col], dim=0)
134
        data.edge_attr = value
135
136
        return data
137
138
    def __repr__(self):
139
        return f"{self.__class__.__name__}({self.data}, " f"num_parts={self.num_parts})"
140
141
142
class ClusterLoader(torch.utils.data.DataLoader):
143
    r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm
144
    for Training Deep and Large Graph Convolutional Networks"
145
    <https://arxiv.org/abs/1905.07953>`_ paper which merges partioned subgraphs
146
    and their between-cluster links from a large-scale graph data object to
147
    form a mini-batch.
148
149
    Args:
150
        cluster_data (torch_geometric.data.ClusterData): The already
151
            partioned data object.
152
        batch_size (int, optional): How many samples per batch to load.
153
            (default: :obj:`1`)
154
        shuffle (bool, optional): If set to :obj:`True`, the data will be
155
            reshuffled at every epoch. (default: :obj:`False`)
156
    """
157
158
    def __init__(self, cluster_data: ClusterData, batch_size: int=1, shuffle: bool=False, **kwargs) -> None:
159
        class HelperDataset(torch.utils.data.Dataset):
160
            def __len__(self):
161
                return len(cluster_data)
162
163
            def __getitem__(self, idx):
164
                start = int(cluster_data.partptr[idx])
165
                length = int(cluster_data.partptr[idx + 1]) - start
166
167
                data = copy.copy(cluster_data.data)
168
                num_nodes = data.num_nodes
169
                for key, item in data:
170
                    if item.size(0) == num_nodes:
171
                        data[key] = item.narrow(0, start, length)
172
173
                return data, idx
174
175
        def collate(batch):
176
            data_list = [data[0] for data in batch]
177
            parts: List[int] = [data[1] for data in batch]
178
            partptr = cluster_data.partptr
179
180
            adj = cat([data.adj for data in data_list], dim=0)
181
182
            adj = adj.t()
183
            adjs = []
184
            for part in parts:
185
                start = partptr[part]
186
                length = partptr[part + 1] - start
187
                adjs.append(adj.narrow(0, start, length))
188
            adj = cat(adjs, dim=0).t()
189
            row, col, value = adj.coo()
190
191
            data = cluster_data.data.__class__()
192
            data.num_nodes = adj.size(0)
193
            data.edge_index = torch.stack([row, col], dim=0)
194
            data.edge_attr = value
195
196
            ref = data_list[0]
197
            keys = list(ref.keys())
198
            keys.remove("adj")
199
200
            for key in keys:
201
                if ref[key].size(0) != ref.adj.size(0):
202
                    data[key] = ref[key]
203
                else:
204
                    data[key] = torch.cat(
205
                        [d[key] for d in data_list], dim=ref.__cat_dim__(key, ref[key])
206
                    )
207
208
            return data
209
210
        super(ClusterLoader, self).__init__(
211
            HelperDataset(), batch_size, shuffle, collate_fn=collate, **kwargs
212
        )
213
214
215
## model
216
class GAT(torch.nn.Module):  # torch.nn.Module is the base class for all NN modules.
217
    def __init__(self, n_nodes: int, nFeatures: int, nHiddenUnits: int, nHeads: int, alpha: float, dropout: float) -> None:
218
        super(GAT, self).__init__()
219
        # 定义实例属性
220
        self.n_nodes = n_nodes
221
        self.nFeatures = nFeatures
222
        self.nHiddenUnits = nHiddenUnits
223
        self.nHeads = nHeads
224
        self.alpha = alpha
225
        self.dropout = dropout
226
227
        self.gat1 = GATConv(
228
            self.nFeatures,
229
            out_channels=self.nHiddenUnits,  # 映射到8维
230
            heads=self.nHeads,
231
            concat=True,
232
            negative_slope=self.alpha,
233
            dropout=self.dropout,
234
            bias=True,
235
        )
236
        self.gat2 = GATConv(
237
            self.nHiddenUnits * self.nHeads,
238
            self.n_nodes,  # 最后一层映射到k维度(k=n_class)
239
            heads=self.nHeads,
240
            concat=False,
241
            negative_slope=self.alpha,
242
            dropout=self.dropout,
243
            bias=True,
244
        )
245
246
    def forward(self, data:     torch_geometric.data.data.Data) ->     torch.Tensor:
247
        x, edge_index = data.x, data.edge_index
248
        x = self.gat1(x, edge_index)  # 第一层输出经过ELU非线性函数
249
        x = F.elu(x)
250
        x = self.gat2(x, edge_index)  # 第二层输出经过softmax变成[0, 1]后直接用于分类
251
        # return F.log_softmax(x, dim=1)
252
        return x
253
254
255
## sklearn classifier
256
class GATclassifier(BaseEstimator):
257
    """A pytorch regressor"""
258
259
    def __init__(
260
        self,
261
        n_nodes: int=2,
262
        nFeatures: Optional[int]=None,
263
        nHiddenUnits: int=8,
264
        nHeads: int=8,
265
        alpha: float=0.2,
266
        dropout: float=0.4,
267
        clip: None=None,
268
        rs: int=random.randint(1, 1000000),
269
        LR: float=0.001,
270
        WeightDecay: float=5e-4,
271
        BatchSize: int=256,
272
        NumParts: int=200,
273
        nEpochs: int=100,
274
        fastmode: bool=True,
275
        verbose: int=0,
276
        device: str="cpu",
277
    ) -> None:
278
        """
279
        Called when initializing the regressor
280
        """
281
        self._history = None
282
        self._model = None
283
284
        args, _, _, values = inspect.getargvalues(inspect.currentframe())
285
        values.pop("self")
286
287
        for arg, val in values.copy().items():
288
            setattr(self, arg, val)
289
290
    def _build_model(self) -> None:
291
292
        self._model = GAT(
293
            self.n_nodes,
294
            self.nFeatures,
295
            self.nHiddenUnits,
296
            self.nHeads,
297
            self.alpha,
298
            self.dropout,
299
        )
300
301
    def _train_model(self, X: ndarray, y: Categorical, adj: csr_matrix) -> None:
302
        # X, y, adj = get_X_y_from_ann(adata_train_final, return_adj=True, n_pc=2, n_neigh=10)
303
304
        node_features = torch.from_numpy(X).float()
305
        labels = torch.LongTensor(y)
306
        edge_index, _ = scipysparse2torchsparse(adj)
307
308
        d = Data(x=node_features, edge_index=edge_index, y=labels)
309
        cd = ClusterData(d, num_parts=self.NumParts)
310
311
        cl = ClusterLoader(cd, batch_size=self.BatchSize, shuffle=True)
312
313
        optimizer = torch.optim.Adagrad(
314
            self._model.parameters(), lr=self.LR, weight_decay=self.WeightDecay
315
        )
316
317
        # Random Seed
318
        random.seed(self.rs)
319
        np.random.seed(self.rs)
320
        torch.manual_seed(self.rs)
321
322
        t_total = time.time()
323
        loss_values = []
324
        bad_counter = 0
325
        best = self.nEpochs + 1
326
        best_epoch = 0
327
328
        for epoch in range(self.nEpochs):
329
330
            t = time.time()
331
            epoch_loss = []
332
            epoch_acc = []
333
            epoch_acc_val = []
334
            epoch_loss_val = []
335
336
            self._model.train()  # It sets the mode to train
337
338
            for batch in cl:  # cl: clusterLoader
339
                batch = batch.to(self.device)  # move the data to CPU/GPU
340
                optimizer.zero_grad()  # weight init
341
                x_output = self._model(batch)  # ncell*2; log_softmax
342
                output = F.log_softmax(x_output, dim=1)
343
344
                loss = F.nll_loss(
345
                    input=output, target=batch.y
346
                )  # compute negative log likelihood loss
347
                # input: ncell*nclass;
348
                # target: ncell*1, 0 =< value <= nclass-1
349
                loss.backward()
350
                if self.clip is not None:
351
                    torch.nn.utils.clip_grad_norm_(self._model.parameters(), self.clip)
352
                optimizer.step()
353
                epoch_loss.append(loss.item())
354
                # epoch_acc.append(accuracy(output, batch.y).item())
355
356
            if not self.fastmode:
357
                d_val = Data(x=features_val, edge_index=edge_index_val, y=labels_val)
358
                d_val = d_val.to(self.device)
359
                self._model.eval()
360
                x_output = self._model(d_val)
361
                output = F.log_softmax(x_output, dim=1)
362
363
                loss_val = F.nll_loss(output, d_val.y)
364
                # acc_val = accuracy(output,d_val.y).item() # tensor.item() returns the value of this tensor as a standard Python number.
365
                if (self.verbose > 0) & ((epoch + 1) % 50 == 0):
366
                    print(
367
                        "Epoch {}\t<loss>={:.4f}\tloss_val={:.4f}\tin {:.2f}-s".format(
368
                            epoch + 1,
369
                            np.mean(epoch_loss),
370
                            loss_val.item(),
371
                            time.time() - t,
372
                        )
373
                    )
374
                loss_values.append(loss_val.item())
375
            else:
376
                if (self.verbose > 0) & ((epoch + 1) % 50 == 0):
377
                    print(
378
                        "Epoch {}\t<loss>={:.4f}\tin {:.2f}-s".format(
379
                            epoch + 1, np.mean(epoch_loss), time.time() - t
380
                        )
381
                    )
382
                loss_values.append(np.mean(epoch_loss))
383
384
    def fit(self, X: ndarray, y: Categorical, adj: csr_matrix) -> "GATclassifier":
385
        """
386
        Trains the pytorch regressor.
387
        """
388
389
        self._build_model()
390
        self._train_model(X, y, adj)
391
392
        return self
393
394
    def predict(self, X: ndarray, y: Categorical, adj: csr_matrix) ->     torch.Tensor:
395
        """
396
        Makes a prediction using the trained pytorch model
397
        """
398
399
        # X, y, adj = get_X_y_from_ann(adata_test, return_adj=True, n_pc=2, n_neigh=10)
400
401
        node_features = torch.from_numpy(X).float()
402
        labels = torch.LongTensor(y)
403
        edge_index, _ = scipysparse2torchsparse(adj)
404
405
        d_test = Data(x=node_features, edge_index=edge_index, y=labels)
406
407
        self._model.eval()  # define the evaluation mode
408
        d_test = d_test.to(self.device)
409
        x_output = self._model(d_test)
410
        output = F.log_softmax(x_output, dim=1)
411
        preds = output.max(1)[1].type_as(labels)
412
413
        return preds
414
415
    def predict_proba(self, X: ndarray, y: Categorical, adj: csr_matrix) -> ndarray:
416
417
        # X, y, adj = get_X_y_from_ann(adata_test, return_adj=True, n_pc=2, n_neigh=10)
418
419
        node_features = torch.from_numpy(X).float()
420
        labels = torch.LongTensor(y)
421
        edge_index, _ = scipysparse2torchsparse(adj)
422
423
        d_test = Data(x=node_features, edge_index=edge_index, y=labels)
424
425
        self._model.eval()  # define the evaluation mode
426
        d_test = d_test.to(self.device)
427
        x_output = self._model(d_test)
428
        output = F.log_softmax(x_output, dim=1)
429
430
        probs = torch.exp(output)  # return softmax (output is logsoftmax)
431
        y_prob = (
432
            probs.detach().cpu().numpy()
433
        )  # detach() here prune away the gradients bond with the probs tensor
434
435
        return y_prob
436
437
    def score(self, X, y, sample_weight=None):
438
        """
439
        Scores the data using the trained pytorch model. Under current implementation
440
        returns negative mae.
441
        """
442
        y_pred = self.predict(X, y)
443
        return F.nll_loss(y_pred, y)