--- a +++ b/shepherd/utils/pretrain_utils.py @@ -0,0 +1,156 @@ +# General +import random +import numpy as np +import pandas as pd +import time +import math +from typing import NamedTuple, Optional, Tuple +import plotly.express as px + +# Pytorch +import torch +from torch import Tensor +import torch.nn.functional as F +from torch.nn import Sigmoid +from torch_geometric.data import Dataset, NeighborSampler, Data + +# Sci-kit Learn +from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, f1_score, roc_curve, precision_recall_curve + +# Global variables +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +def to_numpy(input): + if isinstance(input, torch.sparse.FloatTensor): + return input.to_dense().cpu().detach().numpy() + else: + return input.cpu().detach().numpy() + + +def from_numpy(np_array): + return torch.as_tensor(np_array) + + +def sample_node_for_et(et, targets): + neg_idx = torch.randperm(targets[et].shape[0])[0] # Randomly select an index into the targets for a given edge type + node = targets[et][neg_idx] # Select the location of that edge type + return node + + +class HeterogeneousEdgeIndex(NamedTuple): #adopted from NeighborSampler code in Pytorch Geometric + edge_index: Tensor + e_id: Optional[Tensor] + edge_type: Optional[Tensor] + size: Tuple[int, int] + + def to(self, *args, **kwargs): + edge_index = self.edge_index.to(*args, **kwargs) + e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None + edge_type = self.edge_type.to(*args, **kwargs) if self.edge_type is not None else None + + return EdgeIndex(edge_index, e_id, edge_type, self.size) + + +def get_batched_data(data, all_data): + batch_size, n_id, adjs = data + adjs = [HeterogeneousEdgeIndex(adj.edge_index, adj.e_id, all_data.edge_attr[adj.e_id], adj.size) for adj in adjs] + data = Data(adjs = adjs, + batch_size = batch_size, + n_id = n_id, + ) + return data + + +MAX_SIZE = 625 +def get_mask(edge_index, nodes, ind): + n_splits = math.ceil(nodes.size(0) / MAX_SIZE) + node_mask = (edge_index[ind,:] == nodes[:MAX_SIZE].unsqueeze(-1)).nonzero() + for i in range(1, n_splits-1): + + node_mask_mid = (edge_index[ind,:] == nodes[MAX_SIZE*i:MAX_SIZE*(i+1)].unsqueeze(-1)).nonzero() + node_mask_mid[:,0] = node_mask_mid[:,0] + (MAX_SIZE*i) + node_mask = torch.cat([node_mask, node_mask_mid]) + node_mask_end = (edge_index[ind,:] == nodes[MAX_SIZE*(n_splits-1):].unsqueeze(-1)).nonzero() + node_mask_end[:,0] = node_mask_end[:,0] + (MAX_SIZE*(n_splits-1)) + node_mask = torch.cat([node_mask, node_mask_end]) + return node_mask + + +def get_indices_into_edge_index(edge_index, source_nodes, target_nodes): + + if source_nodes.size(0) > MAX_SIZE: + source_node_mask = get_mask(edge_index, source_nodes, ind = 0) + target_node_mask = get_mask(edge_index, target_nodes, ind = 1) + else: + source_node_mask = (edge_index[0,:] == source_nodes.unsqueeze(-1)).nonzero() + target_node_mask = (edge_index[1,:] == target_nodes.unsqueeze(-1)).nonzero() + + vals_pos, counts_pos = torch.unique(torch.cat([source_node_mask, target_node_mask]), return_counts=True, dim=0) + if len(vals_pos) == 0 or len(counts_pos) == 0: + print(edge_index) + print(source_nodes) + print(target_nodes) + + return vals_pos[counts_pos > 1][:,1], vals_pos[counts_pos > 1][:,0] + + +def get_edges(data, all_data, dataset_type): + # get edge index + edge_index = all_data.edge_index[:, all_data[f'{dataset_type}_mask']].to(data.n_id.device) + edge_type = all_data.edge_attr[ all_data[f'{dataset_type}_mask']].to(data.n_id.device) + + # filter to edges between "seed nodes" + source_nodes = data.n_id[:int(data.batch_size/2)] + pos_target_nodes = data.n_id[int(data.batch_size/2):int(data.batch_size)] + + # get index into edge & node list + ind_to_edge_index_pos, ind_to_nodes_pos = get_indices_into_edge_index(edge_index, source_nodes, pos_target_nodes) + + # get edges where both source & target are seed nodes + data.pos_edge_indices = edge_index[:, ind_to_edge_index_pos] + data.pos_edge_types = edge_type[ind_to_edge_index_pos] + data.index_to_node_features_pos = ind_to_nodes_pos + + return data + + +def calc_metrics(pred, y, threshold=0.5): + y[y < 0] = 0 + try: + roc_score = roc_auc_score(y, pred) + except ValueError: + roc_score = 0.5 + ap_score = average_precision_score(y, pred) + acc = accuracy_score(y, pred > threshold) + f1 = f1_score(y, pred > threshold, average = 'micro') + return roc_score, ap_score, acc, f1 + + +def metrics_per_rel(pred, link_labels, edge_attr_dict, total_edge_type, split, threshold=0.5, verbose=False): + log = {} + for attr, idx in edge_attr_dict.items(): + mask = (total_edge_type == idx) + if mask.sum() == 0: continue + pred_per_rel = pred[mask] + y_per_rel = link_labels[mask] + roc_per_rel, ap_per_rel, acc_per_rel, f1_per_rel = calc_metrics(pred_per_rel.cpu().detach().numpy(), y_per_rel.cpu().detach().numpy(), threshold) + if verbose: + print("ROC for edge type {}: {:.5f}".format(attr, roc_per_rel)) + print("AP for edge type {}: {:.5f}".format(attr, ap_per_rel)) + print("ACC for edge type {}: {:.5f}".format(attr, acc_per_rel)) + print("F1 for edge type {}: {:.5f}".format(attr, f1_per_rel)) + log.update({"edge_metrics/node.%s_%s_roc" % (attr, split): roc_per_rel, "edge_metrics/node.%s_%s_ap" % (attr, split): ap_per_rel, "edge_metrics/node.%s_%s_acc" % (attr, split): acc_per_rel, "edge_metrics/node.%s_%s_f1" % (attr, split): f1_per_rel}) + return log + + +def plot_roc_curve(pred, labels): + fpr, tpr, thresholds = roc_curve(labels, pred) + gmeans = np.sqrt(tpr * (1-fpr)) + max_gmean = max(gmeans) + roc = roc_auc_score(labels, pred) + data = {"False Positive Rate": fpr, "True Positive Rate": tpr, "Threshold": thresholds, + "ROC": [roc] * len(thresholds), "G-Mean": gmeans, "Max G-Mean": [max_gmean] * len(thresholds)} + df = pd.DataFrame(data) + fig = px.line(df, x = "False Positive Rate", y = "True Positive Rate", hover_data=list(data.keys())) + return fig