--- a +++ b/shepherd/pretrain.py @@ -0,0 +1,163 @@ +# General +import numpy as np +import random +import argparse +import os +from copy import deepcopy +from pathlib import Path +import sys +from datetime import datetime + +# Pytorch +import torch +import torch.nn as nn + +# Pytorch Lightning +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor + +# Pytorch Geo +from torch_geometric.data.sampler import NeighborSampler as PyGeoNeighborSampler +from torch_geometric.data import Data, DataLoader + +# W&B +import wandb + +sys.path.insert(0, '..') # add project_config to path + +# Own code +import preprocess +from node_embedder_model import NodeEmbeder +import project_config +from hparams import get_pretrain_hparams +from samplers import NeighborSampler + + +def parse_args(): + parser = argparse.ArgumentParser(description="Learn node embeddings.") + + # Input files/parameters + parser.add_argument("--edgelist", type=str, default=None, help="File with edge list") + parser.add_argument("--node_map", type=str, default=None, help="File with node list") + parser.add_argument('--save_dir', type=str, default=None, help='Directory for saving files') + + # Tunable parameters + parser.add_argument('--nfeat', type=int, default=2048, help='Dimension of embedding layer') + parser.add_argument('--hidden', default=256, type=int) + parser.add_argument('--output', default=128, type=int) + parser.add_argument('--n_heads', default=4, type=int) + parser.add_argument('--wd', default=0.0, type=float) + parser.add_argument('--dropout', type=float, default=0.3, help='Dropout') + parser.add_argument('--lr', default=0.0001, type=float) + parser.add_argument('--max_epochs', default=1000, type=int) + + # Resume with best checkpoint + parser.add_argument('--resume', default="", type=str) + parser.add_argument('--best_ckpt', type=str, default=None, help='Name of the best performing checkpoint') + + # Output + parser.add_argument('--save_embeddings', action='store_true') + + args = parser.parse_args() + return args + + +def get_dataloaders(hparams, all_data): + print('get dataloaders') + train_dataloader = NeighborSampler('train', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.train_mask], sizes = hparams['neighbor_sampler_sizes'], batch_size = hparams['batch_size'], shuffle = True, num_workers=hparams['num_workers'], do_filter_edges=hparams['filter_edges']) + val_dataloader = NeighborSampler('val', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.val_mask], sizes = hparams['neighbor_sampler_sizes'], batch_size = hparams['batch_size'], shuffle = False, num_workers=hparams['num_workers'], do_filter_edges=hparams['filter_edges']) + test_dataloader = NeighborSampler('test', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.test_mask], sizes = hparams['neighbor_sampler_sizes'], batch_size = hparams['batch_size'], shuffle = False, num_workers=hparams['num_workers'], do_filter_edges=hparams['filter_edges']) + return train_dataloader, val_dataloader, test_dataloader + + +def train(args, hparams): + + # Seed + pl.seed_everything(hparams['seed']) + + # Read input data + all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args) + + # Set up + if args.resume != "": + if ":" in args.resume: # colons are not allowed in ID/resume name + resume_id = "_".join(args.resume.split(":")) + run_name = args.resume + wandb_logger = WandbLogger(run_name, project='kg-train', entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'], id=resume_id, resume=resume_id) + model = NodeEmbeder.load_from_checkpoint(checkpoint_path=str(Path(args.save_dir) / 'checkpoints' / args.best_ckpt), + all_data=all_data, edge_attr_dict=edge_attr_dict, + num_nodes=len(nodes["node_idx"].unique()), combined_training=False) + else: + curr_time = datetime.now().strftime("%H:%M:%S") + run_name = f"{curr_time}_run" + wandb_logger = WandbLogger(run_name, project='kg-train', entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'], id="_".join(run_name.split(":")), resume="allow") + model = NodeEmbeder(all_data, edge_attr_dict, hp_dict=hparams, num_nodes=len(nodes["node_idx"].unique()), combined_training=False) + + checkpoint_callback = ModelCheckpoint(monitor='val/node_total_acc', dirpath=Path(args.save_dir) / 'checkpoints', filename=f'{run_name}' + '_{epoch}', save_top_k=1, mode='max') + lr_monitor = LearningRateMonitor(logging_interval='step') + wandb_logger.watch(model, log='all') + + if hparams['debug']: + limit_train_batches = 1 + limit_val_batches = 1.0 + hparams['max_epochs'] = 3 + else: + limit_train_batches = 1.0 + limit_val_batches = 1.0 + + trainer = pl.Trainer(gpus=hparams['n_gpus'], logger=wandb_logger, + max_epochs=hparams['max_epochs'], + callbacks=[checkpoint_callback, lr_monitor], + gradient_clip_val=hparams['gradclip'], + profiler=hparams['profiler'], + log_every_n_steps=hparams['log_every_n_steps'], + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + ) + train_dataloader, val_dataloader, test_dataloader = get_dataloaders(hparams, all_data) + + # Train + trainer.fit(model, train_dataloader, val_dataloader) + + # Test + trainer.test(ckpt_path='best', test_dataloaders=test_dataloader) + +@torch.no_grad() +def save_embeddings(args, hparams): + print('Saving Embeddings') + + # Seed + pl.seed_everything(hparams['seed']) + + # Read input data + all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args) + all_data.num_nodes = len(nodes["node_idx"].unique()) + + model = NodeEmbeder.load_from_checkpoint(checkpoint_path=str(Path(args.save_dir) / 'checkpoints' / args.best_ckpt), + all_data=all_data, edge_attr_dict=edge_attr_dict, + num_nodes=len(nodes["node_idx"].unique()), combined_training=False) + + dataloader = DataLoader([all_data], batch_size=1) + trainer = pl.Trainer(gpus=0, + gradient_clip_val=hparams['gradclip'] + ) + embeddings = trainer.predict(model, dataloaders=dataloader) + embed_path = Path(args.save_dir) / (str(args.best_ckpt).split('.ckpt')[0] + '.embed') + torch.save(embeddings[0], str(embed_path)) + print(embeddings[0].shape) + + + +if __name__ == "__main__": + + # Get hyperparameters + args = parse_args() + hparams = get_pretrain_hparams(args, combined=False) + + if args.save_embeddings: + # save node embeddings from a trained model + save_embeddings(args, hparams) + else: + # Train model + train(args, hparams)