|
a |
|
b/shepherd/preprocess.py |
|
|
1 |
# General |
|
|
2 |
import numpy as np |
|
|
3 |
import pandas as pd |
|
|
4 |
#from typing import List, Optional, Tuple, NamedTuple, Union, Callable |
|
|
5 |
|
|
|
6 |
# Pytorch |
|
|
7 |
import torch |
|
|
8 |
import torch.nn as nn |
|
|
9 |
from torch_geometric.data import Data |
|
|
10 |
from torch import Tensor |
|
|
11 |
|
|
|
12 |
import project_config |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def preprocess_graph(args): |
|
|
16 |
|
|
|
17 |
# Read in nodes & edges |
|
|
18 |
nodes = pd.read_csv(project_config.KG_DIR / args.node_map, sep="\t") |
|
|
19 |
edges = pd.read_csv(project_config.KG_DIR / args.edgelist, sep="\t") |
|
|
20 |
|
|
|
21 |
# Initialize edge index |
|
|
22 |
edge_index = torch.LongTensor(edges[['x_idx', 'y_idx']].values.T).contiguous() |
|
|
23 |
edge_attr = edges['full_relation'] |
|
|
24 |
|
|
|
25 |
# Convert edge attributes to idx |
|
|
26 |
edge_attr_list = [ |
|
|
27 |
'effect/phenotype;phenotype_protein;gene/protein', |
|
|
28 |
'gene/protein;phenotype_protein;effect/phenotype', |
|
|
29 |
'disease;disease_phenotype_negative;effect/phenotype', |
|
|
30 |
'effect/phenotype;disease_phenotype_negative;disease', |
|
|
31 |
'disease;disease_phenotype_positive;effect/phenotype', |
|
|
32 |
'effect/phenotype;disease_phenotype_positive;disease', |
|
|
33 |
'gene/protein;protein_pathway;pathway', |
|
|
34 |
'pathway;protein_pathway;gene/protein', |
|
|
35 |
'disease;disease_protein;gene/protein', |
|
|
36 |
'gene/protein;disease_protein;disease', |
|
|
37 |
'gene/protein;protein_molfunc;molecular_function', |
|
|
38 |
'molecular_function;protein_molfunc;gene/protein', |
|
|
39 |
'gene/protein;protein_cellcomp;cellular_component', |
|
|
40 |
'cellular_component;protein_cellcomp;gene/protein', |
|
|
41 |
'gene/protein;protein_bioprocess;biological_process', |
|
|
42 |
'biological_process;protein_bioprocess;gene/protein', |
|
|
43 |
'biological_process;bioprocess_bioprocess;biological_process', |
|
|
44 |
'biological_process;bioprocess_bioprocess_rev;biological_process', |
|
|
45 |
'molecular_function;molfunc_molfunc;molecular_function', |
|
|
46 |
'molecular_function;molfunc_molfunc_rev;molecular_function', |
|
|
47 |
'cellular_component;cellcomp_cellcomp;cellular_component', |
|
|
48 |
'cellular_component;cellcomp_cellcomp_rev;cellular_component', |
|
|
49 |
'effect/phenotype;phenotype_phenotype;effect/phenotype', |
|
|
50 |
'effect/phenotype;phenotype_phenotype_rev;effect/phenotype', |
|
|
51 |
'gene/protein;protein_protein;gene/protein', |
|
|
52 |
'gene/protein;protein_protein_rev;gene/protein', |
|
|
53 |
'disease;disease_disease;disease', |
|
|
54 |
'disease;disease_disease_rev;disease', |
|
|
55 |
'pathway;pathway_pathway;pathway', |
|
|
56 |
'pathway;pathway_pathway_rev;pathway' |
|
|
57 |
] |
|
|
58 |
|
|
|
59 |
edge_attr_to_idx_dict = {attr:i for i, attr in enumerate(edge_attr_list)} |
|
|
60 |
edge_attr = torch.LongTensor(np.vectorize(edge_attr_to_idx_dict.get)(edge_attr.values)) |
|
|
61 |
|
|
|
62 |
# Get train/val/test masks |
|
|
63 |
mask = edges["mask"].values |
|
|
64 |
train_mask = torch.BoolTensor(mask == "train") |
|
|
65 |
val_mask = torch.BoolTensor(mask == "val") |
|
|
66 |
test_mask = torch.BoolTensor(mask == "test") |
|
|
67 |
|
|
|
68 |
|
|
|
69 |
# Create data object |
|
|
70 |
data = Data(edge_index = edge_index, edge_attr = edge_attr, train_mask = train_mask, val_mask = val_mask, test_mask = test_mask) |
|
|
71 |
return data, edge_attr_to_idx_dict, nodes |
|
|
72 |
|