[db6163]: / shepherd / pretrain.py

Download this file

164 lines (128 with data), 7.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# General
import numpy as np
import random
import argparse
import os
from copy import deepcopy
from pathlib import Path
import sys
from datetime import datetime
# Pytorch
import torch
import torch.nn as nn
# Pytorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
# Pytorch Geo
from torch_geometric.data.sampler import NeighborSampler as PyGeoNeighborSampler
from torch_geometric.data import Data, DataLoader
# W&B
import wandb
sys.path.insert(0, '..') # add project_config to path
# Own code
import preprocess
from node_embedder_model import NodeEmbeder
import project_config
from hparams import get_pretrain_hparams
from samplers import NeighborSampler
def parse_args():
parser = argparse.ArgumentParser(description="Learn node embeddings.")
# Input files/parameters
parser.add_argument("--edgelist", type=str, default=None, help="File with edge list")
parser.add_argument("--node_map", type=str, default=None, help="File with node list")
parser.add_argument('--save_dir', type=str, default=None, help='Directory for saving files')
# Tunable parameters
parser.add_argument('--nfeat', type=int, default=2048, help='Dimension of embedding layer')
parser.add_argument('--hidden', default=256, type=int)
parser.add_argument('--output', default=128, type=int)
parser.add_argument('--n_heads', default=4, type=int)
parser.add_argument('--wd', default=0.0, type=float)
parser.add_argument('--dropout', type=float, default=0.3, help='Dropout')
parser.add_argument('--lr', default=0.0001, type=float)
parser.add_argument('--max_epochs', default=1000, type=int)
# Resume with best checkpoint
parser.add_argument('--resume', default="", type=str)
parser.add_argument('--best_ckpt', type=str, default=None, help='Name of the best performing checkpoint')
# Output
parser.add_argument('--save_embeddings', action='store_true')
args = parser.parse_args()
return args
def get_dataloaders(hparams, all_data):
print('get dataloaders')
train_dataloader = NeighborSampler('train', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.train_mask], sizes = hparams['neighbor_sampler_sizes'], batch_size = hparams['batch_size'], shuffle = True, num_workers=hparams['num_workers'], do_filter_edges=hparams['filter_edges'])
val_dataloader = NeighborSampler('val', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.val_mask], sizes = hparams['neighbor_sampler_sizes'], batch_size = hparams['batch_size'], shuffle = False, num_workers=hparams['num_workers'], do_filter_edges=hparams['filter_edges'])
test_dataloader = NeighborSampler('test', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.test_mask], sizes = hparams['neighbor_sampler_sizes'], batch_size = hparams['batch_size'], shuffle = False, num_workers=hparams['num_workers'], do_filter_edges=hparams['filter_edges'])
return train_dataloader, val_dataloader, test_dataloader
def train(args, hparams):
# Seed
pl.seed_everything(hparams['seed'])
# Read input data
all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args)
# Set up
if args.resume != "":
if ":" in args.resume: # colons are not allowed in ID/resume name
resume_id = "_".join(args.resume.split(":"))
run_name = args.resume
wandb_logger = WandbLogger(run_name, project='kg-train', entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'], id=resume_id, resume=resume_id)
model = NodeEmbeder.load_from_checkpoint(checkpoint_path=str(Path(args.save_dir) / 'checkpoints' / args.best_ckpt),
all_data=all_data, edge_attr_dict=edge_attr_dict,
num_nodes=len(nodes["node_idx"].unique()), combined_training=False)
else:
curr_time = datetime.now().strftime("%H:%M:%S")
run_name = f"{curr_time}_run"
wandb_logger = WandbLogger(run_name, project='kg-train', entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'], id="_".join(run_name.split(":")), resume="allow")
model = NodeEmbeder(all_data, edge_attr_dict, hp_dict=hparams, num_nodes=len(nodes["node_idx"].unique()), combined_training=False)
checkpoint_callback = ModelCheckpoint(monitor='val/node_total_acc', dirpath=Path(args.save_dir) / 'checkpoints', filename=f'{run_name}' + '_{epoch}', save_top_k=1, mode='max')
lr_monitor = LearningRateMonitor(logging_interval='step')
wandb_logger.watch(model, log='all')
if hparams['debug']:
limit_train_batches = 1
limit_val_batches = 1.0
hparams['max_epochs'] = 3
else:
limit_train_batches = 1.0
limit_val_batches = 1.0
trainer = pl.Trainer(gpus=hparams['n_gpus'], logger=wandb_logger,
max_epochs=hparams['max_epochs'],
callbacks=[checkpoint_callback, lr_monitor],
gradient_clip_val=hparams['gradclip'],
profiler=hparams['profiler'],
log_every_n_steps=hparams['log_every_n_steps'],
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
)
train_dataloader, val_dataloader, test_dataloader = get_dataloaders(hparams, all_data)
# Train
trainer.fit(model, train_dataloader, val_dataloader)
# Test
trainer.test(ckpt_path='best', test_dataloaders=test_dataloader)
@torch.no_grad()
def save_embeddings(args, hparams):
print('Saving Embeddings')
# Seed
pl.seed_everything(hparams['seed'])
# Read input data
all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args)
all_data.num_nodes = len(nodes["node_idx"].unique())
model = NodeEmbeder.load_from_checkpoint(checkpoint_path=str(Path(args.save_dir) / 'checkpoints' / args.best_ckpt),
all_data=all_data, edge_attr_dict=edge_attr_dict,
num_nodes=len(nodes["node_idx"].unique()), combined_training=False)
dataloader = DataLoader([all_data], batch_size=1)
trainer = pl.Trainer(gpus=0,
gradient_clip_val=hparams['gradclip']
)
embeddings = trainer.predict(model, dataloaders=dataloader)
embed_path = Path(args.save_dir) / (str(args.best_ckpt).split('.ckpt')[0] + '.embed')
torch.save(embeddings[0], str(embed_path))
print(embeddings[0].shape)
if __name__ == "__main__":
# Get hyperparameters
args = parse_args()
hparams = get_pretrain_hparams(args, combined=False)
if args.save_embeddings:
# save node embeddings from a trained model
save_embeddings(args, hparams)
else:
# Train model
train(args, hparams)