|
a |
|
b/shepherd/train.py |
|
|
1 |
# General |
|
|
2 |
import numpy as np |
|
|
3 |
import random |
|
|
4 |
import argparse |
|
|
5 |
import os |
|
|
6 |
import sys |
|
|
7 |
from pathlib import Path |
|
|
8 |
from datetime import datetime |
|
|
9 |
from collections import Counter |
|
|
10 |
import pandas as pd |
|
|
11 |
import pickle |
|
|
12 |
import time |
|
|
13 |
|
|
|
14 |
sys.path.insert(0, '..') # add project_config to path |
|
|
15 |
|
|
|
16 |
# Pytorch |
|
|
17 |
import torch |
|
|
18 |
import torch.nn as nn |
|
|
19 |
from torch_geometric.utils.convert import to_networkx, to_scipy_sparse_matrix |
|
|
20 |
from torch_geometric.data import Data, DataLoader, NeighborSampler |
|
|
21 |
from torch_geometric.utils import negative_sampling |
|
|
22 |
import torch.nn.functional as F |
|
|
23 |
from torch.utils.data import DataLoader, random_split, SubsetRandomSampler |
|
|
24 |
|
|
|
25 |
# Pytorch lightning |
|
|
26 |
import pytorch_lightning as pl |
|
|
27 |
from pytorch_lightning.loggers import WandbLogger |
|
|
28 |
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
29 |
|
|
|
30 |
# W&B |
|
|
31 |
import wandb |
|
|
32 |
|
|
|
33 |
# multiprocessing |
|
|
34 |
import torch.multiprocessing |
|
|
35 |
torch.multiprocessing.set_sharing_strategy('file_system') |
|
|
36 |
|
|
|
37 |
# Own code |
|
|
38 |
import project_config |
|
|
39 |
from shepherd.dataset import PatientDataset |
|
|
40 |
from shepherd.gene_prioritization_model import CombinedGPAligner |
|
|
41 |
from shepherd.patient_nca_model import CombinedPatientNCA |
|
|
42 |
from shepherd.samplers import PatientNeighborSampler |
|
|
43 |
|
|
|
44 |
import preprocess |
|
|
45 |
from hparams import get_pretrain_hparams, get_train_hparams |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
import os |
|
|
49 |
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" |
|
|
50 |
import faulthandler; faulthandler.enable() |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
|
|
|
54 |
def parse_args(): |
|
|
55 |
parser = argparse.ArgumentParser(description="Learning node embeddings.") |
|
|
56 |
|
|
|
57 |
# Input files/parameters |
|
|
58 |
parser.add_argument("--edgelist", type=str, default=None, help="File with edge list") |
|
|
59 |
parser.add_argument("--node_map", type=str, default=None, help="File with node list") |
|
|
60 |
parser.add_argument('--saved_node_embeddings_path', type=str, default=None, help='Path within kg_embeddings folder to the saved KG embeddings') |
|
|
61 |
parser.add_argument('--patient_data', default="disease_simulated", type=str) |
|
|
62 |
parser.add_argument('--run_type', choices=["causal_gene_discovery", "disease_characterization", "patients_like_me"], type=str) |
|
|
63 |
parser.add_argument("--aug_sim", type=str, default=None, help="File with the similarity dictionary") |
|
|
64 |
parser.add_argument("--aug_gene_by_deg", type=bool, default=False, help="Augment gene by degree") |
|
|
65 |
parser.add_argument("--aug_gene_w", type=float, default=0.7, help="Contribution of augmentation (gene)") |
|
|
66 |
parser.add_argument("--n_sim_genes", type=int, default=3, help="K similar genes for augmentation") |
|
|
67 |
parser.add_argument("--n_transformer_layers", type=int, default=3, help="Number of transformer layers") |
|
|
68 |
parser.add_argument("--n_transformer_heads", type=int, default=8, help="Number of transformer heads") |
|
|
69 |
|
|
|
70 |
# Tunable parameters |
|
|
71 |
parser.add_argument('--sparse_sample', default=200, type=int) |
|
|
72 |
parser.add_argument('--lr', default=0.0001, type=float) |
|
|
73 |
parser.add_argument('--upsample_cand', default=1, type=int) |
|
|
74 |
parser.add_argument('--neighbor_sampler_size', default=-1, type=int) |
|
|
75 |
parser.add_argument('--lmbda', type=float, default=0.5, help='Lambda') |
|
|
76 |
parser.add_argument('--alpha', type=float, default=0, help='Alpha') |
|
|
77 |
parser.add_argument('--kappa', type=float, default=0.3, help='Kappa (Only used for combined model with link prediction loss)') |
|
|
78 |
parser.add_argument('--seed', default=33, type=int) |
|
|
79 |
parser.add_argument('--batch_size', default=64, type=int) |
|
|
80 |
|
|
|
81 |
# Resume / run inference with best checkpoint |
|
|
82 |
parser.add_argument('--resume', default="", type=str) |
|
|
83 |
parser.add_argument('--do_inference', action='store_true') |
|
|
84 |
parser.add_argument('--best_ckpt', type=str, default=None, help='Name of the best performing checkpoint') |
|
|
85 |
|
|
|
86 |
parser.add_argument('--use_wandb', type=bool, default=True) |
|
|
87 |
|
|
|
88 |
args = parser.parse_args() |
|
|
89 |
return args |
|
|
90 |
|
|
|
91 |
|
|
|
92 |
def load_patient_datasets(hparams, inference=False): |
|
|
93 |
print('loading patient datasets') |
|
|
94 |
|
|
|
95 |
if inference: |
|
|
96 |
train_dataset = None |
|
|
97 |
val_dataset = None |
|
|
98 |
else: |
|
|
99 |
train_dataset = PatientDataset(project_config.PROJECT_DIR / 'patients' / hparams['train_data'], time=hparams['time']) |
|
|
100 |
val_dataset = PatientDataset(project_config.PROJECT_DIR / 'patients' / hparams['validation_data'], time=hparams['time']) |
|
|
101 |
|
|
|
102 |
if inference: |
|
|
103 |
test_dataset = PatientDataset(project_config.PROJECT_DIR / 'patients' / hparams['test_data'], time=hparams['time']) |
|
|
104 |
else: |
|
|
105 |
test_dataset = None |
|
|
106 |
|
|
|
107 |
print('finished loading patient datasets') |
|
|
108 |
return train_dataset, val_dataset, test_dataset |
|
|
109 |
|
|
|
110 |
|
|
|
111 |
def get_dataloaders(hparams, all_data, nid_to_spl_dict, n_nodes, gene_phen_dis_node_idx, train_dataset, val_dataset, test_dataset, inference=False): |
|
|
112 |
print('Get dataloaders', flush=True) |
|
|
113 |
shuffle = False if hparams['debug'] or inference else True |
|
|
114 |
if not hparams['sample_from_gpd']: gene_phen_dis_node_idx = None |
|
|
115 |
batch_sz = hparams['inference_batch_size'] if inference else hparams['batch_size'] |
|
|
116 |
sparse_sample = 1 if inference else hparams['sparse_sample'] |
|
|
117 |
|
|
|
118 |
#get phenotypes & genes found in train patients |
|
|
119 |
if hparams['sample_edges_from_train_patients']: |
|
|
120 |
phenotype_counter = Counter() |
|
|
121 |
gene_counter = Counter() |
|
|
122 |
for patient in train_dataset: |
|
|
123 |
phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, labels, additional_labels, patient_ids = patient |
|
|
124 |
|
|
|
125 |
phenotype_counter += Counter(list(phenotype_node_idx.numpy())) |
|
|
126 |
gene_counter += Counter(list(candidate_gene_node_idx.numpy())) |
|
|
127 |
else: |
|
|
128 |
phenotype_counter=None |
|
|
129 |
gene_counter=None |
|
|
130 |
|
|
|
131 |
print('Loading SPL...') |
|
|
132 |
if hparams['spl'] is not None: |
|
|
133 |
spl = np.load(project_config.PROJECT_DIR / 'patients' / hparams['spl']) |
|
|
134 |
else: spl = None |
|
|
135 |
if hparams['spl_index'] is not None and (project_config.PROJECT_DIR / 'patients' / hparams['spl_index']).exists(): |
|
|
136 |
with open(str(project_config.PROJECT_DIR / 'patients' / hparams['spl_index']), "rb") as input_file: |
|
|
137 |
spl_indexing_dict = pickle.load(input_file) |
|
|
138 |
else: spl_indexing_dict=None # TODO: short term fix for simulated patients, get rid once we create this dict |
|
|
139 |
|
|
|
140 |
print('Loaded SPL information') |
|
|
141 |
|
|
|
142 |
if args.aug_sim is not None: |
|
|
143 |
with open(str(project_config.PROJECT_DIR / 'knowledge_graph/8.9.21_kg' / ('top_10_similar_genes_sim=%s.pkl' % args.aug_sim)), "rb") as input_file: |
|
|
144 |
gene_similarity_dict = pickle.load(input_file) |
|
|
145 |
print("Using augment gene similarity: %s" % args.aug_sim) |
|
|
146 |
else: gene_similarity_dict=None |
|
|
147 |
|
|
|
148 |
with open("/home/ema30/zaklab/rare_disease_dx/formatted_patients/degree_dict_8.9.21_kg.pkl", "rb") as input_file: |
|
|
149 |
gene_deg_dict = pickle.load(input_file) |
|
|
150 |
|
|
|
151 |
if inference: |
|
|
152 |
train_dataloader = None |
|
|
153 |
val_dataloader = None |
|
|
154 |
else: |
|
|
155 |
print('setting up train dataloader') |
|
|
156 |
train_dataloader = PatientNeighborSampler('train', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.train_mask], |
|
|
157 |
sizes = hparams['neighbor_sampler_sizes'], patient_dataset=train_dataset, batch_size = batch_sz, |
|
|
158 |
sparse_sample = sparse_sample, do_filter_edges=hparams['filter_edges'], |
|
|
159 |
all_edge_attributes=all_data.edge_attr, n_nodes = n_nodes, relevant_node_idx=gene_phen_dis_node_idx, |
|
|
160 |
shuffle = shuffle, train_phenotype_counter=phenotype_counter, train_gene_counter=gene_counter, sample_edges_from_train_patients=hparams['sample_edges_from_train_patients'], num_workers=hparams['num_workers'], |
|
|
161 |
upsample_cand=hparams['upsample_cand'], n_cand_diseases=hparams['n_cand_diseases'], use_diseases=hparams['use_diseases'], nid_to_spl_dict=nid_to_spl_dict, gp_spl=spl, spl_indexing_dict=spl_indexing_dict, |
|
|
162 |
hparams=hparams, pin_memory=hparams['pin_memory'], |
|
|
163 |
gene_similarity_dict = gene_similarity_dict, |
|
|
164 |
gene_deg_dict = gene_deg_dict) |
|
|
165 |
print('finished setting up train dataloader') |
|
|
166 |
print('setting up val dataloader') |
|
|
167 |
val_dataloader = PatientNeighborSampler('val', all_data.edge_index, all_data.edge_index[:,all_data.val_mask], |
|
|
168 |
sizes = [-1,10,5], |
|
|
169 |
patient_dataset=val_dataset, batch_size = batch_sz, |
|
|
170 |
sparse_sample = sparse_sample, all_edge_attributes=all_data.edge_attr, n_nodes = n_nodes, |
|
|
171 |
relevant_node_idx=gene_phen_dis_node_idx, |
|
|
172 |
shuffle = False, train_phenotype_counter=phenotype_counter, train_gene_counter=gene_counter, sample_edges_from_train_patients=hparams['sample_edges_from_train_patients'], num_workers=hparams['num_workers'], |
|
|
173 |
n_cand_diseases=hparams['n_cand_diseases'], use_diseases=hparams['use_diseases'], nid_to_spl_dict=nid_to_spl_dict, gp_spl=spl, spl_indexing_dict=spl_indexing_dict, |
|
|
174 |
hparams=hparams, pin_memory=hparams['pin_memory'], |
|
|
175 |
gene_similarity_dict = gene_similarity_dict, |
|
|
176 |
gene_deg_dict = gene_deg_dict) |
|
|
177 |
print('finished setting up val dataloader') |
|
|
178 |
|
|
|
179 |
print('setting up test dataloader') |
|
|
180 |
if inference: |
|
|
181 |
sizes = [-1,10,5] |
|
|
182 |
print('SIZES: ', sizes) |
|
|
183 |
test_dataloader = PatientNeighborSampler('test', all_data.edge_index, all_data.edge_index[:,all_data.test_mask], |
|
|
184 |
sizes = sizes, patient_dataset=test_dataset, batch_size = len(test_dataset), |
|
|
185 |
sparse_sample = sparse_sample, all_edge_attributes=all_data.edge_attr, n_nodes = n_nodes, relevant_node_idx=gene_phen_dis_node_idx, |
|
|
186 |
shuffle = False, num_workers=hparams['num_workers'], |
|
|
187 |
n_cand_diseases=hparams['test_n_cand_diseases'], use_diseases=hparams['use_diseases'], nid_to_spl_dict=nid_to_spl_dict, gp_spl=spl, spl_indexing_dict=spl_indexing_dict, |
|
|
188 |
hparams=hparams, pin_memory=hparams['pin_memory'], |
|
|
189 |
gene_similarity_dict = gene_similarity_dict, |
|
|
190 |
gene_deg_dict = gene_deg_dict) |
|
|
191 |
else: test_dataloader = None |
|
|
192 |
print('finished setting up test dataloader') |
|
|
193 |
|
|
|
194 |
return train_dataloader, val_dataloader, test_dataloader |
|
|
195 |
|
|
|
196 |
|
|
|
197 |
def get_model(args, hparams, node_hparams, all_data, edge_attr_dict, n_nodes, load_from_checkpoint=False): |
|
|
198 |
print("setting up model", hparams['model_type']) |
|
|
199 |
# get patient model |
|
|
200 |
if hparams['model_type'] == 'aligner': |
|
|
201 |
if load_from_checkpoint: |
|
|
202 |
comb_patient_model = CombinedGPAligner.load_from_checkpoint(checkpoint_path=str(Path(project_config.PROJECT_DIR / args.best_ckpt)), |
|
|
203 |
edge_attr_dict=edge_attr_dict, all_data=all_data, n_nodes=n_nodes, node_ckpt = hparams["saved_checkpoint_path"], node_hparams=node_hparams) |
|
|
204 |
else: |
|
|
205 |
comb_patient_model = CombinedGPAligner(edge_attr_dict=edge_attr_dict, all_data=all_data, n_nodes=n_nodes, hparams=hparams, node_ckpt = hparams["saved_checkpoint_path"], node_hparams=node_hparams) |
|
|
206 |
elif hparams['model_type'] == 'patient_NCA': |
|
|
207 |
if load_from_checkpoint: |
|
|
208 |
comb_patient_model = CombinedPatientNCA.load_from_checkpoint(checkpoint_path=str(Path(project_config.PROJECT_DIR) / args.best_ckpt), |
|
|
209 |
all_data=all_data, edge_attr_dict=edge_attr_dict, n_nodes=n_nodes, node_ckpt=hparams["saved_checkpoint_path"]) |
|
|
210 |
else: |
|
|
211 |
comb_patient_model = CombinedPatientNCA(edge_attr_dict=edge_attr_dict, all_data=all_data, n_nodes=n_nodes, node_ckpt=hparams["saved_checkpoint_path"], hparams=hparams) |
|
|
212 |
else: |
|
|
213 |
raise NotImplementedError |
|
|
214 |
print('finished setting up model') |
|
|
215 |
return comb_patient_model |
|
|
216 |
|
|
|
217 |
|
|
|
218 |
def train(args, hparams): |
|
|
219 |
print('Training Model', flush=True) |
|
|
220 |
|
|
|
221 |
# Hyperparameters |
|
|
222 |
node_hparams = get_pretrain_hparams(args, combined=True) |
|
|
223 |
print('Edge List: ', args.edgelist, flush=True) |
|
|
224 |
print('Node Map: ', args.node_map, flush=True) |
|
|
225 |
|
|
|
226 |
# Set seed |
|
|
227 |
pl.seed_everything(hparams['seed']) |
|
|
228 |
|
|
|
229 |
# Read input data |
|
|
230 |
print('Read data', flush=True) |
|
|
231 |
all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args) |
|
|
232 |
n_nodes = len(nodes["node_idx"].unique()) |
|
|
233 |
print(f'Number of nodes: {n_nodes}') |
|
|
234 |
gene_phen_dis_node_idx = torch.LongTensor(nodes.loc[nodes['node_type'].isin(['gene/protein', 'effect/phenotype', 'disease']), 'node_idx'].values) |
|
|
235 |
|
|
|
236 |
if args.resume != "": |
|
|
237 |
print('Resuming Run') |
|
|
238 |
# create Weights & Biases Logger |
|
|
239 |
if ":" in args.resume: # colons are not allowed in ID/resume name |
|
|
240 |
resume_id = "_".join(args.resume.split(":")) |
|
|
241 |
run_name = args.resume |
|
|
242 |
wandb_logger = WandbLogger(run_name, project=hparams['wandb_project_name'], entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'], id=resume_id, resume=resume_id) |
|
|
243 |
|
|
|
244 |
#add run name to hparams dict |
|
|
245 |
hparams['run_name'] = run_name |
|
|
246 |
|
|
|
247 |
# get patient model |
|
|
248 |
comb_patient_model = get_model(args, hparams, node_hparams, all_data, edge_attr_dict, n_nodes, load_from_checkpoint=True) |
|
|
249 |
|
|
|
250 |
else: |
|
|
251 |
print('Creating new W&B Logger') |
|
|
252 |
# create Weights & Biases Logger |
|
|
253 |
curr_time = datetime.now().strftime("%m_%d_%y:%H:%M:%S") |
|
|
254 |
lr = hparams['lr'] |
|
|
255 |
val_data = str(hparams['validation_data']).split('.txt')[0].replace('/', '.') |
|
|
256 |
run_name = "{}_val_{}".format(curr_time, val_data).replace('patients', 'pats') |
|
|
257 |
run_name = run_name + f'_seed={args.seed}' |
|
|
258 |
run_name = run_name.replace('5_candidates_mapped_only', '5cand_map').replace('8.9.21_kgsolved_manual_baylor_nobgm_distractor_genes', 'manual').replace('patient_disease_NCA', 'pd_NCA').replace('_distractor', '') |
|
|
259 |
wandb_logger = WandbLogger(name=run_name, project=hparams['wandb_project_name'], entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'], |
|
|
260 |
id="_".join(run_name.split(":")), resume="allow") |
|
|
261 |
|
|
|
262 |
#add run name to hparams dict |
|
|
263 |
print('Run name', run_name) |
|
|
264 |
hparams['run_name'] = run_name |
|
|
265 |
|
|
|
266 |
# get patient model |
|
|
267 |
comb_patient_model = get_model(args, hparams, node_hparams, all_data, edge_attr_dict, n_nodes, load_from_checkpoint=False) |
|
|
268 |
|
|
|
269 |
# get model & dataloaders |
|
|
270 |
nid_to_spl_dict = {nid: idx for idx, nid in enumerate(nodes[nodes["node_type"] == "gene/protein"]["node_idx"].tolist())} |
|
|
271 |
train_dataset, val_dataset, test_dataset = load_patient_datasets(hparams) |
|
|
272 |
patient_train_dataloader, patient_val_dataloader, patient_test_dataloader = get_dataloaders(hparams, all_data, nid_to_spl_dict, |
|
|
273 |
n_nodes, gene_phen_dis_node_idx, |
|
|
274 |
train_dataset, val_dataset, test_dataset) |
|
|
275 |
|
|
|
276 |
# callbacks |
|
|
277 |
print('Init callbacks') |
|
|
278 |
checkpoint_path = (project_config.PROJECT_DIR / 'checkpoints' / hparams['model_type'] / run_name) |
|
|
279 |
hparams['checkpoint_path'] = checkpoint_path |
|
|
280 |
print('Checkpoint path: ', checkpoint_path) |
|
|
281 |
if not os.path.exists(project_config.PROJECT_DIR / 'checkpoints' / hparams['model_type']): (project_config.PROJECT_DIR / 'checkpoints' / hparams['model_type']).mkdir() |
|
|
282 |
if not os.path.exists(checkpoint_path): checkpoint_path.mkdir() |
|
|
283 |
monitor_type = 'val/mrr' if args.run_type == 'disease_characterization' or args.run_type == 'patients_like_me' else 'val/gp_val_epoch_mrr' |
|
|
284 |
fname = 'epoch={epoch:02d}-val_mrr={val/mrr:.2f}' if args.run_type == 'disease_characterization' or args.run_type == 'patients_like_me' else 'epoch={epoch:02d}-val_mrr={val/gp_val_epoch_mrr:.2f}' |
|
|
285 |
patient_checkpoint_callback = ModelCheckpoint( |
|
|
286 |
monitor=monitor_type, |
|
|
287 |
dirpath=checkpoint_path, |
|
|
288 |
filename=fname, |
|
|
289 |
save_top_k=-1, |
|
|
290 |
mode='max', |
|
|
291 |
auto_insert_metric_name = False |
|
|
292 |
) |
|
|
293 |
|
|
|
294 |
# log gradients with logger |
|
|
295 |
print('wandb logger watch') |
|
|
296 |
wandb_logger.watch(comb_patient_model, log='all') |
|
|
297 |
|
|
|
298 |
#initialize trainer |
|
|
299 |
if hparams['debug']: |
|
|
300 |
limit_train_batches = 1 |
|
|
301 |
limit_val_batches = 1 |
|
|
302 |
hparams['max_epochs'] = 6 |
|
|
303 |
else: |
|
|
304 |
limit_train_batches=1.0 |
|
|
305 |
limit_val_batches=1.0 |
|
|
306 |
|
|
|
307 |
print('initialize trainer') |
|
|
308 |
patient_trainer = pl.Trainer(gpus=hparams['n_gpus'], |
|
|
309 |
logger=wandb_logger, |
|
|
310 |
max_epochs=hparams['max_epochs'], |
|
|
311 |
callbacks=[patient_checkpoint_callback], |
|
|
312 |
profiler=hparams['profiler'], |
|
|
313 |
log_gpu_memory=hparams['log_gpu_memory'], |
|
|
314 |
limit_train_batches=limit_train_batches, |
|
|
315 |
limit_val_batches=limit_val_batches, |
|
|
316 |
weights_summary="full", |
|
|
317 |
gradient_clip_val=hparams['gradclip']) |
|
|
318 |
|
|
|
319 |
# Train |
|
|
320 |
patient_trainer.fit(comb_patient_model, patient_train_dataloader, patient_val_dataloader) |
|
|
321 |
|
|
|
322 |
@torch.no_grad() |
|
|
323 |
def inference(args, hparams): |
|
|
324 |
print('Running inference') |
|
|
325 |
# Hyperparameters |
|
|
326 |
node_hparams = get_pretrain_hparams(args, combined=True) |
|
|
327 |
|
|
|
328 |
hparams.update({'add_similar_patients': False}) |
|
|
329 |
|
|
|
330 |
# Seed |
|
|
331 |
pl.seed_everything(hparams['seed']) |
|
|
332 |
|
|
|
333 |
# Read data |
|
|
334 |
all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args) |
|
|
335 |
n_nodes = len(nodes["node_idx"].unique()) |
|
|
336 |
gene_phen_dis_node_idx = torch.LongTensor(nodes.loc[nodes['node_type'].isin(['gene/protein', 'effect/phenotype', 'disease']), 'node_idx'].values) |
|
|
337 |
|
|
|
338 |
# Get logger & trainer |
|
|
339 |
curr_time = datetime.now().strftime("%m_%d_%y:%H:%M:%S") |
|
|
340 |
lr = hparams['lr'] |
|
|
341 |
test_data = hparams['test_data'].split('.txt')[0].replace('/', '.') |
|
|
342 |
run_name = "{}_lr_{}_test_{}".format(curr_time, lr, test_data) |
|
|
343 |
wandb_logger = WandbLogger(run_name, project=hparams['wandb_project_name'], entity='rare_disease_dx', save_dir=hparams['wandb_save_dir']) |
|
|
344 |
print('Run name: ', run_name) |
|
|
345 |
hparams['run_name'] = run_name |
|
|
346 |
|
|
|
347 |
# Get datasets |
|
|
348 |
train_dataset, val_dataset, test_dataset = load_patient_datasets(hparams, inference=True) |
|
|
349 |
|
|
|
350 |
# Get dataloader |
|
|
351 |
nid_to_spl_dict = {nid: idx for idx, nid in enumerate(nodes[nodes["node_type"] == "gene/protein"]["node_idx"].tolist())} |
|
|
352 |
_, _, test_dataloader = get_dataloaders(hparams, all_data, nid_to_spl_dict, |
|
|
353 |
n_nodes, gene_phen_dis_node_idx, |
|
|
354 |
train_dataset, val_dataset, test_dataset, inference=True) |
|
|
355 |
|
|
|
356 |
# Get patient model |
|
|
357 |
model = get_model(args, hparams, node_hparams, all_data, edge_attr_dict, n_nodes, load_from_checkpoint=True) |
|
|
358 |
|
|
|
359 |
trainer = pl.Trainer(gpus=0, logger=wandb_logger) |
|
|
360 |
results = trainer.test(model, dataloaders=test_dataloader) |
|
|
361 |
print(results) |
|
|
362 |
print('---- RESULTS ----') |
|
|
363 |
|
|
|
364 |
|
|
|
365 |
if __name__ == "__main__": |
|
|
366 |
|
|
|
367 |
# Get hyperparameters |
|
|
368 |
args = parse_args() |
|
|
369 |
hparams = get_train_hparams(args) |
|
|
370 |
|
|
|
371 |
# Run model |
|
|
372 |
if args.do_inference: |
|
|
373 |
inference(args, hparams) |
|
|
374 |
else: |
|
|
375 |
train(args, hparams) |
|
|
376 |
|