Diff of /shepherd/train.py [000000] .. [db6163]

Switch to unified view

a b/shepherd/train.py
1
# General
2
import numpy as np
3
import random
4
import argparse
5
import os
6
import sys
7
from pathlib import Path
8
from datetime import datetime
9
from collections import Counter
10
import pandas as pd
11
import pickle
12
import time
13
14
sys.path.insert(0, '..') # add project_config to path
15
16
# Pytorch
17
import torch
18
import torch.nn as nn
19
from torch_geometric.utils.convert import to_networkx, to_scipy_sparse_matrix
20
from torch_geometric.data import Data, DataLoader, NeighborSampler
21
from torch_geometric.utils import negative_sampling
22
import torch.nn.functional as F
23
from torch.utils.data import DataLoader, random_split, SubsetRandomSampler
24
25
# Pytorch lightning
26
import pytorch_lightning as pl
27
from pytorch_lightning.loggers import WandbLogger
28
from pytorch_lightning.callbacks import ModelCheckpoint
29
30
# W&B
31
import wandb
32
33
# multiprocessing
34
import torch.multiprocessing
35
torch.multiprocessing.set_sharing_strategy('file_system')
36
37
# Own code
38
import project_config
39
from shepherd.dataset import PatientDataset
40
from shepherd.gene_prioritization_model import CombinedGPAligner
41
from shepherd.patient_nca_model import CombinedPatientNCA 
42
from shepherd.samplers import PatientNeighborSampler
43
44
import preprocess
45
from hparams import get_pretrain_hparams, get_train_hparams
46
47
48
import os
49
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 
50
import faulthandler; faulthandler.enable()
51
52
53
54
def parse_args():
55
    parser = argparse.ArgumentParser(description="Learning node embeddings.")
56
    
57
    # Input files/parameters
58
    parser.add_argument("--edgelist", type=str, default=None, help="File with edge list")
59
    parser.add_argument("--node_map", type=str, default=None, help="File with node list")
60
    parser.add_argument('--saved_node_embeddings_path', type=str, default=None, help='Path within kg_embeddings folder to the saved KG embeddings')
61
    parser.add_argument('--patient_data', default="disease_simulated", type=str)
62
    parser.add_argument('--run_type', choices=["causal_gene_discovery", "disease_characterization", "patients_like_me"], type=str)
63
    parser.add_argument("--aug_sim", type=str, default=None, help="File with the similarity dictionary")
64
    parser.add_argument("--aug_gene_by_deg", type=bool, default=False, help="Augment gene by degree")
65
    parser.add_argument("--aug_gene_w", type=float, default=0.7, help="Contribution of augmentation (gene)")
66
    parser.add_argument("--n_sim_genes", type=int, default=3, help="K similar genes for augmentation")
67
    parser.add_argument("--n_transformer_layers", type=int, default=3, help="Number of transformer layers")
68
    parser.add_argument("--n_transformer_heads", type=int, default=8, help="Number of transformer heads")
69
70
    # Tunable parameters
71
    parser.add_argument('--sparse_sample', default=200, type=int)
72
    parser.add_argument('--lr', default=0.0001, type=float)
73
    parser.add_argument('--upsample_cand', default=1, type=int)
74
    parser.add_argument('--neighbor_sampler_size', default=-1, type=int)
75
    parser.add_argument('--lmbda', type=float, default=0.5, help='Lambda')
76
    parser.add_argument('--alpha', type=float, default=0, help='Alpha')
77
    parser.add_argument('--kappa', type=float, default=0.3, help='Kappa (Only used for combined model with link prediction loss)')
78
    parser.add_argument('--seed', default=33, type=int)
79
    parser.add_argument('--batch_size', default=64, type=int) 
80
    
81
    # Resume / run inference with best checkpoint
82
    parser.add_argument('--resume', default="", type=str)
83
    parser.add_argument('--do_inference', action='store_true')
84
    parser.add_argument('--best_ckpt', type=str, default=None, help='Name of the best performing checkpoint')
85
    
86
    parser.add_argument('--use_wandb', type=bool, default=True)
87
88
    args = parser.parse_args()
89
    return args
90
91
92
def load_patient_datasets(hparams, inference=False):
93
    print('loading patient datasets')
94
95
    if inference:
96
        train_dataset = None
97
        val_dataset = None
98
    else:
99
        train_dataset = PatientDataset(project_config.PROJECT_DIR / 'patients' / hparams['train_data'],  time=hparams['time'])
100
        val_dataset = PatientDataset(project_config.PROJECT_DIR / 'patients' / hparams['validation_data'], time=hparams['time'])
101
102
    if inference:
103
        test_dataset = PatientDataset(project_config.PROJECT_DIR / 'patients' / hparams['test_data'], time=hparams['time'])
104
    else:
105
        test_dataset = None
106
    
107
    print('finished loading patient datasets')
108
    return train_dataset, val_dataset, test_dataset
109
110
111
def get_dataloaders(hparams, all_data, nid_to_spl_dict, n_nodes, gene_phen_dis_node_idx, train_dataset, val_dataset, test_dataset, inference=False):
112
    print('Get dataloaders', flush=True)
113
    shuffle = False if hparams['debug'] or inference else True
114
    if not hparams['sample_from_gpd']: gene_phen_dis_node_idx = None
115
    batch_sz = hparams['inference_batch_size'] if inference else hparams['batch_size']
116
    sparse_sample = 1 if inference else hparams['sparse_sample']
117
118
    #get phenotypes & genes found in train patients
119
    if hparams['sample_edges_from_train_patients']:
120
        phenotype_counter = Counter()
121
        gene_counter = Counter()
122
        for patient in train_dataset:
123
            phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, labels, additional_labels, patient_ids = patient
124
125
            phenotype_counter += Counter(list(phenotype_node_idx.numpy()))
126
            gene_counter += Counter(list(candidate_gene_node_idx.numpy()))
127
    else:
128
        phenotype_counter=None
129
        gene_counter=None
130
131
    print('Loading SPL...')
132
    if hparams['spl'] is not None:
133
        spl = np.load(project_config.PROJECT_DIR / 'patients' / hparams['spl'])  
134
    else: spl = None
135
    if hparams['spl_index'] is not None and (project_config.PROJECT_DIR / 'patients' / hparams['spl_index']).exists():
136
        with open(str(project_config.PROJECT_DIR / 'patients' / hparams['spl_index']), "rb") as input_file:
137
            spl_indexing_dict = pickle.load(input_file)
138
    else: spl_indexing_dict=None # TODO: short term fix for simulated patients, get rid once we create this dict
139
    
140
    print('Loaded SPL information')
141
142
    if args.aug_sim is not None:
143
        with open(str(project_config.PROJECT_DIR / 'knowledge_graph/8.9.21_kg' / ('top_10_similar_genes_sim=%s.pkl' % args.aug_sim)), "rb") as input_file:
144
            gene_similarity_dict = pickle.load(input_file)
145
        print("Using augment gene similarity: %s" % args.aug_sim)
146
    else: gene_similarity_dict=None
147
148
    with open("/home/ema30/zaklab/rare_disease_dx/formatted_patients/degree_dict_8.9.21_kg.pkl", "rb") as input_file:
149
        gene_deg_dict = pickle.load(input_file)
150
151
    if inference:
152
        train_dataloader = None
153
        val_dataloader = None
154
    else:
155
        print('setting up train dataloader')         
156
        train_dataloader = PatientNeighborSampler('train', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.train_mask], 
157
                        sizes = hparams['neighbor_sampler_sizes'], patient_dataset=train_dataset, batch_size = batch_sz, 
158
                        sparse_sample = sparse_sample, do_filter_edges=hparams['filter_edges'], 
159
                        all_edge_attributes=all_data.edge_attr, n_nodes = n_nodes, relevant_node_idx=gene_phen_dis_node_idx,
160
                        shuffle = shuffle, train_phenotype_counter=phenotype_counter, train_gene_counter=gene_counter, sample_edges_from_train_patients=hparams['sample_edges_from_train_patients'], num_workers=hparams['num_workers'], 
161
                        upsample_cand=hparams['upsample_cand'], n_cand_diseases=hparams['n_cand_diseases'], use_diseases=hparams['use_diseases'], nid_to_spl_dict=nid_to_spl_dict, gp_spl=spl, spl_indexing_dict=spl_indexing_dict,
162
                        hparams=hparams, pin_memory=hparams['pin_memory'],
163
                        gene_similarity_dict = gene_similarity_dict,
164
                        gene_deg_dict = gene_deg_dict)
165
        print('finished setting up train dataloader')
166
        print('setting up val dataloader')
167
        val_dataloader = PatientNeighborSampler('val', all_data.edge_index, all_data.edge_index[:,all_data.val_mask], 
168
                        sizes = [-1,10,5], 
169
                        patient_dataset=val_dataset, batch_size = batch_sz, 
170
                        sparse_sample = sparse_sample, all_edge_attributes=all_data.edge_attr, n_nodes = n_nodes, 
171
                        relevant_node_idx=gene_phen_dis_node_idx, 
172
                        shuffle = False, train_phenotype_counter=phenotype_counter, train_gene_counter=gene_counter, sample_edges_from_train_patients=hparams['sample_edges_from_train_patients'], num_workers=hparams['num_workers'],
173
                        n_cand_diseases=hparams['n_cand_diseases'], use_diseases=hparams['use_diseases'], nid_to_spl_dict=nid_to_spl_dict, gp_spl=spl, spl_indexing_dict=spl_indexing_dict,
174
                        hparams=hparams,  pin_memory=hparams['pin_memory'],
175
                        gene_similarity_dict = gene_similarity_dict, 
176
                        gene_deg_dict = gene_deg_dict)
177
        print('finished setting up val dataloader')
178
    
179
    print('setting up test dataloader')
180
    if inference:
181
        sizes = [-1,10,5]
182
        print('SIZES: ', sizes)
183
        test_dataloader = PatientNeighborSampler('test', all_data.edge_index, all_data.edge_index[:,all_data.test_mask], 
184
                        sizes = sizes, patient_dataset=test_dataset, batch_size = len(test_dataset), 
185
                        sparse_sample = sparse_sample, all_edge_attributes=all_data.edge_attr, n_nodes = n_nodes, relevant_node_idx=gene_phen_dis_node_idx,
186
                        shuffle = False, num_workers=hparams['num_workers'],
187
                        n_cand_diseases=hparams['test_n_cand_diseases'],  use_diseases=hparams['use_diseases'], nid_to_spl_dict=nid_to_spl_dict, gp_spl=spl, spl_indexing_dict=spl_indexing_dict,
188
                        hparams=hparams, pin_memory=hparams['pin_memory'],
189
                        gene_similarity_dict = gene_similarity_dict, 
190
                        gene_deg_dict = gene_deg_dict) 
191
    else: test_dataloader = None
192
    print('finished setting up test dataloader')
193
    
194
    return train_dataloader, val_dataloader, test_dataloader
195
196
197
def get_model(args, hparams, node_hparams, all_data, edge_attr_dict, n_nodes, load_from_checkpoint=False):
198
    print("setting up model", hparams['model_type'])
199
    # get patient model 
200
    if hparams['model_type'] == 'aligner':
201
        if load_from_checkpoint: 
202
            comb_patient_model = CombinedGPAligner.load_from_checkpoint(checkpoint_path=str(Path(project_config.PROJECT_DIR /  args.best_ckpt)), 
203
                                    edge_attr_dict=edge_attr_dict, all_data=all_data, n_nodes=n_nodes, node_ckpt = hparams["saved_checkpoint_path"], node_hparams=node_hparams)
204
        else:
205
            comb_patient_model = CombinedGPAligner(edge_attr_dict=edge_attr_dict, all_data=all_data, n_nodes=n_nodes, hparams=hparams, node_ckpt = hparams["saved_checkpoint_path"], node_hparams=node_hparams)
206
    elif hparams['model_type'] == 'patient_NCA':
207
        if load_from_checkpoint:
208
            comb_patient_model = CombinedPatientNCA.load_from_checkpoint(checkpoint_path=str(Path(project_config.PROJECT_DIR) /  args.best_ckpt), 
209
                                    all_data=all_data, edge_attr_dict=edge_attr_dict, n_nodes=n_nodes, node_ckpt=hparams["saved_checkpoint_path"])
210
        else:
211
            comb_patient_model = CombinedPatientNCA(edge_attr_dict=edge_attr_dict, all_data=all_data, n_nodes=n_nodes, node_ckpt=hparams["saved_checkpoint_path"], hparams=hparams)
212
    else:
213
        raise NotImplementedError
214
    print('finished setting up model')
215
    return comb_patient_model
216
217
218
def train(args, hparams):
219
    print('Training Model', flush=True)
220
221
    # Hyperparameters
222
    node_hparams = get_pretrain_hparams(args, combined=True)
223
    print('Edge List: ', args.edgelist,  flush=True)
224
    print('Node Map: ', args.node_map, flush=True)
225
226
    # Set seed
227
    pl.seed_everything(hparams['seed'])
228
229
    # Read input data
230
    print('Read data', flush=True)
231
    all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args)
232
    n_nodes = len(nodes["node_idx"].unique())
233
    print(f'Number of nodes: {n_nodes}')
234
    gene_phen_dis_node_idx = torch.LongTensor(nodes.loc[nodes['node_type'].isin(['gene/protein', 'effect/phenotype', 'disease']), 'node_idx'].values)
235
    
236
    if args.resume != "":
237
        print('Resuming Run')
238
        # create Weights & Biases Logger
239
        if ":" in args.resume: # colons are not allowed in ID/resume name
240
            resume_id = "_".join(args.resume.split(":"))
241
        run_name = args.resume
242
        wandb_logger = WandbLogger(run_name, project=hparams['wandb_project_name'], entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'], id=resume_id, resume=resume_id)
243
        
244
        #add run name to hparams dict
245
        hparams['run_name'] = run_name
246
        
247
        # get patient model 
248
        comb_patient_model = get_model(args, hparams, node_hparams, all_data, edge_attr_dict, n_nodes, load_from_checkpoint=True)
249
        
250
    else:
251
        print('Creating new W&B Logger')
252
        # create Weights & Biases Logger
253
        curr_time = datetime.now().strftime("%m_%d_%y:%H:%M:%S")
254
        lr = hparams['lr']   
255
        val_data = str(hparams['validation_data']).split('.txt')[0].replace('/', '.')
256
        run_name = "{}_val_{}".format(curr_time, val_data).replace('patients', 'pats') 
257
        run_name = run_name + f'_seed={args.seed}'
258
        run_name = run_name.replace('5_candidates_mapped_only', '5cand_map').replace('8.9.21_kgsolved_manual_baylor_nobgm_distractor_genes', 'manual').replace('patient_disease_NCA', 'pd_NCA').replace('_distractor', '')
259
        wandb_logger = WandbLogger(name=run_name, project=hparams['wandb_project_name'], entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'],
260
                        id="_".join(run_name.split(":")), resume="allow") 
261
        
262
        #add run name to hparams dict
263
        print('Run name', run_name)
264
        hparams['run_name'] = run_name
265
266
        # get patient model 
267
        comb_patient_model = get_model(args, hparams, node_hparams, all_data, edge_attr_dict, n_nodes, load_from_checkpoint=False)
268
269
    # get model & dataloaders
270
    nid_to_spl_dict = {nid: idx for idx, nid in enumerate(nodes[nodes["node_type"] == "gene/protein"]["node_idx"].tolist())}
271
    train_dataset, val_dataset, test_dataset = load_patient_datasets(hparams)
272
    patient_train_dataloader, patient_val_dataloader, patient_test_dataloader = get_dataloaders(hparams, all_data, nid_to_spl_dict,
273
                                                                                                n_nodes, gene_phen_dis_node_idx, 
274
                                                                                                train_dataset, val_dataset, test_dataset)
275
276
    # callbacks
277
    print('Init callbacks')
278
    checkpoint_path = (project_config.PROJECT_DIR / 'checkpoints' / hparams['model_type'] / run_name) 
279
    hparams['checkpoint_path'] = checkpoint_path
280
    print('Checkpoint path: ', checkpoint_path)
281
    if not os.path.exists(project_config.PROJECT_DIR / 'checkpoints' / hparams['model_type']): (project_config.PROJECT_DIR / 'checkpoints' / hparams['model_type']).mkdir()
282
    if not os.path.exists(checkpoint_path): checkpoint_path.mkdir()
283
    monitor_type =  'val/mrr' if args.run_type == 'disease_characterization' or args.run_type == 'patients_like_me' else 'val/gp_val_epoch_mrr'
284
    fname = 'epoch={epoch:02d}-val_mrr={val/mrr:.2f}' if args.run_type == 'disease_characterization' or args.run_type == 'patients_like_me'  else 'epoch={epoch:02d}-val_mrr={val/gp_val_epoch_mrr:.2f}'
285
    patient_checkpoint_callback = ModelCheckpoint(
286
        monitor=monitor_type,
287
        dirpath=checkpoint_path,
288
        filename=fname,
289
        save_top_k=-1,
290
        mode='max',
291
        auto_insert_metric_name = False
292
    )
293
294
    # log gradients with logger
295
    print('wandb logger watch')
296
    wandb_logger.watch(comb_patient_model, log='all')
297
298
    #initialize trainer
299
    if hparams['debug']: 
300
        limit_train_batches = 1
301
        limit_val_batches = 1 
302
        hparams['max_epochs'] = 6
303
    else: 
304
        limit_train_batches=1.0
305
        limit_val_batches=1.0
306
307
    print('initialize trainer')
308
    patient_trainer = pl.Trainer(gpus=hparams['n_gpus'], 
309
                                logger=wandb_logger, 
310
                                max_epochs=hparams['max_epochs'], 
311
                                callbacks=[patient_checkpoint_callback],
312
                                profiler=hparams['profiler'],
313
                                log_gpu_memory=hparams['log_gpu_memory'],
314
                                limit_train_batches=limit_train_batches, 
315
                                limit_val_batches=limit_val_batches,
316
                                weights_summary="full",
317
                                gradient_clip_val=hparams['gradclip'])
318
319
    #  Train
320
    patient_trainer.fit(comb_patient_model, patient_train_dataloader, patient_val_dataloader)
321
322
@torch.no_grad()
323
def inference(args, hparams):
324
    print('Running inference')
325
    # Hyperparameters
326
    node_hparams = get_pretrain_hparams(args, combined=True)
327
328
    hparams.update({'add_similar_patients': False})
329
330
    # Seed
331
    pl.seed_everything(hparams['seed'])
332
333
    # Read data
334
    all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args)
335
    n_nodes = len(nodes["node_idx"].unique())
336
    gene_phen_dis_node_idx = torch.LongTensor(nodes.loc[nodes['node_type'].isin(['gene/protein', 'effect/phenotype', 'disease']), 'node_idx'].values)
337
338
    # Get logger & trainer
339
    curr_time = datetime.now().strftime("%m_%d_%y:%H:%M:%S")
340
    lr = hparams['lr']   
341
    test_data = hparams['test_data'].split('.txt')[0].replace('/', '.')
342
    run_name = "{}_lr_{}_test_{}".format(curr_time, lr, test_data)
343
    wandb_logger = WandbLogger(run_name, project=hparams['wandb_project_name'], entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'])
344
    print('Run name: ', run_name)
345
    hparams['run_name'] = run_name
346
347
    # Get datasets
348
    train_dataset, val_dataset, test_dataset = load_patient_datasets(hparams, inference=True)
349
350
    # Get dataloader
351
    nid_to_spl_dict = {nid: idx for idx, nid in enumerate(nodes[nodes["node_type"] == "gene/protein"]["node_idx"].tolist())}
352
    _, _, test_dataloader = get_dataloaders(hparams, all_data, nid_to_spl_dict,
353
                                                                        n_nodes, gene_phen_dis_node_idx, 
354
                                                                        train_dataset, val_dataset, test_dataset, inference=True)
355
356
    # Get patient model 
357
    model = get_model(args, hparams, node_hparams, all_data, edge_attr_dict, n_nodes, load_from_checkpoint=True)
358
359
    trainer = pl.Trainer(gpus=0, logger=wandb_logger)
360
    results = trainer.test(model, dataloaders=test_dataloader)
361
    print(results)
362
    print('---- RESULTS ----')
363
364
365
if __name__ == "__main__":
366
    
367
    # Get hyperparameters
368
    args = parse_args()
369
    hparams = get_train_hparams(args)
370
371
    # Run model
372
    if args.do_inference:
373
        inference(args, hparams)
374
    else:
375
        train(args, hparams)
376