|
a |
|
b/shepherd/hparams.py |
|
|
1 |
import project_config |
|
|
2 |
|
|
|
3 |
#################################################################### |
|
|
4 |
# |
|
|
5 |
# NODE EMBEDDER MODEL HYPERPARAMETERS |
|
|
6 |
# |
|
|
7 |
#################################################################### |
|
|
8 |
|
|
|
9 |
def get_pretrain_hparams(args, combined=False): |
|
|
10 |
print('node embedder args: ', args) |
|
|
11 |
|
|
|
12 |
# Default |
|
|
13 |
hparams = { |
|
|
14 |
# Tunable parameters |
|
|
15 |
'nfeat': args.nfeat if not combined else 4096, |
|
|
16 |
'hidden': args.hidden if not combined else 256, |
|
|
17 |
'output': args.output if not combined else 128, |
|
|
18 |
'n_heads': args.n_heads if not combined else 2, |
|
|
19 |
'wd': args.wd if not combined else 5e-4, |
|
|
20 |
'dropout': args.dropout if not combined else 0.2, |
|
|
21 |
'lr': args.lr if not combined else 0.0001, |
|
|
22 |
|
|
|
23 |
# Fixed parameters |
|
|
24 |
'decoder_type': 'bilinear', |
|
|
25 |
'norm_method': "batch_layer", |
|
|
26 |
'loss': 'max-margin', |
|
|
27 |
'pred_threshold': 0.5, |
|
|
28 |
'negative_sampler_approach': 'by_edge_type', |
|
|
29 |
'filter_edges': True, |
|
|
30 |
'n_gpus': 1, |
|
|
31 |
'num_workers': 4, |
|
|
32 |
'batch_size': 512, |
|
|
33 |
'inference_batch_size': 64, |
|
|
34 |
'neighbor_sampler_sizes': [15, 10, 5], |
|
|
35 |
'max_epochs': 200, |
|
|
36 |
'gradclip': 1.0, |
|
|
37 |
'lr_factor': 0.01, |
|
|
38 |
'lr_patience': 1000, |
|
|
39 |
'lr_threshold': 1e-4, |
|
|
40 |
'lr_threshold_mode': 'rel', |
|
|
41 |
'lr_cooldown': 0, |
|
|
42 |
'min_lr': 0, |
|
|
43 |
'eps': 1e-8, |
|
|
44 |
'seed': 1, |
|
|
45 |
'profiler': None, |
|
|
46 |
'wandb_save_dir': project_config.PROJECT_DIR / 'wandb' / 'preprocess', |
|
|
47 |
'log_every_n_steps': 10, |
|
|
48 |
'time': False, |
|
|
49 |
'debug': False |
|
|
50 |
} |
|
|
51 |
|
|
|
52 |
print('Pretrain hparams: ', hparams) |
|
|
53 |
|
|
|
54 |
return hparams |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
|
|
|
58 |
#################################################################### |
|
|
59 |
# |
|
|
60 |
# TRAIN MODEL HYPERPARAMETERS |
|
|
61 |
# |
|
|
62 |
#################################################################### |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
def get_train_hparams(args): |
|
|
66 |
print('Train model args: ', args) |
|
|
67 |
|
|
|
68 |
# Default |
|
|
69 |
hparams = { |
|
|
70 |
# Tunable parameters |
|
|
71 |
'sparse_sample': args.sparse_sample, # Randomly sample N nodes from KG |
|
|
72 |
'lr': args.lr, |
|
|
73 |
'upsample_cand': args.upsample_cand, |
|
|
74 |
'neighbor_sampler_sizes': [args.neighbor_sampler_size, 10, 5], |
|
|
75 |
'lambda': args.lmbda, # Contribution of two loss functions |
|
|
76 |
'alpha': args.alpha, # Contribution of GP gate. NOTE: This is not used for patients-like-me or novel disease characterization |
|
|
77 |
'kappa': (1 - args.lmbda) * args.kappa, |
|
|
78 |
'seed': args.seed, |
|
|
79 |
'batch_size': args.batch_size, |
|
|
80 |
|
|
|
81 |
'augment_genes': True if args.aug_gene_w > 0 else False, |
|
|
82 |
'n_sim_genes': args.n_sim_genes, |
|
|
83 |
'aug_gene_w': args.aug_gene_w, |
|
|
84 |
'aug_gene_by_deg': args.aug_gene_by_deg, |
|
|
85 |
|
|
|
86 |
'n_transformer_layers': args.n_transformer_layers, |
|
|
87 |
'n_transformer_heads': args.n_transformer_heads, |
|
|
88 |
|
|
|
89 |
# Fixed parameters |
|
|
90 |
'pos_weight': 1, |
|
|
91 |
'neg_weight': 20, |
|
|
92 |
'margin': 0.4, |
|
|
93 |
'thresh': 1, |
|
|
94 |
'filter_edges': False, |
|
|
95 |
'softmax_scale': 1, |
|
|
96 |
'leaky_relu': 0.1, |
|
|
97 |
'decoder_type': 'bilinear', |
|
|
98 |
'combined_training': True, |
|
|
99 |
'sample_from_gpd': True, |
|
|
100 |
'attention_type': 'bilinear', |
|
|
101 |
'n_cand_diseases': 1000, |
|
|
102 |
'test_n_cand_diseases': -1, |
|
|
103 |
'candidate_disease_type': 'all_kg_nodes', |
|
|
104 |
'patient_similarity_type': 'gene', # How we determine labels for similar patients in "Patients Like Me" |
|
|
105 |
'n_similar_patients': 2, # Number of patients with the same gene/disease that we add to the batch |
|
|
106 |
'only_hard_distractors': False, # Flag when true only uses the curated hard distractors at train time |
|
|
107 |
'sample_edges_from_train_patients': False, # Preferentially sample edges connected to training patients |
|
|
108 |
'gradclip': 1.0, |
|
|
109 |
'inference_batch_size': 64, |
|
|
110 |
'max_epochs': 100, |
|
|
111 |
'n_gpus': 1, |
|
|
112 |
'num_workers': 4, |
|
|
113 |
'wandb_save_dir' : project_config.PROJECT_DIR / 'wandb', |
|
|
114 |
'precision': 16, |
|
|
115 |
'reload_dataloaders_every_n_epochs': 0, |
|
|
116 |
'profiler': 'simple', |
|
|
117 |
'pin_memory': False, |
|
|
118 |
'time': False, |
|
|
119 |
'log_gpu_memory': True, |
|
|
120 |
'debug': False, |
|
|
121 |
'plot_softmax': False, |
|
|
122 |
'plot_intrain': False, # Flag to plot gene rank vs. in train sets |
|
|
123 |
'plot_PG_embed': False, # Flag to plot embeddings with phenotype and gene labels |
|
|
124 |
'plot_disease_embed': False, # Flag to plot embeddings with disease labels |
|
|
125 |
'plot_patient_embed': False, # Flag to plot embeddings for patients |
|
|
126 |
'plot_degree_rank': False, # Flag to plot degree vs. gene rank |
|
|
127 |
'plot_nhops_rank': False, # Flag to plot nhops vs. gene rank |
|
|
128 |
'plot_frac_rank': False, # Flag to plot fraction of ___ vs. gene rank |
|
|
129 |
'plot_gradients': False, # Flag to plot gradients |
|
|
130 |
'plot_attn_nhops': False, # Flag to plot attn weights vs. nhops |
|
|
131 |
'plot_phen_gene_sims': False, # Flag to plot phenotype-gene similarities |
|
|
132 |
'mrr_vs_percent_overlap': False, # Flag to plot MRR vs. percent overlap of phenotypes |
|
|
133 |
'saved_checkpoint_path': project_config.PROJECT_DIR / f'{args.saved_node_embeddings_path}', |
|
|
134 |
} |
|
|
135 |
|
|
|
136 |
# Get hyperparameters based on run type arguments |
|
|
137 |
hparams = get_run_type_args(args, hparams) |
|
|
138 |
|
|
|
139 |
# Get hyperparameters based on patient data arguments |
|
|
140 |
hparams = get_patient_data_args(args, hparams) |
|
|
141 |
|
|
|
142 |
print('Train hparams: ', hparams) |
|
|
143 |
|
|
|
144 |
return hparams |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
def get_run_type_args(args, hparams): |
|
|
148 |
if args.run_type == 'causal_gene_discovery': |
|
|
149 |
hparams.update({ |
|
|
150 |
'model_type': 'aligner', |
|
|
151 |
'loss': 'gene_multisimilarity', |
|
|
152 |
'use_diseases': False, |
|
|
153 |
'add_cand_diseases': False, |
|
|
154 |
'add_similar_patients': False, |
|
|
155 |
'wandb_project_name': 'causal-gene-discovery' |
|
|
156 |
}) |
|
|
157 |
elif args.run_type == 'disease_characterization': |
|
|
158 |
hparams.update({ |
|
|
159 |
'model_type': 'patient_NCA', |
|
|
160 |
'loss': 'patient_disease_NCA', |
|
|
161 |
'use_diseases': True, |
|
|
162 |
'add_cand_diseases': True , |
|
|
163 |
'add_similar_patients': False, |
|
|
164 |
'wandb_project_name': 'disease-heterogeneity', |
|
|
165 |
}) |
|
|
166 |
elif args.run_type == 'patients_like_me': |
|
|
167 |
hparams.update({ |
|
|
168 |
'model_type': 'patient_NCA', |
|
|
169 |
'loss': 'patient_patient_NCA', |
|
|
170 |
'use_diseases': False, |
|
|
171 |
'add_cand_diseases': False, |
|
|
172 |
'add_similar_patients': True, |
|
|
173 |
'wandb_project_name': 'patients-like-me', |
|
|
174 |
}) |
|
|
175 |
else: |
|
|
176 |
raise Exception('You must specify run type.') |
|
|
177 |
return hparams |
|
|
178 |
|
|
|
179 |
|
|
|
180 |
def get_patient_data_args(args, hparams): |
|
|
181 |
if args.patient_data == "disease_simulated": |
|
|
182 |
hparams.update({'train_data': f'simulated_patients/disease_split_train_sim_patients_{project_config.CURR_KG}.txt', |
|
|
183 |
'validation_data': f'simulated_patients/disease_split_val_sim_patients_{project_config.CURR_KG}.txt', |
|
|
184 |
'test_data': f'simulated_patients/disease_split_all_sim_patients_{project_config.CURR_KG}.txt', |
|
|
185 |
'spl': f'simulated_patients/disease_split_all_sim_patients_{project_config.CURR_KG}_spl_matrix.npy', |
|
|
186 |
'spl_index': f'simulated_patients/disease_split_all_sim_patients_{project_config.CURR_KG}_spl_index_dict.pkl' |
|
|
187 |
}) |
|
|
188 |
elif args.patient_data == "my_data": |
|
|
189 |
hparams.update({'train_data': project_config.MY_TRAIN_DATA, |
|
|
190 |
'validation_data': project_config.MY_VAL_DATA, |
|
|
191 |
'test_data': project_config.MY_TEST_DATA, |
|
|
192 |
'spl': project_config.MY_SPL_DATA, # Result of add_spl_to_patients.py (suffix: _spl_matrix.npy) |
|
|
193 |
'spl_index': project_config.MY_SPL_INDEX_DATA, # Result of add_spl_to_patients.py (suffix: _spl_index_dict.pkl) |
|
|
194 |
}) |
|
|
195 |
else: |
|
|
196 |
raise Exception('You must specify patient data.') |
|
|
197 |
return hparams |
|
|
198 |
|
|
|
199 |
|
|
|
200 |
|
|
|
201 |
#################################################################### |
|
|
202 |
# |
|
|
203 |
# PREDICTION HYPERPARAMETERS |
|
|
204 |
# |
|
|
205 |
#################################################################### |
|
|
206 |
|
|
|
207 |
|
|
|
208 |
def get_predict_hparams(args): |
|
|
209 |
hparams = { |
|
|
210 |
'seed': 33, |
|
|
211 |
'n_gpus': 0, # NOTE: currently predict scripts only work with CPU |
|
|
212 |
'num_workers': 4, |
|
|
213 |
'profiler': 'simple', |
|
|
214 |
'pin_memory': False, |
|
|
215 |
'time': False, |
|
|
216 |
'log_gpu_memory': False, |
|
|
217 |
'debug': False, |
|
|
218 |
|
|
|
219 |
'augment_genes': True, |
|
|
220 |
'n_sim_genes': 3, |
|
|
221 |
'aug_gene_w': 0.5, |
|
|
222 |
|
|
|
223 |
'wandb_save_dir' : project_config.PROJECT_DIR / 'wandb', |
|
|
224 |
'saved_checkpoint_path': project_config.PROJECT_DIR / f'{args.saved_node_embeddings_path}', |
|
|
225 |
'test_n_cand_diseases': -1, |
|
|
226 |
'candidate_disease_type': 'all_kg_nodes', |
|
|
227 |
'only_hard_distractors': False, # Flag when true only uses the curated hard distractors at train time |
|
|
228 |
'patient_similarity_type': 'gene', # How we determine labels for similar patients in "Patients Like Me" |
|
|
229 |
'n_similar_patients': 2, # (Patients Like Me only) Number of patients with the same gene/disease that we add to the batch |
|
|
230 |
} |
|
|
231 |
|
|
|
232 |
# Get hyperparameters based on run type arguments |
|
|
233 |
hparams = get_run_type_args(args, hparams) |
|
|
234 |
hparams.update({'add_similar_patients' : False}) |
|
|
235 |
hparams = get_patient_data_args(args, hparams) |
|
|
236 |
|
|
|
237 |
print('Predict hparams: ', hparams) |
|
|
238 |
|
|
|
239 |
return hparams |