Diff of /kgwas/kgwas.py [000000] .. [8790ab]

Switch to unified view

a b/kgwas/kgwas.py
1
from copy import deepcopy
2
from tqdm import tqdm
3
import os
4
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
5
import pandas as pd
6
import numpy as np
7
import pickle
8
import subprocess
9
10
import torch
11
import torch.nn.functional as F
12
import torch.optim as optim
13
14
import torch.multiprocessing
15
torch.multiprocessing.set_sharing_strategy('file_system')
16
from torch_geometric.loader import NeighborLoader
17
from .utils import print_sys, compute_metrics, save_dict, \
18
                        load_dict, load_pretrained, save_model, \
19
                        evaluate_minibatch_clean, process_data, \
20
                        get_network_weight, generate_viz
21
from .eval_utils import storey_ribshirani_integrate, get_clumps_gold_label, get_meta_clumps, \
22
                        get_mega_clump_query, get_curve, find_closest_x
23
from .model import HeteroGNN
24
25
class KGWAS:
26
    def __init__(self,
27
                data,
28
                weight_bias_track = False,
29
                device = 'cuda',
30
                proj_name = 'KGWAS',
31
                exp_name = 'KGWAS',
32
                seed = 42):
33
        torch.manual_seed(seed)
34
        torch.cuda.manual_seed(seed)
35
        np.random.seed(seed)
36
        self.seed = seed
37
        torch.backends.cudnn.enabled = False
38
        use_cuda = torch.cuda.is_available()
39
        self.device = device if use_cuda else "cpu"
40
41
        self.data = data
42
        self.data_path = data.data_path
43
        if weight_bias_track:
44
            import wandb
45
            wandb.init(project=proj_name, name=exp_name)  
46
            self.wandb = wandb
47
        else:
48
            self.wandb = False
49
        self.exp_name = exp_name
50
        
51
        
52
    def initialize_model(self, gnn_num_layers = 2, gnn_hidden_dim = 128, gnn_backbone = 'GAT', gnn_aggr = 'sum', gat_num_head = 1, no_relu = False):
53
54
        self.config = {
55
            'gnn_num_layers': gnn_num_layers,
56
            'gnn_hidden_dim': gnn_hidden_dim,
57
            'gnn_backbone': gnn_backbone,
58
            'gnn_aggr': gnn_aggr,
59
            'gat_num_head': gat_num_head
60
        }
61
62
        self.gnn_num_layers = gnn_num_layers
63
        self.model = HeteroGNN(self.data.data, gnn_hidden_dim, 1, 
64
                               gnn_num_layers, gnn_backbone, gnn_aggr,
65
                               self.data.snp_init_dim_size,
66
                               self.data.gene_init_dim_size,
67
                               self.data.go_init_dim_size,
68
                               gat_num_head,
69
                               no_relu = no_relu,
70
                              ).to(self.device)
71
    
72
    
73
    def load_pretrained(self, path):
74
        with open(os.path.join(path, 'config.pkl'), 'rb') as f:
75
            config = pickle.load(f)
76
            
77
        self.initialize_model(**config)
78
        self.config = config
79
        
80
        self.model = load_pretrained(path, self.model)
81
        self.best_model = self.model
82
        self.kgwas_res = pd.read_csv(os.path.join(path, 'pred.csv'), sep = None, engine = 'python')
83
        self.save_name = path.split('/')[-1]
84
85
    def train(self, batch_size = 512, num_workers = 0, lr = 1e-4, 
86
                    weight_decay = 5e-4, epoch = 10, save_best_model = True, 
87
                    save_name = None, data_to_cuda = False):
88
        total_epoch = epoch
89
        if save_name is None:
90
            save_name = self.exp_name
91
        self.save_name = save_name
92
        print_sys('Creating data loader...')
93
        kwargs = {'batch_size': batch_size, 'num_workers': num_workers, 'drop_last': True}
94
        eval_kwargs = {'batch_size': 512, 'num_workers': num_workers, 'drop_last': False}
95
        
96
        if data_to_cuda:
97
            self.data.data = self.data.data.to(self.device)
98
        
99
        self.train_loader = NeighborLoader(self.data.data, num_neighbors=[-1] * self.gnn_num_layers, 
100
                                      sampler = None,
101
                                      input_nodes=self.data.train_input_nodes, **kwargs)
102
        self.val_loader = NeighborLoader(self.data.data, num_neighbors=[-1] * self.gnn_num_layers,
103
                                    input_nodes=self.data.val_input_nodes, **kwargs)
104
        self.test_loader = NeighborLoader(self.data.data, num_neighbors=[-1] * self.gnn_num_layers,
105
                                    input_nodes=self.data.test_input_nodes, **eval_kwargs)
106
        
107
        X_infer = self.data.lr_uni.ID.values
108
        #print_sys('# of to-infer SNPs: ' + str(len(X_infer)))
109
        infer_idx = np.array([self.data.id2idx['SNP'][i] for i in X_infer])
110
        infer_input_nodes = ('SNP', infer_idx)
111
112
        self.infer_loader = NeighborLoader(self.data.data, num_neighbors=[-1] * self.gnn_num_layers,
113
                                    input_nodes=infer_input_nodes, **eval_kwargs)
114
        
115
        ## model training
116
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay = weight_decay)
117
118
        loss_fct = F.mse_loss
119
        earlystop_validation_metric = 'pearsonr'
120
        binary_output = False
121
        earlystop_direction = 'ascend'
122
        min_val = -1000
123
124
        self.best_model = deepcopy(self.model).to(self.device)
125
        print_sys('Start Training...')
126
        for epoch in range(total_epoch):
127
            self.model.train()
128
129
            for step, batch in enumerate(tqdm(self.train_loader, desc=f"Training Progress Epoch {epoch+1}/{total_epoch}", total=len(self.train_loader))):
130
                optimizer.zero_grad()
131
                if data_to_cuda:
132
                    pass
133
                    #batch = batch.to(self.device, 'edge_index')
134
                else:
135
                    batch = batch.to(self.device)
136
                bs_batch = batch['SNP'].batch_size
137
138
                out = self.model(batch.x_dict, batch.edge_index_dict, bs_batch)
139
                pred = out.reshape(-1)
140
141
                y_batch = batch['SNP'].y[:bs_batch]
142
                rs_id = [self.data.idx2id['SNP'][i.item()] for i in batch['SNP']['n_id'][:bs_batch]]
143
                ld_weight = torch.tensor([self.data.rs_id_to_ldsc_weight[i] for i in rs_id]).to(self.device)
144
145
                loss = torch.mean(ld_weight * (pred - y_batch)**2)
146
147
                if self.wandb:
148
                    self.wandb.log({'training_loss': loss.item()})
149
150
                loss.backward()
151
                optimizer.step()
152
153
                if (step % 500 == 0) and (step >= 500):
154
                    log = "Epoch {} Step {} Train Loss: {:.4f}" 
155
                    print_sys(log.format(epoch + 1, step + 1, loss.item()))
156
157
            val_res = evaluate_minibatch_clean(self.val_loader, self.model, self.device)
158
            val_metrics = compute_metrics(val_res, binary_output, -1, -1, loss_fct)
159
160
161
            log = "Epoch {}: Validation MSE: {:.4f} " \
162
                      "Validation Pearson: {:.4f}. "
163
            print_sys(log.format(epoch + 1, val_metrics['mse'], 
164
                         val_metrics['pearsonr']))
165
166
            if self.wandb:
167
                for i,j in val_metrics.items():
168
                    self.wandb.log({'val_' + i: j})
169
                
170
            if val_metrics[earlystop_validation_metric] > min_val:
171
                min_val = val_metrics[earlystop_validation_metric]
172
                self.best_model = deepcopy(self.model)
173
                best_epoch = epoch
174
175
176
        if save_best_model:
177
            save_model_path = self.data_path + '/model/'
178
            print_sys('Saving models to ' + os.path.join(save_model_path, save_name))
179
            save_model(self.best_model, self.config, os.path.join(save_model_path, save_name))
180
181
182
        test_res = evaluate_minibatch_clean(self.test_loader, self.best_model, self.device)
183
        test_metric = compute_metrics(test_res, binary_output, -1, -1, loss_fct)
184
        if self.wandb:
185
            for i,j in test_metric.items():
186
                self.wandb.log({'test_' + i: j})    
187
188
189
        infer_res = evaluate_minibatch_clean(self.infer_loader, self.best_model, self.device)
190
191
        self.data.lr_uni['pred'] = infer_res['pred']
192
        lr_uni_to_save = deepcopy(self.data.lr_uni)
193
194
        self.data.lr_uni['abs_pred'] = np.abs(self.data.lr_uni['pred'])
195
196
        self.data.lr_uni['SR_P_val'] = storey_ribshirani_integrate(self.data.lr_uni, column = 'abs_pred', num_bins = 500)
197
        self.data.lr_uni['SR'] = -(np.log10(self.data.lr_uni['SR_P_val'].astype(float).values))
198
        lr_uni_to_save['P_weighted'] = self.data.lr_uni['SR_P_val']
199
200
        ## calibration
201
        scale_factor = find_closest_x(lr_uni_to_save)
202
        lr_uni_to_save['KGWAS_P'] = scale_factor * lr_uni_to_save['P_weighted']
203
        lr_uni_to_save['KGWAS_P'] = lr_uni_to_save['KGWAS_P'].clip(lower=0, upper=1)
204
205
        if not os.path.exists(self.data_path + '/model_pred/'):
206
            os.makedirs(self.data_path + '/model_pred/')
207
            os.makedirs(self.data_path + '/model_pred/new_experiments/')
208
        lr_uni_to_save.to_csv(self.data_path + '/model_pred/new_experiments/' + save_name + '_pred.csv', index = False, sep = '\t')
209
        print('KGWAS prediction and p-values saved to ' + self.data_path + '/model_pred/new_experiments/' + save_name + '_pred.csv')
210
        if save_best_model:
211
            lr_uni_to_save.to_csv(self.data_path + '/model/' + save_name + '/pred.csv', index = False, sep = '\t')
212
        self.kgwas_res = lr_uni_to_save
213
214
    def run_magma(self, path_to_magma, bfile):
215
        if 'N' in self.kgwas_res.columns:
216
            n_value = self.kgwas_res['N'].values[0]
217
        else:
218
            n_value = input("Please provide the sample size for the GWAS analysis.")
219
        
220
        url = "https://dataverse.harvard.edu/api/access/datafile/10731670"
221
        annot_file_path = os.path.join(self.data_path, 'gene_annotation.genes.annot')
222
223
        # Check if the example file is already downloaded
224
        if not os.path.exists(annot_file_path):
225
            print('Annotation file not found locally. Downloading...')
226
            self.data._download_with_progress(url, annot_file_path)
227
            print('Annotation file downloaded successfully.')
228
        else:
229
            print('Annotation file already exists locally.')
230
231
        gene_annot = annot_file_path
232
233
        magma_path = self.data_path + '/model_pred/new_experiments/' + self.save_name + '_magma_format.csv'
234
        self.kgwas_res[['ID', 'KGWAS_P']].rename(columns = {'ID': 'SNP', 'KGWAS_P': 'P'}).to_csv(magma_path, index = False, sep = '\t')
235
236
        # Construct the MAGMA command
237
        command = [
238
            path_to_magma,
239
            "--bfile", bfile,
240
            "--gene-annot", gene_annot,
241
            "--pval", magma_path, f"N={n_value}",
242
            "--out", self.data_path + '/model_pred/new_experiments/' + self.save_name + '_magma_out'
243
        ]
244
        
245
        try:
246
            # Run the command with real-time output
247
            process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
248
            print("Running MAGMA...")
249
250
            # Stream stdout line by line
251
            for line in process.stdout:
252
                print(line, end="")  # Print each line as it's received
253
254
            # Wait for the process to complete and capture stderr
255
            stderr = process.communicate()[1]
256
257
            if process.returncode == 0:
258
                print("MAGMA command executed successfully.")
259
            else:
260
                print("MAGMA encountered an error.")
261
                print("Error message:", stderr)
262
        except FileNotFoundError:
263
            print("MAGMA executable not found. Ensure it is in the specified path.")
264
        except Exception as e:
265
            print(f"An unexpected error occurred: {e}")
266
267
268
    def get_disease_critical_network(self, variant_threshold = 5e-8, 
269
                magma_path = None, magma_threshold = 0.05, program_threshold = 0.05,
270
                K_neighbors = 3, num_cpus = 1):
271
        df_network_weight = get_network_weight(self, self.data)
272
        df_variant_interpretation, disease_critical_network = generate_viz(self, df_network_weight, self.data_path, variant_threshold, magma_path, magma_threshold, program_threshold, K_neighbors, num_cpus)
273
        return df_network_weight, df_variant_interpretation, disease_critical_network