a b/shepherd/hparams.py
1
import project_config
2
3
####################################################################
4
#
5
# NODE EMBEDDER MODEL HYPERPARAMETERS
6
#
7
####################################################################
8
9
def get_pretrain_hparams(args, combined=False):    
10
    print('node embedder args: ', args)
11
12
    # Default
13
    hparams = {
14
               # Tunable parameters
15
               'nfeat': args.nfeat if not combined else 4096,
16
               'hidden': args.hidden if not combined else 256,
17
               'output': args.output if not combined else 128,
18
               'n_heads': args.n_heads if not combined else 2,
19
               'wd': args.wd if not combined else 5e-4,
20
               'dropout': args.dropout if not combined else 0.2,
21
               'lr': args.lr if not combined else 0.0001,
22
23
               # Fixed parameters
24
               'decoder_type': 'bilinear',
25
               'norm_method': "batch_layer",
26
               'loss': 'max-margin',
27
               'pred_threshold': 0.5,
28
               'negative_sampler_approach': 'by_edge_type',
29
               'filter_edges': True,
30
               'n_gpus': 1,
31
               'num_workers': 4,
32
               'batch_size': 512,
33
               'inference_batch_size': 64,
34
               'neighbor_sampler_sizes': [15, 10, 5],
35
               'max_epochs': 200,
36
               'gradclip': 1.0,
37
               'lr_factor': 0.01,
38
               'lr_patience': 1000,
39
               'lr_threshold': 1e-4,
40
               'lr_threshold_mode': 'rel',
41
               'lr_cooldown': 0,
42
               'min_lr': 0,
43
               'eps': 1e-8,
44
               'seed': 1,
45
               'profiler': None,
46
               'wandb_save_dir': project_config.PROJECT_DIR / 'wandb' / 'preprocess',
47
               'log_every_n_steps': 10,
48
               'time': False,
49
               'debug': False
50
        }
51
    
52
    print('Pretrain hparams: ', hparams)
53
    
54
    return hparams
55
56
57
58
####################################################################
59
#
60
# TRAIN MODEL HYPERPARAMETERS
61
#
62
####################################################################
63
64
65
def get_train_hparams(args):
66
    print('Train model args: ', args)
67
68
    # Default
69
    hparams = {
70
               # Tunable parameters
71
               'sparse_sample': args.sparse_sample, # Randomly sample N nodes from KG
72
               'lr': args.lr,
73
               'upsample_cand': args.upsample_cand, 
74
               'neighbor_sampler_sizes': [args.neighbor_sampler_size, 10, 5],
75
               'lambda': args.lmbda, # Contribution of two loss functions
76
               'alpha': args.alpha, # Contribution of GP gate. NOTE: This is not used for patients-like-me or novel disease characterization
77
               'kappa': (1 - args.lmbda) * args.kappa,
78
               'seed': args.seed,
79
               'batch_size': args.batch_size,
80
               
81
               'augment_genes': True if args.aug_gene_w > 0 else False,
82
               'n_sim_genes': args.n_sim_genes,
83
               'aug_gene_w': args.aug_gene_w,
84
               'aug_gene_by_deg': args.aug_gene_by_deg,
85
86
               'n_transformer_layers': args.n_transformer_layers,
87
               'n_transformer_heads': args.n_transformer_heads,
88
               
89
               # Fixed parameters
90
               'pos_weight': 1,
91
               'neg_weight': 20,
92
               'margin': 0.4,
93
               'thresh': 1,
94
               'filter_edges': False,
95
               'softmax_scale': 1,
96
               'leaky_relu': 0.1,
97
               'decoder_type': 'bilinear',
98
               'combined_training': True,
99
               'sample_from_gpd': True,
100
               'attention_type': 'bilinear',
101
               'n_cand_diseases': 1000,
102
               'test_n_cand_diseases': -1, 
103
               'candidate_disease_type': 'all_kg_nodes',
104
               'patient_similarity_type': 'gene', # How we determine labels for similar patients in "Patients Like Me"
105
               'n_similar_patients': 2, # Number of patients with the same gene/disease that we add to the batch
106
               'only_hard_distractors': False, # Flag when true only uses the curated hard distractors at train time
107
               'sample_edges_from_train_patients': False, # Preferentially sample edges connected to training patients
108
               'gradclip': 1.0,
109
               'inference_batch_size': 64,
110
               'max_epochs': 100, 
111
               'n_gpus': 1, 
112
               'num_workers': 4,
113
               'wandb_save_dir' : project_config.PROJECT_DIR / 'wandb',
114
               'precision': 16, 
115
               'reload_dataloaders_every_n_epochs': 0,
116
               'profiler': 'simple',
117
               'pin_memory': False,
118
               'time': False,
119
               'log_gpu_memory': True,
120
               'debug': False, 
121
               'plot_softmax': False,
122
               'plot_intrain': False, # Flag to plot gene rank vs. in train sets
123
               'plot_PG_embed': False, # Flag to plot embeddings with phenotype and gene labels
124
               'plot_disease_embed': False, # Flag to plot embeddings with disease labels
125
               'plot_patient_embed': False, # Flag to plot embeddings for patients
126
               'plot_degree_rank': False, # Flag to plot degree vs. gene rank
127
               'plot_nhops_rank': False, # Flag to plot nhops vs. gene rank
128
               'plot_frac_rank': False, # Flag to plot fraction of ___ vs. gene rank
129
               'plot_gradients': False, # Flag to plot gradients
130
               'plot_attn_nhops': False, # Flag to plot attn weights vs. nhops
131
               'plot_phen_gene_sims': False, # Flag to plot phenotype-gene similarities
132
               'mrr_vs_percent_overlap': False, # Flag to plot MRR vs. percent overlap of phenotypes
133
               'saved_checkpoint_path': project_config.PROJECT_DIR  / f'{args.saved_node_embeddings_path}', 
134
    }
135
136
    # Get hyperparameters based on run type arguments
137
    hparams = get_run_type_args(args, hparams)
138
139
    # Get hyperparameters based on patient data arguments
140
    hparams = get_patient_data_args(args, hparams)
141
142
    print('Train hparams: ', hparams)
143
144
    return hparams
145
146
147
def get_run_type_args(args, hparams):
148
    if args.run_type == 'causal_gene_discovery':
149
        hparams.update({
150
                        'model_type': 'aligner', 
151
                        'loss': 'gene_multisimilarity', 
152
                        'use_diseases': False,
153
                        'add_cand_diseases': False,
154
                        'add_similar_patients': False,
155
                        'wandb_project_name': 'causal-gene-discovery'
156
                       })
157
    elif args.run_type == 'disease_characterization':
158
        hparams.update({
159
                        'model_type': 'patient_NCA',
160
                        'loss': 'patient_disease_NCA',
161
                        'use_diseases': True,
162
                        'add_cand_diseases': True ,
163
                        'add_similar_patients': False,
164
                        'wandb_project_name': 'disease-heterogeneity',
165
                       })
166
    elif args.run_type == 'patients_like_me':
167
        hparams.update({
168
                        'model_type': 'patient_NCA',
169
                        'loss': 'patient_patient_NCA',
170
                        'use_diseases': False,
171
                        'add_cand_diseases': False,
172
                        'add_similar_patients': True,
173
                        'wandb_project_name': 'patients-like-me',
174
                       })
175
    else:
176
        raise Exception('You must specify run type.')
177
    return hparams
178
179
180
def get_patient_data_args(args, hparams):
181
    if args.patient_data == "disease_simulated":
182
        hparams.update({'train_data': f'simulated_patients/disease_split_train_sim_patients_{project_config.CURR_KG}.txt',
183
                        'validation_data': f'simulated_patients/disease_split_val_sim_patients_{project_config.CURR_KG}.txt', 
184
                        'test_data': f'simulated_patients/disease_split_all_sim_patients_{project_config.CURR_KG}.txt',
185
                        'spl': f'simulated_patients/disease_split_all_sim_patients_{project_config.CURR_KG}_spl_matrix.npy',
186
                        'spl_index': f'simulated_patients/disease_split_all_sim_patients_{project_config.CURR_KG}_spl_index_dict.pkl'
187
                        })
188
    elif args.patient_data == "my_data":
189
        hparams.update({'train_data': project_config.MY_TRAIN_DATA,
190
                        'validation_data': project_config.MY_VAL_DATA,
191
                        'test_data': project_config.MY_TEST_DATA,
192
                        'spl': project_config.MY_SPL_DATA, # Result of add_spl_to_patients.py (suffix: _spl_matrix.npy)
193
                        'spl_index': project_config.MY_SPL_INDEX_DATA, # Result of add_spl_to_patients.py (suffix: _spl_index_dict.pkl)
194
                        })
195
    else:
196
        raise Exception('You must specify patient data.')
197
    return hparams
198
199
200
201
####################################################################
202
#
203
# PREDICTION HYPERPARAMETERS
204
#
205
####################################################################
206
207
208
def get_predict_hparams(args):
209
    hparams = {
210
               'seed': 33,
211
               'n_gpus': 0, # NOTE: currently predict scripts only work with CPU
212
               'num_workers': 4, 
213
               'profiler': 'simple',
214
               'pin_memory': False,
215
               'time': False,
216
               'log_gpu_memory': False,
217
               'debug': False,
218
219
               'augment_genes': True,
220
               'n_sim_genes': 3,
221
               'aug_gene_w': 0.5,
222
223
               'wandb_save_dir' : project_config.PROJECT_DIR / 'wandb',
224
               'saved_checkpoint_path': project_config.PROJECT_DIR  / f'{args.saved_node_embeddings_path}',
225
               'test_n_cand_diseases': -1, 
226
               'candidate_disease_type': 'all_kg_nodes', 
227
               'only_hard_distractors': False, # Flag when true only uses the curated hard distractors at train time
228
               'patient_similarity_type': 'gene', # How we determine labels for similar patients in "Patients Like Me"
229
               'n_similar_patients': 2, # (Patients Like Me only) Number of patients with the same gene/disease that we add to the batch
230
    }
231
232
    # Get hyperparameters based on run type arguments
233
    hparams = get_run_type_args(args, hparams)    
234
    hparams.update({'add_similar_patients' : False})
235
    hparams = get_patient_data_args(args, hparams)
236
237
    print('Predict hparams: ', hparams)
238
239
    return hparams