--- a +++ b/src/scpanel/GATclassifier.py @@ -0,0 +1,443 @@ +import copy +import inspect +import os.path as osp +import random + +# import os,sys,pickle,time,random,glob +import time +from typing import Optional, Tuple, List + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.data +from sklearn.base import BaseEstimator +from torch_geometric.data import Data +from torch_geometric.nn import GATConv +from torch_sparse import SparseTensor, cat +import torch_geometric.data.data +from numpy import ndarray +from pandas.core.arrays.categorical import Categorical +from scipy.sparse._csr import csr_matrix + +# from .utils_func import get_X_y_from_ann + + +# import pandas as pd + + +# Seed +seed = 42 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +torch.cuda.manual_seed_all(seed) +np.random.seed(seed) +random.seed(seed) +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True + + +def scipysparse2torchsparse(x: csr_matrix) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Input: scipy csr_matrix + Returns: torch tensor in experimental sparse format + + REF: Code adatped from [PyTorch discussion forum](https://discuss.pytorch.org/t/better-way-to-forward-sparse-matrix/21915>) + """ + samples = x.shape[0] + features = x.shape[1] + values = x.data + coo_data = x.tocoo() + indices = torch.LongTensor( + [coo_data.row, coo_data.col] + ) # OR transpose list of index tuples + t = torch.sparse.FloatTensor( + indices, torch.from_numpy(values).float(), [samples, features] + ) + return indices, t + + +class ClusterData(torch.utils.data.Dataset): + r"""Clusters/partitions a graph data object into multiple subgraphs, as + motivated by the `"Cluster-GCN: An Efficient Algorithm for Training Deep + and Large Graph Convolutional Networks" + <https://arxiv.org/abs/1905.07953>`_ paper. + + Args: + data (torch_geometric.data.Data): The graph data object. + num_parts (int): The number of partitions. + recursive (bool, optional): If set to :obj:`True`, will use multilevel + recursive bisection instead of multilevel k-way partitioning. + (default: :obj:`False`) + save_dir (string, optional): If set, will save the partitioned data to + the :obj:`save_dir` directory for faster re-use. + """ + + def __init__(self, data: torch_geometric.data.data.Data, num_parts: int, recursive: bool=False, save_dir: None=None) -> None: + assert data.edge_index is not None + + self.num_parts = num_parts + self.recursive = recursive + self.save_dir = save_dir + + self.process(data) + + def process(self, data: torch_geometric.data.data.Data) -> None: + recursive = "_recursive" if self.recursive else "" + filename = f"part_data_{self.num_parts}{recursive}.pt" + + path = osp.join(self.save_dir or "", filename) + if self.save_dir is not None and osp.exists(path): + data, partptr, perm = torch.load(path) + else: + data = copy.copy(data) + num_nodes = data.num_nodes + + (row, col), edge_attr = data.edge_index, data.edge_attr + adj = SparseTensor(row=row, col=col, value=edge_attr) + adj, partptr, perm = adj.partition(self.num_parts, self.recursive) + + for key, item in data: + if item.size(0) == num_nodes: + data[key] = item[perm] + + data.edge_index = None + data.edge_attr = None + data.adj = adj + + if self.save_dir is not None: + torch.save((data, partptr, perm), path) + + self.data = data + self.perm = perm + self.partptr = partptr + + def __len__(self) -> int: + return self.partptr.numel() - 1 + + def __getitem__(self, idx): + start = int(self.partptr[idx]) + length = int(self.partptr[idx + 1]) - start + + data = copy.copy(self.data) + num_nodes = data.num_nodes + + for key, item in data: + if item.size(0) == num_nodes: + data[key] = item.narrow(0, start, length) + + data.adj = data.adj.narrow(1, start, length) + + row, col, value = data.adj.coo() + data.adj = None + data.edge_index = torch.stack([row, col], dim=0) + data.edge_attr = value + + return data + + def __repr__(self): + return f"{self.__class__.__name__}({self.data}, " f"num_parts={self.num_parts})" + + +class ClusterLoader(torch.utils.data.DataLoader): + r"""The data loader scheme from the `"Cluster-GCN: An Efficient Algorithm + for Training Deep and Large Graph Convolutional Networks" + <https://arxiv.org/abs/1905.07953>`_ paper which merges partioned subgraphs + and their between-cluster links from a large-scale graph data object to + form a mini-batch. + + Args: + cluster_data (torch_geometric.data.ClusterData): The already + partioned data object. + batch_size (int, optional): How many samples per batch to load. + (default: :obj:`1`) + shuffle (bool, optional): If set to :obj:`True`, the data will be + reshuffled at every epoch. (default: :obj:`False`) + """ + + def __init__(self, cluster_data: ClusterData, batch_size: int=1, shuffle: bool=False, **kwargs) -> None: + class HelperDataset(torch.utils.data.Dataset): + def __len__(self): + return len(cluster_data) + + def __getitem__(self, idx): + start = int(cluster_data.partptr[idx]) + length = int(cluster_data.partptr[idx + 1]) - start + + data = copy.copy(cluster_data.data) + num_nodes = data.num_nodes + for key, item in data: + if item.size(0) == num_nodes: + data[key] = item.narrow(0, start, length) + + return data, idx + + def collate(batch): + data_list = [data[0] for data in batch] + parts: List[int] = [data[1] for data in batch] + partptr = cluster_data.partptr + + adj = cat([data.adj for data in data_list], dim=0) + + adj = adj.t() + adjs = [] + for part in parts: + start = partptr[part] + length = partptr[part + 1] - start + adjs.append(adj.narrow(0, start, length)) + adj = cat(adjs, dim=0).t() + row, col, value = adj.coo() + + data = cluster_data.data.__class__() + data.num_nodes = adj.size(0) + data.edge_index = torch.stack([row, col], dim=0) + data.edge_attr = value + + ref = data_list[0] + keys = list(ref.keys()) + keys.remove("adj") + + for key in keys: + if ref[key].size(0) != ref.adj.size(0): + data[key] = ref[key] + else: + data[key] = torch.cat( + [d[key] for d in data_list], dim=ref.__cat_dim__(key, ref[key]) + ) + + return data + + super(ClusterLoader, self).__init__( + HelperDataset(), batch_size, shuffle, collate_fn=collate, **kwargs + ) + + +## model +class GAT(torch.nn.Module): # torch.nn.Module is the base class for all NN modules. + def __init__(self, n_nodes: int, nFeatures: int, nHiddenUnits: int, nHeads: int, alpha: float, dropout: float) -> None: + super(GAT, self).__init__() + # 定义实例属性 + self.n_nodes = n_nodes + self.nFeatures = nFeatures + self.nHiddenUnits = nHiddenUnits + self.nHeads = nHeads + self.alpha = alpha + self.dropout = dropout + + self.gat1 = GATConv( + self.nFeatures, + out_channels=self.nHiddenUnits, # 映射到8维 + heads=self.nHeads, + concat=True, + negative_slope=self.alpha, + dropout=self.dropout, + bias=True, + ) + self.gat2 = GATConv( + self.nHiddenUnits * self.nHeads, + self.n_nodes, # 最后一层映射到k维度(k=n_class) + heads=self.nHeads, + concat=False, + negative_slope=self.alpha, + dropout=self.dropout, + bias=True, + ) + + def forward(self, data: torch_geometric.data.data.Data) -> torch.Tensor: + x, edge_index = data.x, data.edge_index + x = self.gat1(x, edge_index) # 第一层输出经过ELU非线性函数 + x = F.elu(x) + x = self.gat2(x, edge_index) # 第二层输出经过softmax变成[0, 1]后直接用于分类 + # return F.log_softmax(x, dim=1) + return x + + +## sklearn classifier +class GATclassifier(BaseEstimator): + """A pytorch regressor""" + + def __init__( + self, + n_nodes: int=2, + nFeatures: Optional[int]=None, + nHiddenUnits: int=8, + nHeads: int=8, + alpha: float=0.2, + dropout: float=0.4, + clip: None=None, + rs: int=random.randint(1, 1000000), + LR: float=0.001, + WeightDecay: float=5e-4, + BatchSize: int=256, + NumParts: int=200, + nEpochs: int=100, + fastmode: bool=True, + verbose: int=0, + device: str="cpu", + ) -> None: + """ + Called when initializing the regressor + """ + self._history = None + self._model = None + + args, _, _, values = inspect.getargvalues(inspect.currentframe()) + values.pop("self") + + for arg, val in values.copy().items(): + setattr(self, arg, val) + + def _build_model(self) -> None: + + self._model = GAT( + self.n_nodes, + self.nFeatures, + self.nHiddenUnits, + self.nHeads, + self.alpha, + self.dropout, + ) + + def _train_model(self, X: ndarray, y: Categorical, adj: csr_matrix) -> None: + # X, y, adj = get_X_y_from_ann(adata_train_final, return_adj=True, n_pc=2, n_neigh=10) + + node_features = torch.from_numpy(X).float() + labels = torch.LongTensor(y) + edge_index, _ = scipysparse2torchsparse(adj) + + d = Data(x=node_features, edge_index=edge_index, y=labels) + cd = ClusterData(d, num_parts=self.NumParts) + + cl = ClusterLoader(cd, batch_size=self.BatchSize, shuffle=True) + + optimizer = torch.optim.Adagrad( + self._model.parameters(), lr=self.LR, weight_decay=self.WeightDecay + ) + + # Random Seed + random.seed(self.rs) + np.random.seed(self.rs) + torch.manual_seed(self.rs) + + t_total = time.time() + loss_values = [] + bad_counter = 0 + best = self.nEpochs + 1 + best_epoch = 0 + + for epoch in range(self.nEpochs): + + t = time.time() + epoch_loss = [] + epoch_acc = [] + epoch_acc_val = [] + epoch_loss_val = [] + + self._model.train() # It sets the mode to train + + for batch in cl: # cl: clusterLoader + batch = batch.to(self.device) # move the data to CPU/GPU + optimizer.zero_grad() # weight init + x_output = self._model(batch) # ncell*2; log_softmax + output = F.log_softmax(x_output, dim=1) + + loss = F.nll_loss( + input=output, target=batch.y + ) # compute negative log likelihood loss + # input: ncell*nclass; + # target: ncell*1, 0 =< value <= nclass-1 + loss.backward() + if self.clip is not None: + torch.nn.utils.clip_grad_norm_(self._model.parameters(), self.clip) + optimizer.step() + epoch_loss.append(loss.item()) + # epoch_acc.append(accuracy(output, batch.y).item()) + + if not self.fastmode: + d_val = Data(x=features_val, edge_index=edge_index_val, y=labels_val) + d_val = d_val.to(self.device) + self._model.eval() + x_output = self._model(d_val) + output = F.log_softmax(x_output, dim=1) + + loss_val = F.nll_loss(output, d_val.y) + # acc_val = accuracy(output,d_val.y).item() # tensor.item() returns the value of this tensor as a standard Python number. + if (self.verbose > 0) & ((epoch + 1) % 50 == 0): + print( + "Epoch {}\t<loss>={:.4f}\tloss_val={:.4f}\tin {:.2f}-s".format( + epoch + 1, + np.mean(epoch_loss), + loss_val.item(), + time.time() - t, + ) + ) + loss_values.append(loss_val.item()) + else: + if (self.verbose > 0) & ((epoch + 1) % 50 == 0): + print( + "Epoch {}\t<loss>={:.4f}\tin {:.2f}-s".format( + epoch + 1, np.mean(epoch_loss), time.time() - t + ) + ) + loss_values.append(np.mean(epoch_loss)) + + def fit(self, X: ndarray, y: Categorical, adj: csr_matrix) -> "GATclassifier": + """ + Trains the pytorch regressor. + """ + + self._build_model() + self._train_model(X, y, adj) + + return self + + def predict(self, X: ndarray, y: Categorical, adj: csr_matrix) -> torch.Tensor: + """ + Makes a prediction using the trained pytorch model + """ + + # X, y, adj = get_X_y_from_ann(adata_test, return_adj=True, n_pc=2, n_neigh=10) + + node_features = torch.from_numpy(X).float() + labels = torch.LongTensor(y) + edge_index, _ = scipysparse2torchsparse(adj) + + d_test = Data(x=node_features, edge_index=edge_index, y=labels) + + self._model.eval() # define the evaluation mode + d_test = d_test.to(self.device) + x_output = self._model(d_test) + output = F.log_softmax(x_output, dim=1) + preds = output.max(1)[1].type_as(labels) + + return preds + + def predict_proba(self, X: ndarray, y: Categorical, adj: csr_matrix) -> ndarray: + + # X, y, adj = get_X_y_from_ann(adata_test, return_adj=True, n_pc=2, n_neigh=10) + + node_features = torch.from_numpy(X).float() + labels = torch.LongTensor(y) + edge_index, _ = scipysparse2torchsparse(adj) + + d_test = Data(x=node_features, edge_index=edge_index, y=labels) + + self._model.eval() # define the evaluation mode + d_test = d_test.to(self.device) + x_output = self._model(d_test) + output = F.log_softmax(x_output, dim=1) + + probs = torch.exp(output) # return softmax (output is logsoftmax) + y_prob = ( + probs.detach().cpu().numpy() + ) # detach() here prune away the gradients bond with the probs tensor + + return y_prob + + def score(self, X, y, sample_weight=None): + """ + Scores the data using the trained pytorch model. Under current implementation + returns negative mae. + """ + y_pred = self.predict(X, y) + return F.nll_loss(y_pred, y)