Diff of /shepherd/predict.py [000000] .. [db6163]

Switch to unified view

a b/shepherd/predict.py
1
# General
2
import numpy as np
3
import pickle
4
import random
5
import argparse
6
import os
7
import sys
8
from pathlib import Path
9
from datetime import datetime
10
import time
11
from collections import Counter
12
import pandas as pd
13
14
sys.path.insert(0, '..') # add project_config to path
15
16
# Pytorch
17
import torch
18
import torch.nn as nn
19
import torch.nn.functional as F
20
21
# Pytorch lightning
22
import pytorch_lightning as pl
23
from pytorch_lightning.loggers import WandbLogger
24
25
# W&B
26
import wandb
27
28
29
# Own code
30
import project_config
31
from shepherd.dataset import PatientDataset
32
from shepherd.samplers import PatientNeighborSampler
33
34
35
import preprocess
36
from hparams import get_predict_hparams
37
from train import get_model, load_patient_datasets, get_dataloaders
38
39
import os
40
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 
41
42
43
'''
44
Example Command:
45
46
python predict.py \
47
--run_type causal_gene_discovery \
48
--patient_data test_predict \
49
--edgelist KG_edgelist_mask.txt \
50
--node_map KG_node_map.txt \
51
--saved_node_embeddings_path checkpoints/pretrain.ckpt \
52
--best_ckpt checkpoints/causal_gene_discovery.ckpt 
53
54
python predict.py \
55
--run_type patients_like_me \
56
--patient_data test_predict \
57
--edgelist KG_edgelist_mask.txt \
58
--node_map KG_node_map.txt \
59
--saved_node_embeddings_path checkpoints/pretrain.ckpt \
60
--best_ckpt checkpoints/patients_like_me.ckpt 
61
62
python predict.py \
63
--run_type disease_characterization \
64
--patient_data test_predict \
65
--edgelist KG_edgelist_mask.txt \
66
--node_map KG_node_map.txt \
67
--saved_node_embeddings_path checkpoints/pretrain.ckpt \
68
--best_ckpt checkpoints/disease_characterization.ckpt 
69
'''
70
71
72
def parse_args():
73
    parser = argparse.ArgumentParser(description="Predict using SHEPHERD")
74
    parser.add_argument("--edgelist", type=str, default=None, help="File with edge list")
75
    parser.add_argument("--node_map", type=str, default=None, help="File with node list")
76
    parser.add_argument('--patient_data', default="disease_simulated", type=str)
77
    parser.add_argument('--run_type', choices=["causal_gene_discovery", "disease_characterization", "patients_like_me"], type=str)
78
    parser.add_argument('--saved_node_embeddings_path', type=str, default=None, help='Path to pretrained model checkpoint')
79
    parser.add_argument('--best_ckpt', type=str, default=None, help='Name of the best performing checkpoint')
80
    args = parser.parse_args()
81
    return args
82
83
84
@torch.no_grad()
85
def predict(args):
86
    
87
    # Hyperparameters
88
    hparams = get_predict_hparams(args)
89
90
    # Seed
91
    pl.seed_everything(hparams['seed'])
92
93
    # Read KG
94
    all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args)
95
    n_nodes = len(nodes["node_idx"].unique())
96
    gene_phen_dis_node_idx = torch.LongTensor(nodes.loc[nodes['node_type'].isin(['gene/protein', 'effect/phenotype', 'disease']), 'node_idx'].values)
97
98
99
    # Get dataset
100
    print('Loading SPL...')
101
    spl = np.load(project_config.PROJECT_DIR / 'patients' / hparams['spl'])  
102
    if (project_config.PROJECT_DIR / 'patients' / hparams['spl_index']).exists():
103
        with open(str(project_config.PROJECT_DIR / 'patients' / hparams['spl_index']), "rb") as input_file:
104
            spl_indexing_dict = pickle.load(input_file)
105
    else: spl_indexing_dict = None 
106
    print('Loaded SPL information')
107
    
108
    dataset = PatientDataset(project_config.PROJECT_DIR / 'patients' / hparams['test_data'], time=hparams['time'])
109
    print(f'There are {len(dataset)} patients in the test dataset')
110
    hparams.update({'inference_batch_size': len(dataset)})
111
    print('batch size: ', hparams['inference_batch_size'])
112
    # Get dataloader
113
    nid_to_spl_dict = {nid: idx for idx, nid in enumerate(nodes[nodes["node_type"] == "gene/protein"]["node_idx"].tolist())}
114
115
116
    dataloader = PatientNeighborSampler('predict', all_data.edge_index, all_data.edge_index[:,all_data.test_mask], 
117
                                        sizes = [-1,10,5], patient_dataset=dataset, batch_size = hparams['inference_batch_size'], sparse_sample = 0, 
118
                                        all_edge_attributes=all_data.edge_attr, n_nodes = n_nodes, relevant_node_idx=gene_phen_dis_node_idx,
119
                                        n_cand_diseases=hparams['test_n_cand_diseases'],  use_diseases=hparams['use_diseases'], 
120
                                        nid_to_spl_dict=nid_to_spl_dict, gp_spl=spl, spl_indexing_dict=spl_indexing_dict,
121
                                        shuffle = False, num_workers=hparams['num_workers'],
122
                                        hparams=hparams, pin_memory=hparams['pin_memory']) # 'test'
123
    
124
    # Create Weights & Biases Logger
125
    run_name = 'test'
126
    wandb_logger = WandbLogger(name=run_name, project='rare_disease_dx_combined', entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'],
127
                    id="_".join(run_name.split(":")), resume="allow") 
128
129
    # Get patient model 
130
    model = get_model(args, hparams, None, all_data, edge_attr_dict,  n_nodes,load_from_checkpoint=True)
131
132
    trainer = pl.Trainer(gpus=hparams['n_gpus'])
133
    
134
    t1 = time.time()
135
    results = trainer.predict(model, dataloaders=dataloader)
136
    t2 = time.time()
137
    print(f"Predicting took {t2 - t1:0.4f} seconds", len(dataset), "patients")
138
139
    scores_dfs, attn_dfs, gat_attn_df_1, gat_attn_df_2, gat_attn_df_3, phenotype_embeddings, disease_embeddings = zip(*results)
140
    
141
    print('---- RESULTS ----')
142
    if not os.path.exists(project_config.PROJECT_DIR / 'results'):
143
        os.mkdir(project_config.PROJECT_DIR / 'results')
144
    output_base = project_config.PROJECT_DIR / 'results' /  (str(args.best_ckpt).replace('/', '.').split('.ckpt')[0])     
145
146
    # Save scores
147
    scores_df = pd.concat(scores_dfs).reset_index(drop=True)
148
    scores_df.to_csv(str(output_base) + '_scores.csv', index=False)
149
    print(scores_df)
150
151
    # Save patient phenotype attention
152
    attn_df = pd.concat(attn_dfs).reset_index(drop=True)
153
    attn_df.to_csv(str(output_base) + '_phenotype_attn.csv', index=False)
154
    print(attn_df)
155
156
    # Save patient phenotype embeddings
157
    if type(phenotype_embeddings) == tuple: phenotype_embeddings = phenotype_embeddings[0]
158
    torch.save(phenotype_embeddings, str(output_base) + '_phenotype_embeddings.pth')
159
    print("Phenotype embeddings", phenotype_embeddings)
160
161
    # Save disease embeddings
162
    if args.run_type == "disease_characterization":
163
        if type(disease_embeddings) == tuple: disease_embeddings = disease_embeddings[0]
164
        torch.save(disease_embeddings, str(output_base) + '_disease_embeddings.pth')
165
        print("Disease embeddings", disease_embeddings)
166
167
168
if __name__ == "__main__":
169
    
170
    # Get hyperparameters
171
    args = parse_args()
172
173
    # perform prediction
174
    predict(args)