Diff of /shepherd/pretrain.py [000000] .. [db6163]

Switch to unified view

a b/shepherd/pretrain.py
1
# General
2
import numpy as np
3
import random
4
import argparse
5
import os
6
from copy import deepcopy
7
from pathlib import Path
8
import sys
9
from datetime import datetime
10
11
# Pytorch
12
import torch
13
import torch.nn as nn
14
15
# Pytorch Lightning
16
import pytorch_lightning as pl
17
from pytorch_lightning.loggers import WandbLogger
18
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
19
20
# Pytorch Geo
21
from torch_geometric.data.sampler import NeighborSampler as PyGeoNeighborSampler
22
from torch_geometric.data import Data, DataLoader
23
24
# W&B
25
import wandb
26
27
sys.path.insert(0, '..') # add project_config to path
28
29
# Own code
30
import preprocess
31
from node_embedder_model import NodeEmbeder
32
import project_config
33
from hparams import get_pretrain_hparams
34
from samplers import NeighborSampler
35
36
37
def parse_args():
38
    parser = argparse.ArgumentParser(description="Learn node embeddings.")
39
    
40
    # Input files/parameters
41
    parser.add_argument("--edgelist", type=str, default=None, help="File with edge list")
42
    parser.add_argument("--node_map", type=str, default=None, help="File with node list")
43
    parser.add_argument('--save_dir', type=str, default=None, help='Directory for saving files')
44
    
45
    # Tunable parameters
46
    parser.add_argument('--nfeat', type=int, default=2048, help='Dimension of embedding layer')
47
    parser.add_argument('--hidden', default=256, type=int)
48
    parser.add_argument('--output', default=128, type=int)
49
    parser.add_argument('--n_heads', default=4, type=int)
50
    parser.add_argument('--wd', default=0.0, type=float)
51
    parser.add_argument('--dropout', type=float, default=0.3, help='Dropout')
52
    parser.add_argument('--lr', default=0.0001, type=float)
53
    parser.add_argument('--max_epochs', default=1000, type=int)
54
    
55
    # Resume with best checkpoint
56
    parser.add_argument('--resume', default="", type=str)
57
    parser.add_argument('--best_ckpt', type=str, default=None, help='Name of the best performing checkpoint')
58
    
59
    # Output
60
    parser.add_argument('--save_embeddings', action='store_true')
61
62
    args = parser.parse_args()
63
    return args
64
65
66
def get_dataloaders(hparams, all_data):
67
    print('get dataloaders')
68
    train_dataloader = NeighborSampler('train', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.train_mask], sizes = hparams['neighbor_sampler_sizes'], batch_size = hparams['batch_size'], shuffle = True, num_workers=hparams['num_workers'], do_filter_edges=hparams['filter_edges'])
69
    val_dataloader = NeighborSampler('val', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.val_mask], sizes = hparams['neighbor_sampler_sizes'], batch_size = hparams['batch_size'], shuffle = False, num_workers=hparams['num_workers'], do_filter_edges=hparams['filter_edges'])
70
    test_dataloader = NeighborSampler('test', all_data.edge_index[:,all_data.train_mask], all_data.edge_index[:,all_data.test_mask], sizes = hparams['neighbor_sampler_sizes'], batch_size = hparams['batch_size'], shuffle = False, num_workers=hparams['num_workers'], do_filter_edges=hparams['filter_edges'])
71
    return train_dataloader, val_dataloader, test_dataloader 
72
73
74
def train(args, hparams):
75
76
    # Seed
77
    pl.seed_everything(hparams['seed'])
78
79
    # Read input data
80
    all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args)
81
82
    # Set up
83
    if args.resume != "":
84
        if ":" in args.resume: # colons are not allowed in ID/resume name
85
            resume_id = "_".join(args.resume.split(":"))
86
        run_name = args.resume
87
        wandb_logger = WandbLogger(run_name, project='kg-train', entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'], id=resume_id, resume=resume_id)
88
        model = NodeEmbeder.load_from_checkpoint(checkpoint_path=str(Path(args.save_dir) / 'checkpoints' /  args.best_ckpt), 
89
                                                 all_data=all_data, edge_attr_dict=edge_attr_dict, 
90
                                                 num_nodes=len(nodes["node_idx"].unique()), combined_training=False) 
91
    else:
92
        curr_time = datetime.now().strftime("%H:%M:%S")
93
        run_name = f"{curr_time}_run"
94
        wandb_logger = WandbLogger(run_name, project='kg-train', entity='rare_disease_dx', save_dir=hparams['wandb_save_dir'], id="_".join(run_name.split(":")), resume="allow")
95
        model = NodeEmbeder(all_data, edge_attr_dict, hp_dict=hparams, num_nodes=len(nodes["node_idx"].unique()), combined_training=False)
96
97
    checkpoint_callback = ModelCheckpoint(monitor='val/node_total_acc', dirpath=Path(args.save_dir) / 'checkpoints', filename=f'{run_name}' + '_{epoch}', save_top_k=1, mode='max')
98
    lr_monitor = LearningRateMonitor(logging_interval='step')
99
    wandb_logger.watch(model, log='all')
100
101
    if hparams['debug']:
102
        limit_train_batches = 1
103
        limit_val_batches = 1.0 
104
        hparams['max_epochs'] = 3
105
    else:
106
        limit_train_batches = 1.0
107
        limit_val_batches = 1.0 
108
109
    trainer = pl.Trainer(gpus=hparams['n_gpus'], logger=wandb_logger, 
110
                         max_epochs=hparams['max_epochs'], 
111
                         callbacks=[checkpoint_callback, lr_monitor], 
112
                         gradient_clip_val=hparams['gradclip'],
113
                         profiler=hparams['profiler'],
114
                         log_every_n_steps=hparams['log_every_n_steps'],
115
                         limit_train_batches=limit_train_batches, 
116
                         limit_val_batches=limit_val_batches,
117
                        ) 
118
    train_dataloader, val_dataloader, test_dataloader = get_dataloaders(hparams, all_data)
119
120
    # Train
121
    trainer.fit(model, train_dataloader, val_dataloader)
122
    
123
    # Test
124
    trainer.test(ckpt_path='best', test_dataloaders=test_dataloader)
125
126
@torch.no_grad()
127
def save_embeddings(args, hparams):
128
    print('Saving Embeddings')
129
130
    # Seed
131
    pl.seed_everything(hparams['seed'])
132
133
    # Read input data
134
    all_data, edge_attr_dict, nodes = preprocess.preprocess_graph(args)
135
    all_data.num_nodes = len(nodes["node_idx"].unique())
136
137
    model = NodeEmbeder.load_from_checkpoint(checkpoint_path=str(Path(args.save_dir) / 'checkpoints' /  args.best_ckpt), 
138
                                            all_data=all_data, edge_attr_dict=edge_attr_dict, 
139
                                            num_nodes=len(nodes["node_idx"].unique()), combined_training=False) 
140
   
141
    dataloader = DataLoader([all_data], batch_size=1)
142
    trainer = pl.Trainer(gpus=0, 
143
                        gradient_clip_val=hparams['gradclip']
144
                    ) 
145
    embeddings = trainer.predict(model, dataloaders=dataloader)  
146
    embed_path = Path(args.save_dir) / (str(args.best_ckpt).split('.ckpt')[0] + '.embed')
147
    torch.save(embeddings[0], str(embed_path))
148
    print(embeddings[0].shape)
149
150
151
152
if __name__ == "__main__":
153
    
154
    # Get hyperparameters
155
    args = parse_args()
156
    hparams = get_pretrain_hparams(args, combined=False) 
157
    
158
    if args.save_embeddings:
159
        # save node embeddings from a trained model
160
        save_embeddings(args, hparams)
161
    else:
162
        # Train model
163
        train(args, hparams)