[db6163]: / data_prep / preprocess_patients_and_kg.py

Download this file

363 lines (275 with data), 21.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
import pandas as pd
import jsonlines
import networkx as nx
import snap
import obonet
import numpy as np
import re
import argparse
import random
import pickle
from tqdm import tqdm
from collections import defaultdict, Counter
from pathlib import Path
from copy import deepcopy
from itertools import combinations
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import matplotlib
matplotlib.use('Agg')
import sys
sys.path.insert(0, '..') # add config to path
import preprocess
import project_config
from project_utils import read_simulated_patients, write_patients
pd.options.mode.chained_assignment = None
# input locations
ORPHANET_METADATA_FILE = str(project_config.PROJECT_DIR / 'preprocess' / 'orphanet' / 'orphanet_final_disease_metadata.tsv')
MONDO_MAP_FILE = str(project_config.PROJECT_DIR / 'preprocess' / 'mondo' / 'mondo_references.csv')
MONDO_OBO_FILE = str(project_config.PROJECT_DIR / 'preprocess' / 'mondo' / 'mondo.obo')
HP_TERMS = project_config.PROJECT_DIR / 'preprocess' / 'hp_terms.csv'
MONDOTOHPO = project_config.PROJECT_DIR /'preprocess'/ 'mondo' / 'mondo2hpo.csv'
# output locations
ORPHANET_TO_MONDO_DICT = str(project_config.PROJECT_DIR / 'preprocess' / 'orphanet' / 'orphanet_to_mondo_dict.pkl')
HPO_TO_IDX_DICT_FILE = project_config.PROJECT_DIR / 'knowledge_graph' / project_config.CURR_KG / f'hpo_to_idx_dict_{project_config.CURR_KG}.pkl'
HPO_TO_NAME_DICT_FILE = project_config.PROJECT_DIR / 'knowledge_graph'/ project_config.CURR_KG / f'hpo_to_name_dict_{project_config.CURR_KG}.pkl'
ENSEMBL_TO_IDX_DICT_FILE = project_config.PROJECT_DIR / 'knowledge_graph'/ project_config.CURR_KG / f'ensembl_to_idx_dict_{project_config.CURR_KG}.pkl'
GENE_SYMBOL_TO_IDX_DICT_FILE = project_config.PROJECT_DIR / 'knowledge_graph'/ project_config.CURR_KG / f'gene_symbol_to_idx_dict_{project_config.CURR_KG}.pkl'
MONDO_TO_NAME_DICT_FILE = project_config.PROJECT_DIR / 'knowledge_graph'/ project_config.CURR_KG / f'mondo_to_name_dict_{project_config.CURR_KG}.pkl'
MONDO_TO_IDX_DICT_FILE = project_config.PROJECT_DIR / 'knowledge_graph'/ project_config.CURR_KG / f'mondo_to_idx_dict_{project_config.CURR_KG}.pkl'
# extracted from mondo.obo file
OBSOLETE_MONDO_DICT = {'MONDO:0008646':'MONDO:0100316', 'MONDO:0016021': 'MONDO:0100062',
'MONDO:0017125':'MONDO:0010261', 'MONDO:0010624':'MONDO:0100213', 'MONDO:0019863':'MONDO:0011812',
'MONDO:0019523':'MONDO:0000171', 'MONDO:0018275':'MONDO:0018274', 'MONDO:0010071':'MONDO:0011939', 'MONDO:0010195':'MONDO:0008490',
'MONDO:0008897':'MONDO:0100251', 'MONDO:0011127':'MONDO:0100344', 'MONDO:0009304':'MONDO:0012853',
'MONDO:0008480':'MONDO:0031169', 'MONDO:0009641':'MONDO:0100294', 'MONDO:0010119':'MONDO:0031332',
'MONDO:0010766':'MONDO:0100250', 'MONDO:0008646':'MONDO:0100316', 'MONDO:0018641':'MONDO:0100245',
'MONDO:0010272':'MONDO:0010327', 'MONDO:0012189':'MONDO:0018274', 'MONDO:0007926':'MONDO:0100280',
'MONDO:0008032':'MONDO:0012215', 'MONDO:0009739':'MONDO:0024457', 'MONDO:0010419':'MONDO:0020721',
'MONDO:0007291':'MONDO:0031037', 'MONDO:0009245':'MONDO:0100339'} # to do starting with MONDO:0018275
def read_data(args):
# read in KG nodes
node_df = pd.read_csv(project_config.PROJECT_DIR / 'knowledge_graph' / project_config.CURR_KG / args.node_map, sep='\t')
print(f'Unique node sources: {node_df["node_source"].unique()}')
print(f'Unique node types: {node_df["node_type"].unique()}')
node_type_dict = {idx:node_type for idx, node_type in zip(node_df['node_idx'], node_df['node_type'])}
# read in patients
sim_patients = read_simulated_patients(args.simulated_path)
print(f'Number of sim patients: {len(sim_patients)}')
# orphanet metadata
orphanet_metadata = pd.read_csv(ORPHANET_METADATA_FILE, sep='\t', dtype=str)
# orphanet to mondo map
mondo_map_df = pd.read_csv(MONDO_MAP_FILE, sep=',', index_col=0) #dtype=str
obsolete_mondo_dict = {int(re.sub('MONDO:0*', '', k)):int(re.sub('MONDO:0*', '', v)) for k,v in OBSOLETE_MONDO_DICT.items()}
mondo_map_df['mondo_id'] = mondo_map_df['mondo_id'].replace(obsolete_mondo_dict)
mondo_map_df.to_csv(project_config.PROJECT_DIR / 'mondo_references_normalized.csv', sep=',')
mondo_map_df = mondo_map_df.loc[mondo_map_df['ontology'] == 'Orphanet']
mondo_orphanet_map = {str(mondo_id):[int(v) for v in mondo_map_df.loc[mondo_map_df['mondo_id'] == mondo_id, 'ontology_id'].tolist()] for mondo_id in mondo_map_df['mondo_id'].unique().tolist() }
mondo_obo = obonet.read_obo(MONDO_OBO_FILE)
mondo_to_orphanet_obo_map = {node_id:[r for r in node['xref'] if r.startswith('Orphanet')] for node_id, node in list(mondo_obo.nodes(data=True)) if 'xref' in node}
mondo_to_orphanet_obo_map = {k.replace('MONDO:', ''): [int(v.replace('Orphanet:', '')) for v in vals] for k, vals in mondo_to_orphanet_obo_map.items() if len(vals) > 0 }
mondo_to_orphanet_obo_map = {re.split('^0*', k)[-1]:v for k,v in mondo_to_orphanet_obo_map.items()}
#merge two sources of mondo to orphanet mappings
missing_keys = set(list(mondo_to_orphanet_obo_map.keys())).difference(set(list(mondo_orphanet_map.keys())))
missing_keys2 = set(list(mondo_orphanet_map.keys())).difference(set(list(mondo_to_orphanet_obo_map.keys())))
overlapping_keys = set(list(mondo_to_orphanet_obo_map.keys())).intersection(set(list(mondo_orphanet_map.keys())))
print('\n ############ Retrieving mondo to orphanet maps ############')
print(f'There are {len(missing_keys)} missing mappings from the non-obo mondo to orphanet mapping')
print(f'There are {len(missing_keys2)} missing mappings from the obo mondo to orphanet mapping')
disagreement_keys = [(k, mondo_orphanet_map[k],mondo_to_orphanet_obo_map[k]) for k in overlapping_keys if len(set(mondo_orphanet_map[k]).intersection(set(mondo_to_orphanet_obo_map[k]))) == 0]
print(f'There is/are {len(disagreement_keys)} mapping(s from the two mondo dicts that don\'t agree with each other: {disagreement_keys}')
merged_mondo_to_orphanet_map = {k: list(set(mondo_orphanet_map[k]).union(set(mondo_to_orphanet_obo_map[k]))) for k in overlapping_keys if k not in disagreement_keys}
for k in missing_keys: merged_mondo_to_orphanet_map[k] = mondo_to_orphanet_obo_map[k]
for k in missing_keys2: merged_mondo_to_orphanet_map[k] = mondo_orphanet_map[k]
# create reverse - orphanet to mondo mapping
orphanet_to_mondo_dict = defaultdict(list)
for mondo, orphanet_list in merged_mondo_to_orphanet_map.items():
for orphanet_id in orphanet_list:
orphanet_to_mondo_dict[orphanet_id].append(mondo)
with open(ORPHANET_TO_MONDO_DICT, 'wb') as handle:
pickle.dump(orphanet_to_mondo_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
print('max number of mondo terms associated with an orphanet term: ', max([len(v) for k,v in orphanet_to_mondo_dict.items()]))
# read in mapping from old to current phenotypes
hp_terms = pd.read_csv(HP_TERMS)
hp_map_dict = {'HP:' + ('0' * (7-len(str(int(hp_old))))) + str(int(hp_old)): 'HP:' + '0' * (7-len(str(int(hp_new)))) + str(int(hp_new)) for hp_old,hp_new in zip(hp_terms['id'], hp_terms['replacement_id'] ) if not pd.isnull(hp_new)}
# read in mapping from mondo diseases to HPO phenotypes (this mapping occurs when a single entity is cross referenced by MONDO & HPO. In such cases we map to HPO)
mondo2hpo = pd.read_csv(MONDOTOHPO)
mondo_to_hpo_dict = {mondo:hpo for hpo,mondo in zip(mondo2hpo['ontology_id'], mondo2hpo['mondo_id'])}
return node_df, node_type_dict, sim_patients, orphanet_metadata, merged_mondo_to_orphanet_map, orphanet_to_mondo_dict, hp_map_dict, mondo_to_hpo_dict
def create_networkx_graph(edges):
G = nx.MultiDiGraph()
edge_index = list(zip(edges['x_idx'], edges['y_idx']))
G.add_edges_from(edge_index)
return G
###################################################################
# create maps from phenotype/gene to the idx in the KG
def create_hpo_to_node_idx_dict(node_df, hp_old_new_map):
# get HPO nodes
hpo_nodes = node_df.loc[node_df['node_type'] == 'effect/phenotype']
hpo_nodes['node_id'] = hpo_nodes['node_id'].astype(str)
# convert HPO id to string version (e.g. 1 -> HP:0000001)
HPO_LEN = 7
padding_needed = HPO_LEN - hpo_nodes['node_id'].str.len()
padded_hpo = padding_needed.apply(lambda x: 'HP:' + '0' * x)
hpo_nodes['hpo_string'] = padded_hpo + hpo_nodes['node_id']
# create dict from HPO ID to node index in graph
hpo_to_idx_dict = {hpo:idx for hpo, idx in zip(hpo_nodes['hpo_string'].tolist(), hpo_nodes['node_idx'].tolist())}
old_hpo_to_idx_dict = {old:hpo_to_idx_dict[new] for old, new in hp_old_new_map.items()}
hpo_to_idx_dict = {**hpo_to_idx_dict, **old_hpo_to_idx_dict}
hpo_to_name_dict = {hpo:name for hpo, name in zip(hpo_nodes['hpo_string'].tolist(), hpo_nodes['node_name'].tolist())}
# save to file
with open(HPO_TO_IDX_DICT_FILE, 'wb') as handle:
pickle.dump(hpo_to_idx_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(HPO_TO_NAME_DICT_FILE, 'wb') as handle:
pickle.dump(hpo_to_name_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
return hpo_to_idx_dict
def create_gene_to_node_idx_dict(args, node_df):
ensembl_node_map = Path(str(args.node_map).split('.txt')[0]+ '_ensembl_ids.txt')
if ensembl_node_map.exists():
node_df = pd.read_csv(ensembl_node_map, sep='\t')
else:
print('Generating ensembl_ids for KG')
preprocessor = preprocess.Preprocessor() #NOTE: raw data to perform preprocessing is missing from dataverse, but we provide the already processed files for our KG
# get gene nodes & map to ensembl IDs
gene_nodes = node_df.loc[node_df['node_type'] == 'gene/protein']
gene_nodes = preprocessor.map_genes(gene_nodes, ['node_name'])
gene_nodes = gene_nodes.rename(columns={'node_name_ensembl': 'node_name', 'node_name': 'gene_symbol'})
# merge gene names with the original node df
node_df['old_node_name'] = node_df['node_name']
node_df.loc[node_df['node_idx'].isin(gene_nodes['node_idx']), 'node_name'] = gene_nodes['node_name']
# save modified node df back to file
node_df.to_csv(ensembl_node_map, sep='\t')
# create gene to idx dict
gene_nodes = node_df.loc[node_df['node_type'] == 'gene/protein']
gene_symbol_to_idx_dict = {gene:idx for gene, idx in zip(gene_nodes['old_node_name'].tolist(), gene_nodes['node_idx'].tolist())}
ensembl_to_idx_dict = {gene:idx for gene, idx in zip(gene_nodes['node_name'].tolist(), gene_nodes['node_idx'].tolist())}
# save to file
with open(ENSEMBL_TO_IDX_DICT_FILE, 'wb') as handle:
pickle.dump(ensembl_to_idx_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(GENE_SYMBOL_TO_IDX_DICT_FILE, 'wb') as handle:
pickle.dump(gene_symbol_to_idx_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
return node_df, gene_symbol_to_idx_dict, ensembl_to_idx_dict
def create_mondo_to_node_idx_dict(node_df, mondo_to_hpo_dict):
'''create mondo disease to node_idx map'''
# get disease nodes
disease_nodes = node_df.loc[(node_df['node_type'] == 'disease')]
disease_nodes['node_id'] = disease_nodes['node_id'].str.replace('.0', '', regex=False)
mondo_strs = [str(mondo_str) for mondo, idx in zip(disease_nodes['node_id'].tolist(), disease_nodes['node_idx'].tolist()) for mondo_str in mondo.split('_')]
assert len(mondo_strs) == len(list(set(mondo_strs))), 'The following dict may overwrite some mappings if because there are duplicates in the mondo ids'
mondo_to_idx_dict = {str(mondo_str):idx for mondo, idx in zip(disease_nodes['node_id'].tolist(), disease_nodes['node_idx'].tolist()) for mondo_str in mondo.split('_')}
# get mapping from phenotypes to KG idx
phenotype_nodes = node_df.loc[node_df['node_type'] == 'effect/phenotype']
phen_to_idx_dict = {int(phen):idx for phen, idx in zip(phenotype_nodes['node_id'].tolist(), phenotype_nodes['node_idx'].tolist()) if int(phen) in mondo_to_hpo_dict.values()}
disease_mapped_phen_to_idx_dict = {str(mondo):phen_to_idx_dict[hpo] for mondo, hpo in mondo_to_hpo_dict.items()}
# merge two mappings
mondo_to_idx_dict = {**mondo_to_idx_dict, **disease_mapped_phen_to_idx_dict}
#TODO: this dict is missing some names from phenotype diseases. This needs to be fixed.
mondo_to_name_dict = {mondo_str:name for mondo, name in zip(disease_nodes['node_id'].tolist(), disease_nodes['node_name'].tolist()) for mondo_str in mondo.split('_')}
# save to file
with open(MONDO_TO_NAME_DICT_FILE, 'wb') as handle:
pickle.dump(mondo_to_name_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(MONDO_TO_IDX_DICT_FILE, 'wb') as handle:
pickle.dump(mondo_to_idx_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
return mondo_to_idx_dict
def map_diseases_to_orphanet(node_df, mondo_orphanet_map):
all_orphanet_ids = []
for node_id, node_type in zip(node_df['node_id'], node_df['node_type']):
if node_type == 'disease':
mondo_ids = node_id.split('_')
orphanet_ids = [mondo_orphanet_map[m] for m in mondo_ids if m in mondo_orphanet_map]
orphanet_ids = [str(o) for l in orphanet_ids for o in l]
if len(orphanet_ids) == 0: all_orphanet_ids.append(None)
else: all_orphanet_ids.append('_'.join(orphanet_ids)) #NOTE: some nodes that contain grouped MONDO ids are mapped to multiple orphanet ids
node_df['orphanet_node_id'] = pd.Series(all_orphanet_ids)
node_df.to_csv(project_config.KG_DIR / 'KG_node_map_ensembl_ids_orphanet_ids.txt', sep='\t') #TODO: check where we use this downstream
###################################################################
## split data into train/val/test
def filter_patients(patients, hpo_to_idx_dict, ensembl_to_idx_dict):
'''
Filter patients out of the dataset if their causal gene, all of their distractor genes, or all of
their phenotypes cannot be found in the KG.
'''
print(f'Number of patients pre-filtering: {len(patients)}')
filtered_patients = [p for p in patients if len(set(p['true_genes']).intersection(set(ensembl_to_idx_dict.keys()))) > 0]
print(f'Number of patients after filtering out those with no causal gene in the KG: {len(filtered_patients)}')
if 'distractor_genes' in filtered_patients[0]:
filtered_patients = [p for p in filtered_patients if len(set(p['distractor_genes']).intersection(set(ensembl_to_idx_dict.keys()))) > 0]
print(f'Number of patients after filtering out those with no distractor genes in the KG: {len(filtered_patients)}')
filtered_patients = [p for p in filtered_patients if len(set(p['positive_phenotypes']).intersection(set(hpo_to_idx_dict.keys()))) > 0]
print(f'Number of patients after filtering out those with no phenotypes in the KG: {len(filtered_patients)}')
return filtered_patients
def create_dataset_split_from_lists(filtered_patients, train_list_f, val_list_f):
train_list = pd.read_csv(project_config.PROJECT_DIR / 'formatted_patients' / train_list_f, index_col=0)['ids'].tolist()
val_list = pd.read_csv(project_config.PROJECT_DIR / 'formatted_patients' / val_list_f, index_col=0)['ids'].tolist()
train_patients, val_patients, unsorted_patients = [], [], []
for patient in filtered_patients:
if patient['id'] in train_list: rand_train_patients.append(patient)
elif patient['id'] in val_list: rand_val_patients.append(patient)
else: unsorted_patients.append(patient)
print(f'There are {len(train_patients)} patients in the train set and {len(val_patients)} in the val set.')
print(f'There are {len(unsorted_patients)} unsorted patients.')
return train_patients, val_patients
def create_disease_split_dataset(filtered_patients, frac_train=0.7, frac_val_test=0.15):
# divide patients by disease ID into train/val/test
diseases = list(set([p['disease_id'] for p in filtered_patients]))
n_train = round(len(diseases) * frac_train)
n_val_test = round(len(diseases) * frac_val_test)
dx_train_patients = diseases[0:n_train]
dx_val_patients = diseases[n_train:n_val_test+n_train]
dx_test_patients = diseases[n_val_test+n_train:]
print('Split of diseases into train/val/test: ', len(dx_train_patients), len(dx_val_patients), len(dx_test_patients))
dx_split_train_patients = [p for p in filtered_patients if p['disease_id'] in dx_train_patients]
dx_split_val_patients = [p for p in filtered_patients if p['disease_id'] in dx_val_patients]
dx_split_test_patients = [p for p in filtered_patients if p['disease_id'] in dx_test_patients]
dx_split_train_patient_ids = pd.DataFrame({'ids':[p['id'] for p in dx_split_train_patients]})
dx_split_val_patient_ids = pd.DataFrame({'ids':[p['id'] for p in dx_split_val_patients]})
dx_split_test_patient_ids = pd.DataFrame({'ids':[p['id'] for p in dx_split_test_patients]})
#NOTE: we decided to merge the train & test sets into a single larger train set to be able to train on more diseases. We are posthoc merging to keep the code as was originally written.
dx_split_train_patient_ids = pd.concat([]dx_split_train_patient_ids, dx_split_test_patient_ids)
dx_split_train_patients = dx_split_train_patients + dx_split_test_patient_ids
print(f'There are {len(dx_split_train_patients)} patients in the disease split train set and {len(dx_split_val_patients)} in the val set.')
return dx_split_train_patients, dx_split_val_patients, dx_split_train_patient_ids, dx_split_val_patient_ids
###################################################################
## main
'''
python preprocess_patients.py \
-split_dataset
'''
def main():
parser = argparse.ArgumentParser(description="Preprocessing Patients & KG.")
parser.add_argument("-edgelist", type=str, default=f'KG_edgelist_mask.txt', help="File with edge list")
parser.add_argument("-node_map", type=str, default=f'KG_node_map.txt', help="File with node list")
parser.add_argument("-simulated_path", type=str, default=f'{project_config.PROJECT_DIR}/patients/simulated_patients/simulated_patients_formatted.jsonl', help="Path to simulated patients")
parser.add_argument("-split_dataset", action='store_true', help="Split patient datasets into train/val/test.")
parser.add_argument("-split_dataset_from_lists", action='store_true', help='Whether the train/val/test split IDs should be read from file.')
args = parser.parse_args()
## read in data, normalize genes to ensembl ids, and create maps from genes/phenotypes to node idx
node_df, node_type_dict, sim_patients, orphanet_metadata, mondo_orphanet_map, orphanet_mondo_map, hp_map_dict, mondo_to_hpo_dict = read_data(args)
hpo_to_idx_dict = create_hpo_to_node_idx_dict(node_df, hp_map_dict)
node_df, gene_symbol_to_idx_dict, ensembl_to_idx_dict = create_gene_to_node_idx_dict(args,node_df)
mondo_to_node_idx_dict = create_mondo_to_node_idx_dict(node_df, mondo_to_hpo_dict)
map_diseases_to_orphanet(node_df, mondo_orphanet_map)
edges = pd.read_csv(project_config.KG_DIR / args.edgelist, sep="\t")
graph = create_networkx_graph(edges)
snap_graph = snap.LoadEdgeList(snap.TUNGraph, str(project_config.KG_DIR / args.edgelist), 0, 1, '\t')
# filter patients to remove those with no causal gene, no distractor genes, or no phenotypes
filtered_sim_patients = filter_patients(sim_patients, hpo_to_idx_dict, ensembl_to_idx_dict)
# write patients to file
write_patients(filtered_sim_patients, project_config.PROJECT_DIR / 'patients' / 'simulated_patients' /f'all_sim_patients_kg_{project_config.CURR_KG}.txt')
if args.split_dataset:
## filter patients & split into train/val/test
if args.split_dataset_from_lists:
dx_split_train_patients, dx_split_val_patients = create_dataset_split_from_lists(filtered_sim_patients,
f'simulated_patients/disease_split_train_sim_patients_kg_{project_config.CURR_KG}_patient_ids.csv',
f'simulated_patients/disease_split_val_sim_patients_kg_{project_config.CURR_KG}_patient_ids.csv'
)
else:
dx_split_train_patients, dx_split_val_patients, dx_split_train_patient_ids, dx_split_val_patient_ids = create_disease_split_dataset(filtered_sim_patients)
## Save to file
if not args.create_train_val_test_from_lists:
dx_split_train_patient_ids.to_csv(project_config.PROJECT_DIR / 'patients' / f'simulated_patients'/ f'disease_split_train_sim_patients_kg_{project_config.CURR_KG}_patient_ids.csv')
dx_split_val_patient_ids.to_csv(project_config.PROJECT_DIR / 'patients' / f'simulated_patients'/ f'disease_split_val_sim_patients_kg_{project_config.CURR_KG}_patient_ids.csv')
write_patients(dx_split_train_patients, project_config.PROJECT_DIR / 'patients' / 'simulated_patients'/ f'disease_split_train_sim_patients_kg_{project_config.CURR_KG}.txt')
write_patients(dx_split_val_patients, project_config.PROJECT_DIR / 'patients' / 'simulated_patients'/ f'disease_split_val_sim_patients_kg_{project_config.CURR_KG}.txt')
if __name__ == "__main__":
main()