|
a |
|
b/shepherd/predict.py |
|
|
1 |
# General |
|
|
2 |
import numpy as np |
|
|
3 |
import pickle |
|
|
4 |
import random |
|
|
5 |
import argparse |
|
|
6 |
import os |
|
|
7 |
import sys |
|
|
8 |
from pathlib import Path |
|
|
9 |
from datetime import datetime |
|
|
10 |
import time |
|
|
11 |
from collections import Counter |
|
|
12 |
import pandas as pd |
|
|
13 |
|
|
|
14 |
sys.path.insert(0, '..') # add project_config to path |
|
|
15 |
|
|
|
16 |
# Pytorch |
|
|
17 |
import torch |
|
|
18 |
import torch.nn as nn |
|
|
19 |
import torch.nn.functional as F |
|
|
20 |
|
|
|
21 |
# Pytorch lightning |
|
|
22 |
import pytorch_lightning as pl |
|
|
23 |
from pytorch_lightning.loggers import WandbLogger |
|
|
24 |
|
|
|
25 |
# W&B |
|
|
26 |
import wandb |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
# Own code |
|
|
30 |
import project_config |
|
|
31 |
from shepherd.dataset import PatientDataset |
|
|
32 |
from shepherd.samplers import PatientNeighborSampler |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
import preprocess |
|
|
36 |
from hparams import get_predict_hparams |
|
|
37 |
from train import get_model, load_patient_datasets, get_dataloaders |
|
|
38 |
|
|
|
39 |
import os |
|
|
40 |
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
''' |
|
|
44 |
Example Command: |
|
|
45 |
|
|
|
46 |
python predict.py \ |
|
|
47 |
--run_type causal_gene_discovery \ |
|
|
48 |
--patient_data test_predict \ |
|
|
49 |
--edgelist KG_edgelist_mask.txt \ |
|
|
50 |
--node_map KG_node_map.txt \ |
|
|
51 |
--saved_node_embeddings_path checkpoints/pretrain.ckpt \ |
|
|
52 |
--best_ckpt checkpoints/causal_gene_discovery.ckpt |
|
|
53 |
|
|
|
54 |
python predict.py \ |
|
|
55 |
--run_type patients_like_me \ |
|
|
56 |
--patient_data test_predict \ |
|
|
57 |
--edgelist KG_edgelist_mask.txt \ |
|
|
58 |
--node_map KG_node_map.txt \ |
|
|
59 |
--saved_node_embeddings_path checkpoints/pretrain.ckpt \ |
|
|
60 |
--best_ckpt checkpoints/patients_like_me.ckpt |
|
|
61 |
|
|
|
62 |
python predict.py \ |
|
|
63 |
--run_type disease_characterization \ |
|
|
64 |
--patient_data test_predict \ |
|
|
65 |
--edgelist KG_edgelist_mask.txt \ |
|
|
66 |
--node_map KG_node_map.txt \ |
|
|
67 |
--saved_node_embeddings_path checkpoints/pretrain.ckpt \ |
|
|
68 |
--best_ckpt checkpoints/disease_characterization.ckpt |
|
|
69 |
''' |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
def parse_args(): |
|
|
73 |
parser = argparse.ArgumentParser(description="Predict using SHEPHERD") |
|
|
74 |
parser.add_argument("--edgelist", type=str, default=None, help="File with edge list") |
|
|
75 |
parser.add_argument("--node_map", type=str, default=None, help="File with node list") |
|
|
76 |
parser.add_argument('--patient_data', default="disease_simulated", type=str) |
|
|
77 |
parser.add_argument('--run_type', choices=["causal_gene_discovery", "disease_characterization", "patients_like_me"], type=str) |
|
|
78 |
parser.add_argument('--saved_node_embeddings_path', type=str, default=None, help='Path to pretrained model checkpoint') |
|
|
79 |
parser.add_argument('--best_ckpt', type=str, default=None, help='Name of the best performing checkpoint') |
|
|
80 |
args = parser.parse_args() |
|
|
81 |
return args |
|
|
82 |
|
|
|
83 |
|
|
|
84 |
@torch.no_grad() |
|
|
85 |
def predict(args): |
|
|
86 |
|
|
|
87 |
# Hyperparameters |
|
|
88 |
hparams = get_predict_hparams(args) |
|
|
89 |
|
|
|
90 |
# Seed |
|
|
91 |
pl.seed_everything(hparams['seed']) |
|
|
92 |
|
|
|
93 |
# Read KG |
|
|
94 |
all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args) |
|
|
95 |
n_nodes = len(nodes["node_idx"].unique()) |
|
|
96 |
gene_phen_dis_node_idx = torch.LongTensor(nodes.loc[nodes['node_type'].isin(['gene/protein', 'effect/phenotype', 'disease']), 'node_idx'].values) |
|
|
97 |
|
|
|
98 |
|
|
|
99 |
# Get dataset |
|
|
100 |
print('Loading SPL...') |
|
|
101 |
spl = np.load(project_config.PROJECT_DIR / 'patients' / hparams['spl']) |
|
|
102 |
if (project_config.PROJECT_DIR / 'patients' / hparams['spl_index']).exists(): |
|
|
103 |
with open(str(project_config.PROJECT_DIR / 'patients' / hparams['spl_index']), "rb") as input_file: |
|
|
104 |
spl_indexing_dict = pickle.load(input_file) |
|
|
105 |
else: spl_indexing_dict = None |
|
|
106 |
print('Loaded SPL information') |
|
|
107 |
|
|
|
108 |
dataset = PatientDataset(project_config.PROJECT_DIR / 'patients' / hparams['test_data'], time=hparams['time']) |
|
|
109 |
print(f'There are {len(dataset)} patients in the test dataset') |
|
|
110 |
hparams.update({'inference_batch_size': len(dataset)}) |
|
|
111 |
print('batch size: ', hparams['inference_batch_size']) |
|
|
112 |
# Get dataloader |
|
|
113 |
nid_to_spl_dict = {nid: idx for idx, nid in enumerate(nodes[nodes["node_type"] == "gene/protein"]["node_idx"].tolist())} |
|
|
114 |
|
|
|
115 |
|
|
|
116 |
dataloader = PatientNeighborSampler('predict', all_data.edge_index, all_data.edge_index[:,all_data.test_mask], |
|
|
117 |
sizes = [-1,10,5], patient_dataset=dataset, batch_size = hparams['inference_batch_size'], sparse_sample = 0, |
|
|
118 |
all_edge_attributes=all_data.edge_attr, n_nodes = n_nodes, relevant_node_idx=gene_phen_dis_node_idx, |
|
|
119 |
n_cand_diseases=hparams['test_n_cand_diseases'], use_diseases=hparams['use_diseases'], |
|
|
120 |
nid_to_spl_dict=nid_to_spl_dict, gp_spl=spl, spl_indexing_dict=spl_indexing_dict, |
|
|
121 |
shuffle = False, num_workers=hparams['num_workers'], |
|
|
122 |
hparams=hparams, pin_memory=hparams['pin_memory']) # 'test' |
|
|
123 |
|
|
|
124 |
# Create Weights & Biases Logger |
|
|
125 |
run_name = 'test' |
|
|
126 |
wandb_logger = WandbLogger(name=run_name, project='rare_disease_dx_combined', entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'], |
|
|
127 |
id="_".join(run_name.split(":")), resume="allow") |
|
|
128 |
|
|
|
129 |
# Get patient model |
|
|
130 |
model = get_model(args, hparams, None, all_data, edge_attr_dict, n_nodes,load_from_checkpoint=True) |
|
|
131 |
|
|
|
132 |
trainer = pl.Trainer(gpus=hparams['n_gpus']) |
|
|
133 |
|
|
|
134 |
t1 = time.time() |
|
|
135 |
results = trainer.predict(model, dataloaders=dataloader) |
|
|
136 |
t2 = time.time() |
|
|
137 |
print(f"Predicting took {t2 - t1:0.4f} seconds", len(dataset), "patients") |
|
|
138 |
|
|
|
139 |
scores_dfs, attn_dfs, gat_attn_df_1, gat_attn_df_2, gat_attn_df_3, phenotype_embeddings, disease_embeddings = zip(*results) |
|
|
140 |
|
|
|
141 |
print('---- RESULTS ----') |
|
|
142 |
if not os.path.exists(project_config.PROJECT_DIR / 'results'): |
|
|
143 |
os.mkdir(project_config.PROJECT_DIR / 'results') |
|
|
144 |
output_base = project_config.PROJECT_DIR / 'results' / (str(args.best_ckpt).replace('/', '.').split('.ckpt')[0]) |
|
|
145 |
|
|
|
146 |
# Save scores |
|
|
147 |
scores_df = pd.concat(scores_dfs).reset_index(drop=True) |
|
|
148 |
scores_df.to_csv(str(output_base) + '_scores.csv', index=False) |
|
|
149 |
print(scores_df) |
|
|
150 |
|
|
|
151 |
# Save patient phenotype attention |
|
|
152 |
attn_df = pd.concat(attn_dfs).reset_index(drop=True) |
|
|
153 |
attn_df.to_csv(str(output_base) + '_phenotype_attn.csv', index=False) |
|
|
154 |
print(attn_df) |
|
|
155 |
|
|
|
156 |
# Save patient phenotype embeddings |
|
|
157 |
if type(phenotype_embeddings) == tuple: phenotype_embeddings = phenotype_embeddings[0] |
|
|
158 |
torch.save(phenotype_embeddings, str(output_base) + '_phenotype_embeddings.pth') |
|
|
159 |
print("Phenotype embeddings", phenotype_embeddings) |
|
|
160 |
|
|
|
161 |
# Save disease embeddings |
|
|
162 |
if args.run_type == "disease_characterization": |
|
|
163 |
if type(disease_embeddings) == tuple: disease_embeddings = disease_embeddings[0] |
|
|
164 |
torch.save(disease_embeddings, str(output_base) + '_disease_embeddings.pth') |
|
|
165 |
print("Disease embeddings", disease_embeddings) |
|
|
166 |
|
|
|
167 |
|
|
|
168 |
if __name__ == "__main__": |
|
|
169 |
|
|
|
170 |
# Get hyperparameters |
|
|
171 |
args = parse_args() |
|
|
172 |
|
|
|
173 |
# perform prediction |
|
|
174 |
predict(args) |