|
a |
|
b/kgwas/kgwas.py |
|
|
1 |
from copy import deepcopy |
|
|
2 |
from tqdm import tqdm |
|
|
3 |
import os |
|
|
4 |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
|
|
5 |
import pandas as pd |
|
|
6 |
import numpy as np |
|
|
7 |
import pickle |
|
|
8 |
import subprocess |
|
|
9 |
|
|
|
10 |
import torch |
|
|
11 |
import torch.nn.functional as F |
|
|
12 |
import torch.optim as optim |
|
|
13 |
|
|
|
14 |
import torch.multiprocessing |
|
|
15 |
torch.multiprocessing.set_sharing_strategy('file_system') |
|
|
16 |
from torch_geometric.loader import NeighborLoader |
|
|
17 |
from .utils import print_sys, compute_metrics, save_dict, \ |
|
|
18 |
load_dict, load_pretrained, save_model, \ |
|
|
19 |
evaluate_minibatch_clean, process_data, \ |
|
|
20 |
get_network_weight, generate_viz |
|
|
21 |
from .eval_utils import storey_ribshirani_integrate, get_clumps_gold_label, get_meta_clumps, \ |
|
|
22 |
get_mega_clump_query, get_curve, find_closest_x |
|
|
23 |
from .model import HeteroGNN |
|
|
24 |
|
|
|
25 |
class KGWAS: |
|
|
26 |
def __init__(self, |
|
|
27 |
data, |
|
|
28 |
weight_bias_track = False, |
|
|
29 |
device = 'cuda', |
|
|
30 |
proj_name = 'KGWAS', |
|
|
31 |
exp_name = 'KGWAS', |
|
|
32 |
seed = 42): |
|
|
33 |
torch.manual_seed(seed) |
|
|
34 |
torch.cuda.manual_seed(seed) |
|
|
35 |
np.random.seed(seed) |
|
|
36 |
self.seed = seed |
|
|
37 |
torch.backends.cudnn.enabled = False |
|
|
38 |
use_cuda = torch.cuda.is_available() |
|
|
39 |
self.device = device if use_cuda else "cpu" |
|
|
40 |
|
|
|
41 |
self.data = data |
|
|
42 |
self.data_path = data.data_path |
|
|
43 |
if weight_bias_track: |
|
|
44 |
import wandb |
|
|
45 |
wandb.init(project=proj_name, name=exp_name) |
|
|
46 |
self.wandb = wandb |
|
|
47 |
else: |
|
|
48 |
self.wandb = False |
|
|
49 |
self.exp_name = exp_name |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
def initialize_model(self, gnn_num_layers = 2, gnn_hidden_dim = 128, gnn_backbone = 'GAT', gnn_aggr = 'sum', gat_num_head = 1, no_relu = False): |
|
|
53 |
|
|
|
54 |
self.config = { |
|
|
55 |
'gnn_num_layers': gnn_num_layers, |
|
|
56 |
'gnn_hidden_dim': gnn_hidden_dim, |
|
|
57 |
'gnn_backbone': gnn_backbone, |
|
|
58 |
'gnn_aggr': gnn_aggr, |
|
|
59 |
'gat_num_head': gat_num_head |
|
|
60 |
} |
|
|
61 |
|
|
|
62 |
self.gnn_num_layers = gnn_num_layers |
|
|
63 |
self.model = HeteroGNN(self.data.data, gnn_hidden_dim, 1, |
|
|
64 |
gnn_num_layers, gnn_backbone, gnn_aggr, |
|
|
65 |
self.data.snp_init_dim_size, |
|
|
66 |
self.data.gene_init_dim_size, |
|
|
67 |
self.data.go_init_dim_size, |
|
|
68 |
gat_num_head, |
|
|
69 |
no_relu = no_relu, |
|
|
70 |
).to(self.device) |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
def load_pretrained(self, path): |
|
|
74 |
with open(os.path.join(path, 'config.pkl'), 'rb') as f: |
|
|
75 |
config = pickle.load(f) |
|
|
76 |
|
|
|
77 |
self.initialize_model(**config) |
|
|
78 |
self.config = config |
|
|
79 |
|
|
|
80 |
self.model = load_pretrained(path, self.model) |
|
|
81 |
self.best_model = self.model |
|
|
82 |
self.kgwas_res = pd.read_csv(os.path.join(path, 'pred.csv'), sep = None, engine = 'python') |
|
|
83 |
self.save_name = path.split('/')[-1] |
|
|
84 |
|
|
|
85 |
def train(self, batch_size = 512, num_workers = 0, lr = 1e-4, |
|
|
86 |
weight_decay = 5e-4, epoch = 10, save_best_model = True, |
|
|
87 |
save_name = None, data_to_cuda = False): |
|
|
88 |
total_epoch = epoch |
|
|
89 |
if save_name is None: |
|
|
90 |
save_name = self.exp_name |
|
|
91 |
self.save_name = save_name |
|
|
92 |
print_sys('Creating data loader...') |
|
|
93 |
kwargs = {'batch_size': batch_size, 'num_workers': num_workers, 'drop_last': True} |
|
|
94 |
eval_kwargs = {'batch_size': 512, 'num_workers': num_workers, 'drop_last': False} |
|
|
95 |
|
|
|
96 |
if data_to_cuda: |
|
|
97 |
self.data.data = self.data.data.to(self.device) |
|
|
98 |
|
|
|
99 |
self.train_loader = NeighborLoader(self.data.data, num_neighbors=[-1] * self.gnn_num_layers, |
|
|
100 |
sampler = None, |
|
|
101 |
input_nodes=self.data.train_input_nodes, **kwargs) |
|
|
102 |
self.val_loader = NeighborLoader(self.data.data, num_neighbors=[-1] * self.gnn_num_layers, |
|
|
103 |
input_nodes=self.data.val_input_nodes, **kwargs) |
|
|
104 |
self.test_loader = NeighborLoader(self.data.data, num_neighbors=[-1] * self.gnn_num_layers, |
|
|
105 |
input_nodes=self.data.test_input_nodes, **eval_kwargs) |
|
|
106 |
|
|
|
107 |
X_infer = self.data.lr_uni.ID.values |
|
|
108 |
#print_sys('# of to-infer SNPs: ' + str(len(X_infer))) |
|
|
109 |
infer_idx = np.array([self.data.id2idx['SNP'][i] for i in X_infer]) |
|
|
110 |
infer_input_nodes = ('SNP', infer_idx) |
|
|
111 |
|
|
|
112 |
self.infer_loader = NeighborLoader(self.data.data, num_neighbors=[-1] * self.gnn_num_layers, |
|
|
113 |
input_nodes=infer_input_nodes, **eval_kwargs) |
|
|
114 |
|
|
|
115 |
## model training |
|
|
116 |
optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay = weight_decay) |
|
|
117 |
|
|
|
118 |
loss_fct = F.mse_loss |
|
|
119 |
earlystop_validation_metric = 'pearsonr' |
|
|
120 |
binary_output = False |
|
|
121 |
earlystop_direction = 'ascend' |
|
|
122 |
min_val = -1000 |
|
|
123 |
|
|
|
124 |
self.best_model = deepcopy(self.model).to(self.device) |
|
|
125 |
print_sys('Start Training...') |
|
|
126 |
for epoch in range(total_epoch): |
|
|
127 |
self.model.train() |
|
|
128 |
|
|
|
129 |
for step, batch in enumerate(tqdm(self.train_loader, desc=f"Training Progress Epoch {epoch+1}/{total_epoch}", total=len(self.train_loader))): |
|
|
130 |
optimizer.zero_grad() |
|
|
131 |
if data_to_cuda: |
|
|
132 |
pass |
|
|
133 |
#batch = batch.to(self.device, 'edge_index') |
|
|
134 |
else: |
|
|
135 |
batch = batch.to(self.device) |
|
|
136 |
bs_batch = batch['SNP'].batch_size |
|
|
137 |
|
|
|
138 |
out = self.model(batch.x_dict, batch.edge_index_dict, bs_batch) |
|
|
139 |
pred = out.reshape(-1) |
|
|
140 |
|
|
|
141 |
y_batch = batch['SNP'].y[:bs_batch] |
|
|
142 |
rs_id = [self.data.idx2id['SNP'][i.item()] for i in batch['SNP']['n_id'][:bs_batch]] |
|
|
143 |
ld_weight = torch.tensor([self.data.rs_id_to_ldsc_weight[i] for i in rs_id]).to(self.device) |
|
|
144 |
|
|
|
145 |
loss = torch.mean(ld_weight * (pred - y_batch)**2) |
|
|
146 |
|
|
|
147 |
if self.wandb: |
|
|
148 |
self.wandb.log({'training_loss': loss.item()}) |
|
|
149 |
|
|
|
150 |
loss.backward() |
|
|
151 |
optimizer.step() |
|
|
152 |
|
|
|
153 |
if (step % 500 == 0) and (step >= 500): |
|
|
154 |
log = "Epoch {} Step {} Train Loss: {:.4f}" |
|
|
155 |
print_sys(log.format(epoch + 1, step + 1, loss.item())) |
|
|
156 |
|
|
|
157 |
val_res = evaluate_minibatch_clean(self.val_loader, self.model, self.device) |
|
|
158 |
val_metrics = compute_metrics(val_res, binary_output, -1, -1, loss_fct) |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
log = "Epoch {}: Validation MSE: {:.4f} " \ |
|
|
162 |
"Validation Pearson: {:.4f}. " |
|
|
163 |
print_sys(log.format(epoch + 1, val_metrics['mse'], |
|
|
164 |
val_metrics['pearsonr'])) |
|
|
165 |
|
|
|
166 |
if self.wandb: |
|
|
167 |
for i,j in val_metrics.items(): |
|
|
168 |
self.wandb.log({'val_' + i: j}) |
|
|
169 |
|
|
|
170 |
if val_metrics[earlystop_validation_metric] > min_val: |
|
|
171 |
min_val = val_metrics[earlystop_validation_metric] |
|
|
172 |
self.best_model = deepcopy(self.model) |
|
|
173 |
best_epoch = epoch |
|
|
174 |
|
|
|
175 |
|
|
|
176 |
if save_best_model: |
|
|
177 |
save_model_path = self.data_path + '/model/' |
|
|
178 |
print_sys('Saving models to ' + os.path.join(save_model_path, save_name)) |
|
|
179 |
save_model(self.best_model, self.config, os.path.join(save_model_path, save_name)) |
|
|
180 |
|
|
|
181 |
|
|
|
182 |
test_res = evaluate_minibatch_clean(self.test_loader, self.best_model, self.device) |
|
|
183 |
test_metric = compute_metrics(test_res, binary_output, -1, -1, loss_fct) |
|
|
184 |
if self.wandb: |
|
|
185 |
for i,j in test_metric.items(): |
|
|
186 |
self.wandb.log({'test_' + i: j}) |
|
|
187 |
|
|
|
188 |
|
|
|
189 |
infer_res = evaluate_minibatch_clean(self.infer_loader, self.best_model, self.device) |
|
|
190 |
|
|
|
191 |
self.data.lr_uni['pred'] = infer_res['pred'] |
|
|
192 |
lr_uni_to_save = deepcopy(self.data.lr_uni) |
|
|
193 |
|
|
|
194 |
self.data.lr_uni['abs_pred'] = np.abs(self.data.lr_uni['pred']) |
|
|
195 |
|
|
|
196 |
self.data.lr_uni['SR_P_val'] = storey_ribshirani_integrate(self.data.lr_uni, column = 'abs_pred', num_bins = 500) |
|
|
197 |
self.data.lr_uni['SR'] = -(np.log10(self.data.lr_uni['SR_P_val'].astype(float).values)) |
|
|
198 |
lr_uni_to_save['P_weighted'] = self.data.lr_uni['SR_P_val'] |
|
|
199 |
|
|
|
200 |
## calibration |
|
|
201 |
scale_factor = find_closest_x(lr_uni_to_save) |
|
|
202 |
lr_uni_to_save['KGWAS_P'] = scale_factor * lr_uni_to_save['P_weighted'] |
|
|
203 |
lr_uni_to_save['KGWAS_P'] = lr_uni_to_save['KGWAS_P'].clip(lower=0, upper=1) |
|
|
204 |
|
|
|
205 |
if not os.path.exists(self.data_path + '/model_pred/'): |
|
|
206 |
os.makedirs(self.data_path + '/model_pred/') |
|
|
207 |
os.makedirs(self.data_path + '/model_pred/new_experiments/') |
|
|
208 |
lr_uni_to_save.to_csv(self.data_path + '/model_pred/new_experiments/' + save_name + '_pred.csv', index = False, sep = '\t') |
|
|
209 |
print('KGWAS prediction and p-values saved to ' + self.data_path + '/model_pred/new_experiments/' + save_name + '_pred.csv') |
|
|
210 |
if save_best_model: |
|
|
211 |
lr_uni_to_save.to_csv(self.data_path + '/model/' + save_name + '/pred.csv', index = False, sep = '\t') |
|
|
212 |
self.kgwas_res = lr_uni_to_save |
|
|
213 |
|
|
|
214 |
def run_magma(self, path_to_magma, bfile): |
|
|
215 |
if 'N' in self.kgwas_res.columns: |
|
|
216 |
n_value = self.kgwas_res['N'].values[0] |
|
|
217 |
else: |
|
|
218 |
n_value = input("Please provide the sample size for the GWAS analysis.") |
|
|
219 |
|
|
|
220 |
url = "https://dataverse.harvard.edu/api/access/datafile/10731670" |
|
|
221 |
annot_file_path = os.path.join(self.data_path, 'gene_annotation.genes.annot') |
|
|
222 |
|
|
|
223 |
# Check if the example file is already downloaded |
|
|
224 |
if not os.path.exists(annot_file_path): |
|
|
225 |
print('Annotation file not found locally. Downloading...') |
|
|
226 |
self.data._download_with_progress(url, annot_file_path) |
|
|
227 |
print('Annotation file downloaded successfully.') |
|
|
228 |
else: |
|
|
229 |
print('Annotation file already exists locally.') |
|
|
230 |
|
|
|
231 |
gene_annot = annot_file_path |
|
|
232 |
|
|
|
233 |
magma_path = self.data_path + '/model_pred/new_experiments/' + self.save_name + '_magma_format.csv' |
|
|
234 |
self.kgwas_res[['ID', 'KGWAS_P']].rename(columns = {'ID': 'SNP', 'KGWAS_P': 'P'}).to_csv(magma_path, index = False, sep = '\t') |
|
|
235 |
|
|
|
236 |
# Construct the MAGMA command |
|
|
237 |
command = [ |
|
|
238 |
path_to_magma, |
|
|
239 |
"--bfile", bfile, |
|
|
240 |
"--gene-annot", gene_annot, |
|
|
241 |
"--pval", magma_path, f"N={n_value}", |
|
|
242 |
"--out", self.data_path + '/model_pred/new_experiments/' + self.save_name + '_magma_out' |
|
|
243 |
] |
|
|
244 |
|
|
|
245 |
try: |
|
|
246 |
# Run the command with real-time output |
|
|
247 |
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
|
|
248 |
print("Running MAGMA...") |
|
|
249 |
|
|
|
250 |
# Stream stdout line by line |
|
|
251 |
for line in process.stdout: |
|
|
252 |
print(line, end="") # Print each line as it's received |
|
|
253 |
|
|
|
254 |
# Wait for the process to complete and capture stderr |
|
|
255 |
stderr = process.communicate()[1] |
|
|
256 |
|
|
|
257 |
if process.returncode == 0: |
|
|
258 |
print("MAGMA command executed successfully.") |
|
|
259 |
else: |
|
|
260 |
print("MAGMA encountered an error.") |
|
|
261 |
print("Error message:", stderr) |
|
|
262 |
except FileNotFoundError: |
|
|
263 |
print("MAGMA executable not found. Ensure it is in the specified path.") |
|
|
264 |
except Exception as e: |
|
|
265 |
print(f"An unexpected error occurred: {e}") |
|
|
266 |
|
|
|
267 |
|
|
|
268 |
def get_disease_critical_network(self, variant_threshold = 5e-8, |
|
|
269 |
magma_path = None, magma_threshold = 0.05, program_threshold = 0.05, |
|
|
270 |
K_neighbors = 3, num_cpus = 1): |
|
|
271 |
df_network_weight = get_network_weight(self, self.data) |
|
|
272 |
df_variant_interpretation, disease_critical_network = generate_viz(self, df_network_weight, self.data_path, variant_threshold, magma_path, magma_threshold, program_threshold, K_neighbors, num_cpus) |
|
|
273 |
return df_network_weight, df_variant_interpretation, disease_critical_network |