Diff of /shepherd/preprocess.py [000000] .. [db6163]

Switch to unified view

a b/shepherd/preprocess.py
1
# General
2
import numpy as np
3
import pandas as pd
4
#from typing import List, Optional, Tuple, NamedTuple, Union, Callable
5
6
# Pytorch
7
import torch
8
import torch.nn as nn
9
from torch_geometric.data import Data
10
from torch import Tensor
11
12
import project_config
13
14
15
def preprocess_graph(args):
16
17
    # Read in nodes & edges
18
    nodes = pd.read_csv(project_config.KG_DIR / args.node_map, sep="\t")
19
    edges = pd.read_csv(project_config.KG_DIR / args.edgelist, sep="\t")
20
21
    # Initialize edge index
22
    edge_index = torch.LongTensor(edges[['x_idx', 'y_idx']].values.T).contiguous() 
23
    edge_attr = edges['full_relation']
24
25
    # Convert edge attributes to idx
26
    edge_attr_list = [
27
                      'effect/phenotype;phenotype_protein;gene/protein',
28
                      'gene/protein;phenotype_protein;effect/phenotype',
29
                      'disease;disease_phenotype_negative;effect/phenotype',
30
                      'effect/phenotype;disease_phenotype_negative;disease',
31
                      'disease;disease_phenotype_positive;effect/phenotype',
32
                      'effect/phenotype;disease_phenotype_positive;disease',
33
                      'gene/protein;protein_pathway;pathway',
34
                      'pathway;protein_pathway;gene/protein',
35
                      'disease;disease_protein;gene/protein',
36
                      'gene/protein;disease_protein;disease',
37
                      'gene/protein;protein_molfunc;molecular_function',
38
                      'molecular_function;protein_molfunc;gene/protein',
39
                      'gene/protein;protein_cellcomp;cellular_component',
40
                      'cellular_component;protein_cellcomp;gene/protein',
41
                      'gene/protein;protein_bioprocess;biological_process',
42
                      'biological_process;protein_bioprocess;gene/protein',
43
                      'biological_process;bioprocess_bioprocess;biological_process',
44
                      'biological_process;bioprocess_bioprocess_rev;biological_process',
45
                      'molecular_function;molfunc_molfunc;molecular_function',
46
                      'molecular_function;molfunc_molfunc_rev;molecular_function',
47
                      'cellular_component;cellcomp_cellcomp;cellular_component',
48
                      'cellular_component;cellcomp_cellcomp_rev;cellular_component',
49
                      'effect/phenotype;phenotype_phenotype;effect/phenotype',
50
                      'effect/phenotype;phenotype_phenotype_rev;effect/phenotype',
51
                      'gene/protein;protein_protein;gene/protein',
52
                      'gene/protein;protein_protein_rev;gene/protein',
53
                      'disease;disease_disease;disease',
54
                      'disease;disease_disease_rev;disease',
55
                      'pathway;pathway_pathway;pathway',
56
                      'pathway;pathway_pathway_rev;pathway'
57
                     ]
58
59
    edge_attr_to_idx_dict = {attr:i for i, attr in enumerate(edge_attr_list)}
60
    edge_attr = torch.LongTensor(np.vectorize(edge_attr_to_idx_dict.get)(edge_attr.values))
61
62
    # Get train/val/test masks
63
    mask = edges["mask"].values
64
    train_mask = torch.BoolTensor(mask == "train")
65
    val_mask = torch.BoolTensor(mask == "val")
66
    test_mask = torch.BoolTensor(mask == "test")
67
68
69
    # Create data object
70
    data = Data(edge_index = edge_index, edge_attr = edge_attr, train_mask = train_mask, val_mask = val_mask, test_mask = test_mask)
71
    return data, edge_attr_to_idx_dict, nodes
72