|
a |
|
b/shepherd/pretrain.py |
|
|
1 |
# General |
|
|
2 |
import numpy as np |
|
|
3 |
import random |
|
|
4 |
import argparse |
|
|
5 |
import os |
|
|
6 |
from copy import deepcopy |
|
|
7 |
from pathlib import Path |
|
|
8 |
import sys |
|
|
9 |
from datetime import datetime |
|
|
10 |
|
|
|
11 |
# Pytorch |
|
|
12 |
import torch |
|
|
13 |
import torch.nn as nn |
|
|
14 |
|
|
|
15 |
# Pytorch Lightning |
|
|
16 |
import pytorch_lightning as pl |
|
|
17 |
from pytorch_lightning.loggers import WandbLogger |
|
|
18 |
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor |
|
|
19 |
|
|
|
20 |
# Pytorch Geo |
|
|
21 |
from torch_geometric.data.sampler import NeighborSampler as PyGeoNeighborSampler |
|
|
22 |
from torch_geometric.data import Data, DataLoader |
|
|
23 |
|
|
|
24 |
# W&B |
|
|
25 |
import wandb |
|
|
26 |
|
|
|
27 |
sys.path.insert(0, '..') # add project_config to path |
|
|
28 |
|
|
|
29 |
# Own code |
|
|
30 |
import preprocess |
|
|
31 |
from node_embedder_model import NodeEmbeder |
|
|
32 |
import project_config |
|
|
33 |
from hparams import get_pretrain_hparams |
|
|
34 |
from samplers import NeighborSampler |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
def parse_args(): |
|
|
38 |
parser = argparse.ArgumentParser(description="Learn node embeddings.") |
|
|
39 |
|
|
|
40 |
# Input files/parameters |
|
|
41 |
parser.add_argument("--edgelist", type=str, default=None, help="File with edge list") |
|
|
42 |
parser.add_argument("--node_map", type=str, default=None, help="File with node list") |
|
|
43 |
parser.add_argument('--save_dir', type=str, default=None, help='Directory for saving files') |
|
|
44 |
|
|
|
45 |
# Tunable parameters |
|
|
46 |
parser.add_argument('--nfeat', type=int, default=2048, help='Dimension of embedding layer') |
|
|
47 |
parser.add_argument('--hidden', default=256, type=int) |
|
|
48 |
parser.add_argument('--output', default=128, type=int) |
|
|
49 |
parser.add_argument('--n_heads', default=4, type=int) |
|
|
50 |
parser.add_argument('--wd', default=0.0, type=float) |
|
|
51 |
parser.add_argument('--dropout', type=float, default=0.3, help='Dropout') |
|
|
52 |
parser.add_argument('--lr', default=0.0001, type=float) |
|
|
53 |
parser.add_argument('--max_epochs', default=1000, type=int) |
|
|
54 |
|
|
|
55 |
# Resume with best checkpoint |
|
|
56 |
parser.add_argument('--resume', default="", type=str) |
|
|
57 |
parser.add_argument('--best_ckpt', type=str, default=None, help='Name of the best performing checkpoint') |
|
|
58 |
|
|
|
59 |
# Output |
|
|
60 |
parser.add_argument('--save_embeddings', action='store_true') |
|
|
61 |
|
|
|
62 |
args = parser.parse_args() |
|
|
63 |
return args |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
def get_dataloaders(hparams, all_data): |
|
|
67 |
print('get dataloaders') |
|
|
68 |
train_dataloader = NeighborSampler('train', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.train_mask], sizes = hparams['neighbor_sampler_sizes'], batch_size = hparams['batch_size'], shuffle = True, num_workers=hparams['num_workers'], do_filter_edges=hparams['filter_edges']) |
|
|
69 |
val_dataloader = NeighborSampler('val', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.val_mask], sizes = hparams['neighbor_sampler_sizes'], batch_size = hparams['batch_size'], shuffle = False, num_workers=hparams['num_workers'], do_filter_edges=hparams['filter_edges']) |
|
|
70 |
test_dataloader = NeighborSampler('test', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.test_mask], sizes = hparams['neighbor_sampler_sizes'], batch_size = hparams['batch_size'], shuffle = False, num_workers=hparams['num_workers'], do_filter_edges=hparams['filter_edges']) |
|
|
71 |
return train_dataloader, val_dataloader, test_dataloader |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
def train(args, hparams): |
|
|
75 |
|
|
|
76 |
# Seed |
|
|
77 |
pl.seed_everything(hparams['seed']) |
|
|
78 |
|
|
|
79 |
# Read input data |
|
|
80 |
all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args) |
|
|
81 |
|
|
|
82 |
# Set up |
|
|
83 |
if args.resume != "": |
|
|
84 |
if ":" in args.resume: # colons are not allowed in ID/resume name |
|
|
85 |
resume_id = "_".join(args.resume.split(":")) |
|
|
86 |
run_name = args.resume |
|
|
87 |
wandb_logger = WandbLogger(run_name, project='kg-train', entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'], id=resume_id, resume=resume_id) |
|
|
88 |
model = NodeEmbeder.load_from_checkpoint(checkpoint_path=str(Path(args.save_dir) / 'checkpoints' / args.best_ckpt), |
|
|
89 |
all_data=all_data, edge_attr_dict=edge_attr_dict, |
|
|
90 |
num_nodes=len(nodes["node_idx"].unique()), combined_training=False) |
|
|
91 |
else: |
|
|
92 |
curr_time = datetime.now().strftime("%H:%M:%S") |
|
|
93 |
run_name = f"{curr_time}_run" |
|
|
94 |
wandb_logger = WandbLogger(run_name, project='kg-train', entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'], id="_".join(run_name.split(":")), resume="allow") |
|
|
95 |
model = NodeEmbeder(all_data, edge_attr_dict, hp_dict=hparams, num_nodes=len(nodes["node_idx"].unique()), combined_training=False) |
|
|
96 |
|
|
|
97 |
checkpoint_callback = ModelCheckpoint(monitor='val/node_total_acc', dirpath=Path(args.save_dir) / 'checkpoints', filename=f'{run_name}' + '_{epoch}', save_top_k=1, mode='max') |
|
|
98 |
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
|
99 |
wandb_logger.watch(model, log='all') |
|
|
100 |
|
|
|
101 |
if hparams['debug']: |
|
|
102 |
limit_train_batches = 1 |
|
|
103 |
limit_val_batches = 1.0 |
|
|
104 |
hparams['max_epochs'] = 3 |
|
|
105 |
else: |
|
|
106 |
limit_train_batches = 1.0 |
|
|
107 |
limit_val_batches = 1.0 |
|
|
108 |
|
|
|
109 |
trainer = pl.Trainer(gpus=hparams['n_gpus'], logger=wandb_logger, |
|
|
110 |
max_epochs=hparams['max_epochs'], |
|
|
111 |
callbacks=[checkpoint_callback, lr_monitor], |
|
|
112 |
gradient_clip_val=hparams['gradclip'], |
|
|
113 |
profiler=hparams['profiler'], |
|
|
114 |
log_every_n_steps=hparams['log_every_n_steps'], |
|
|
115 |
limit_train_batches=limit_train_batches, |
|
|
116 |
limit_val_batches=limit_val_batches, |
|
|
117 |
) |
|
|
118 |
train_dataloader, val_dataloader, test_dataloader = get_dataloaders(hparams, all_data) |
|
|
119 |
|
|
|
120 |
# Train |
|
|
121 |
trainer.fit(model, train_dataloader, val_dataloader) |
|
|
122 |
|
|
|
123 |
# Test |
|
|
124 |
trainer.test(ckpt_path='best', test_dataloaders=test_dataloader) |
|
|
125 |
|
|
|
126 |
@torch.no_grad() |
|
|
127 |
def save_embeddings(args, hparams): |
|
|
128 |
print('Saving Embeddings') |
|
|
129 |
|
|
|
130 |
# Seed |
|
|
131 |
pl.seed_everything(hparams['seed']) |
|
|
132 |
|
|
|
133 |
# Read input data |
|
|
134 |
all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args) |
|
|
135 |
all_data.num_nodes = len(nodes["node_idx"].unique()) |
|
|
136 |
|
|
|
137 |
model = NodeEmbeder.load_from_checkpoint(checkpoint_path=str(Path(args.save_dir) / 'checkpoints' / args.best_ckpt), |
|
|
138 |
all_data=all_data, edge_attr_dict=edge_attr_dict, |
|
|
139 |
num_nodes=len(nodes["node_idx"].unique()), combined_training=False) |
|
|
140 |
|
|
|
141 |
dataloader = DataLoader([all_data], batch_size=1) |
|
|
142 |
trainer = pl.Trainer(gpus=0, |
|
|
143 |
gradient_clip_val=hparams['gradclip'] |
|
|
144 |
) |
|
|
145 |
embeddings = trainer.predict(model, dataloaders=dataloader) |
|
|
146 |
embed_path = Path(args.save_dir) / (str(args.best_ckpt).split('.ckpt')[0] + '.embed') |
|
|
147 |
torch.save(embeddings[0], str(embed_path)) |
|
|
148 |
print(embeddings[0].shape) |
|
|
149 |
|
|
|
150 |
|
|
|
151 |
|
|
|
152 |
if __name__ == "__main__": |
|
|
153 |
|
|
|
154 |
# Get hyperparameters |
|
|
155 |
args = parse_args() |
|
|
156 |
hparams = get_pretrain_hparams(args, combined=False) |
|
|
157 |
|
|
|
158 |
if args.save_embeddings: |
|
|
159 |
# save node embeddings from a trained model |
|
|
160 |
save_embeddings(args, hparams) |
|
|
161 |
else: |
|
|
162 |
# Train model |
|
|
163 |
train(args, hparams) |