[db6163]: / shepherd / utils / pretrain_utils.py

Download this file

157 lines (121 with data), 6.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# General
import random
import numpy as np
import pandas as pd
import time
import math
from typing import NamedTuple, Optional, Tuple
import plotly.express as px
# Pytorch
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Sigmoid
from torch_geometric.data import Dataset, NeighborSampler, Data
# Sci-kit Learn
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, f1_score, roc_curve, precision_recall_curve
# Global variables
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def to_numpy(input):
if isinstance(input, torch.sparse.FloatTensor):
return input.to_dense().cpu().detach().numpy()
else:
return input.cpu().detach().numpy()
def from_numpy(np_array):
return torch.as_tensor(np_array)
def sample_node_for_et(et, targets):
neg_idx = torch.randperm(targets[et].shape[0])[0] # Randomly select an index into the targets for a given edge type
node = targets[et][neg_idx] # Select the location of that edge type
return node
class HeterogeneousEdgeIndex(NamedTuple): #adopted from NeighborSampler code in Pytorch Geometric
edge_index: Tensor
e_id: Optional[Tensor]
edge_type: Optional[Tensor]
size: Tuple[int, int]
def to(self, *args, **kwargs):
edge_index = self.edge_index.to(*args, **kwargs)
e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None
edge_type = self.edge_type.to(*args, **kwargs) if self.edge_type is not None else None
return EdgeIndex(edge_index, e_id, edge_type, self.size)
def get_batched_data(data, all_data):
batch_size, n_id, adjs = data
adjs = [HeterogeneousEdgeIndex(adj.edge_index, adj.e_id, all_data.edge_attr[adj.e_id], adj.size) for adj in adjs]
data = Data(adjs = adjs,
batch_size = batch_size,
n_id = n_id,
)
return data
MAX_SIZE = 625
def get_mask(edge_index, nodes, ind):
n_splits = math.ceil(nodes.size(0) / MAX_SIZE)
node_mask = (edge_index[ind,:] == nodes[:MAX_SIZE].unsqueeze(-1)).nonzero()
for i in range(1, n_splits-1):
node_mask_mid = (edge_index[ind,:] == nodes[MAX_SIZE*i:MAX_SIZE*(i+1)].unsqueeze(-1)).nonzero()
node_mask_mid[:,0] = node_mask_mid[:,0] + (MAX_SIZE*i)
node_mask = torch.cat([node_mask, node_mask_mid])
node_mask_end = (edge_index[ind,:] == nodes[MAX_SIZE*(n_splits-1):].unsqueeze(-1)).nonzero()
node_mask_end[:,0] = node_mask_end[:,0] + (MAX_SIZE*(n_splits-1))
node_mask = torch.cat([node_mask, node_mask_end])
return node_mask
def get_indices_into_edge_index(edge_index, source_nodes, target_nodes):
if source_nodes.size(0) > MAX_SIZE:
source_node_mask = get_mask(edge_index, source_nodes, ind = 0)
target_node_mask = get_mask(edge_index, target_nodes, ind = 1)
else:
source_node_mask = (edge_index[0,:] == source_nodes.unsqueeze(-1)).nonzero()
target_node_mask = (edge_index[1,:] == target_nodes.unsqueeze(-1)).nonzero()
vals_pos, counts_pos = torch.unique(torch.cat([source_node_mask, target_node_mask]), return_counts=True, dim=0)
if len(vals_pos) == 0 or len(counts_pos) == 0:
print(edge_index)
print(source_nodes)
print(target_nodes)
return vals_pos[counts_pos > 1][:,1], vals_pos[counts_pos > 1][:,0]
def get_edges(data, all_data, dataset_type):
# get edge index
edge_index = all_data.edge_index[:, all_data[f'{dataset_type}_mask']].to(data.n_id.device)
edge_type = all_data.edge_attr[ all_data[f'{dataset_type}_mask']].to(data.n_id.device)
# filter to edges between "seed nodes"
source_nodes = data.n_id[:int(data.batch_size/2)]
pos_target_nodes = data.n_id[int(data.batch_size/2):int(data.batch_size)]
# get index into edge & node list
ind_to_edge_index_pos, ind_to_nodes_pos = get_indices_into_edge_index(edge_index, source_nodes, pos_target_nodes)
# get edges where both source & target are seed nodes
data.pos_edge_indices = edge_index[:, ind_to_edge_index_pos]
data.pos_edge_types = edge_type[ind_to_edge_index_pos]
data.index_to_node_features_pos = ind_to_nodes_pos
return data
def calc_metrics(pred, y, threshold=0.5):
y[y < 0] = 0
try:
roc_score = roc_auc_score(y, pred)
except ValueError:
roc_score = 0.5
ap_score = average_precision_score(y, pred)
acc = accuracy_score(y, pred > threshold)
f1 = f1_score(y, pred > threshold, average = 'micro')
return roc_score, ap_score, acc, f1
def metrics_per_rel(pred, link_labels, edge_attr_dict, total_edge_type, split, threshold=0.5, verbose=False):
log = {}
for attr, idx in edge_attr_dict.items():
mask = (total_edge_type == idx)
if mask.sum() == 0: continue
pred_per_rel = pred[mask]
y_per_rel = link_labels[mask]
roc_per_rel, ap_per_rel, acc_per_rel, f1_per_rel = calc_metrics(pred_per_rel.cpu().detach().numpy(), y_per_rel.cpu().detach().numpy(), threshold)
if verbose:
print("ROC for edge type {}: {:.5f}".format(attr, roc_per_rel))
print("AP for edge type {}: {:.5f}".format(attr, ap_per_rel))
print("ACC for edge type {}: {:.5f}".format(attr, acc_per_rel))
print("F1 for edge type {}: {:.5f}".format(attr, f1_per_rel))
log.update({"edge_metrics/node.%s_%s_roc" % (attr, split): roc_per_rel, "edge_metrics/node.%s_%s_ap" % (attr, split): ap_per_rel, "edge_metrics/node.%s_%s_acc" % (attr, split): acc_per_rel, "edge_metrics/node.%s_%s_f1" % (attr, split): f1_per_rel})
return log
def plot_roc_curve(pred, labels):
fpr, tpr, thresholds = roc_curve(labels, pred)
gmeans = np.sqrt(tpr * (1-fpr))
max_gmean = max(gmeans)
roc = roc_auc_score(labels, pred)
data = {"False Positive Rate": fpr, "True Positive Rate": tpr, "Threshold": thresholds,
"ROC": [roc] * len(thresholds), "G-Mean": gmeans, "Max G-Mean": [max_gmean] * len(thresholds)}
df = pd.DataFrame(data)
fig = px.line(df, x = "False Positive Rate", y = "True Positive Rate", hover_data=list(data.keys()))
return fig