import random
import numpy as np
import pandas as pd
import pickle
import sys
import torch
import time
import re
from torch.utils.data import Dataset
from collections import defaultdict
from project_utils import read_patients
import project_config
class PatientDataset(Dataset):
def __init__(self, filepath, gp_spl=None, raw_data=False, mondo_map_file=str(project_config.PROJECT_DIR / 'mondo_references.csv'), needs_disease_mapping=False, time=False):
self.filepath = filepath
self.patients = read_patients(filepath)
print('Dataset filepath: ', filepath)
print('Number of patients: ', len(self.patients))
# add placeholder for true genes/diseases if they don't exist
for patient in self.patients:
if 'true_genes' not in patient: patient['true_genes'] = []
if 'true_diseases' not in patient: patient['true_diseases'] = []
self.raw_data = raw_data
self.needs_disease_mapping = needs_disease_mapping
self.time = time
# create HPO to node_idx map
with open(project_config.KG_DIR / f'hpo_to_idx_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
self.hpo_to_idx_dict = pickle.load(handle)
with open(project_config.KG_DIR / f'hpo_to_name_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
self.hpo_to_name_dict = pickle.load(handle)
self.idx_to_hpo_dict = {v:self.hpo_to_name_dict[k] if k in self.hpo_to_name_dict else k for k, v in self.hpo_to_idx_dict.items()}
# create ensembl to node_idx map
# NOTE: assumes conversion from gene symbols to ensembl IDs has already occurred
with open(str(project_config.KG_DIR / f'ensembl_to_idx_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle:
self.ensembl_to_idx_dict = pickle.load(handle)
self.idx_to_ensembl_dict = {v:k for k, v in self.ensembl_to_idx_dict.items()}
# orphanet to mondo disease map
with open(str(project_config.PROJECT_DIR / 'preprocess' / 'orphanet' / 'orphanet_to_mondo_dict.pkl'), 'rb') as handle:
self.orpha_mondo_map = pickle.load(handle)
with open(project_config.KG_DIR / f'mondo_to_idx_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
self.disease_to_idx_dict = pickle.load(handle)
with open(project_config.KG_DIR / f'mondo_to_name_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
self.disease_to_name_dict = pickle.load(handle)
self.idx_to_disease_dict = {v:self.disease_to_name_dict[k] if k in self.disease_to_name_dict else k for k, v in self.disease_to_idx_dict.items()}
# degree dict from idx to degree - used for debugging
#NOTE: may need to subtract 1 to index into this dict
with open(str(project_config.KG_DIR / f'degree_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle:
self.degree_dict = pickle.load(handle)
# get patients with similar genes
if all(['true_genes' in patient for patient in self.patients]): # first check to make sure all patients have true genes
genes_to_patients = defaultdict(list)
for patient in self.patients:
for g in patient['true_genes']:
genes_to_patients[g].append(patient['id'])
self.patients_with_same_gene = defaultdict(list)
for patients in genes_to_patients.values():
for p in patients:
self.patients_with_same_gene[p].extend([pat for pat in patients if pat != p])
# get patients with similar diseases
if all(['true_diseases' in patient for patient in self.patients]): # first check to make sure all patients have true diseases
dis_to_patients = defaultdict(list)
for patient in self.patients:
patient_diseases = patient['true_diseases']
for d in patient_diseases:
dis_to_patients[d].append(patient['id'])
self.patients_with_same_disease = defaultdict(list)
for patients in dis_to_patients.values():
for p in patients:
self.patients_with_same_disease[p].extend([pat for pat in patients if pat != p])
# map from patient id to index in dataset
self.patient_id_to_index = {p['id']:i for i, p in enumerate(self.patients)}
print('Finished initalizing dataset')
def __len__(self):
'''
Returns the length of the dataset
'''
return len(self.patients)
def __getitem__(self, idx):
'''
Returns a single example from the dataset
'''
t0 = time.time()
patient = self.patients[idx]
additional_labels_dict = {}
if 'additional_labels' in patient:
for label, values in patient['additional_labels'].items():
if label == "n_hops_cand_g_p": continue
if values == None: values = [[-1]]
if type(values) != list: values = [[values]] # wrap in list if needed
if type(values) == list and (len(values)==0 or type(values[0]) != list): values = [values]
additional_labels_dict[label] = values
if 'max_percent_phen_overlap_train' not in patient['additional_labels']: additional_labels_dict['max_percent_phen_overlap_train'] = [[-1]]
if 'max_phen_overlap_train' not in patient['additional_labels']: additional_labels_dict['max_phen_overlap_train'] = [[-1]]
phenotype_node_idx = [self.hpo_to_idx_dict[p] for p in patient['positive_phenotypes'] if p in self.hpo_to_idx_dict ]
correct_genes_node_idx = [self.ensembl_to_idx_dict[g] for g in patient['true_genes'] if g in self.ensembl_to_idx_dict ]
if 'all_candidate_genes' in patient:
candidate_gene_node_idx = [self.ensembl_to_idx_dict[g] for g in patient['all_candidate_genes'] if g in self.ensembl_to_idx_dict ]
else: candidate_gene_node_idx = []
if 'true_diseases' in patient:
if self.needs_disease_mapping:
orpha_diseases = [ int(d) if len(re.match("^[0-9]*", d)[0]) > 0 else d for d in patient['true_diseases']]
mondo_diseases = [mondo_d for orpha_d in set(orpha_diseases).intersection(set(self.orpha_mondo_map.keys())) for mondo_d in self.orpha_mondo_map[orpha_d]]
else:
mondo_diseases = [str(d) for d in patient['true_diseases']]
disease_node_idx = [self.disease_to_idx_dict[d] for d in mondo_diseases if d in self.disease_to_idx_dict]
else: disease_node_idx = None
if not self.raw_data:
phenotype_node_idx = torch.LongTensor(phenotype_node_idx)
correct_genes_node_idx = torch.LongTensor(correct_genes_node_idx)
candidate_gene_node_idx = torch.LongTensor(candidate_gene_node_idx)
if 'true_diseases' in patient: disease_node_idx = torch.LongTensor(disease_node_idx)
assert len(phenotype_node_idx) >= 1, f'There are no phenotypes for patient: {patient}'
#NOTE: assumes that patient has a single causal/correct gene (the model still outputs a score for each candidate gene)
if not self.raw_data:
if len(correct_genes_node_idx) > 1:
#print('NOTE: The patient has multiple correct genes, but we\'re only selecting the first.')
correct_genes_node_idx = correct_genes_node_idx[0].unsqueeze(-1)
# get index of correct gene
if len(correct_genes_node_idx) == 0: # no correct genes available for the patient
label_idx = None
else:
if self.raw_data:
label_idx = [candidate_gene_node_idx.index(g) for g in correct_genes_node_idx]
else:
label_idx = (candidate_gene_node_idx == correct_genes_node_idx[0]).nonzero(as_tuple=True)[0]
if self.time:
t1 = time.time()
print(f'It takes {t1-t0:0.4f}s to get an item from the dataset')
return (phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, label_idx, additional_labels_dict, patient['id'])
def node_idx_to_degree(self, idx):
if idx in self.degree_dict:
return self.degree_dict[idx]
else:
return -1
def node_idx_to_name(self, idx):
if idx in self.idx_to_hpo_dict:
return self.idx_to_hpo_dict[idx]
elif idx in self.idx_to_ensembl_dict:
return self.idx_to_ensembl_dict[idx]
elif idx in self.idx_to_disease_dict:
return self.idx_to_disease_dict[idx]
elif idx == 0:
return 'padding'
else:
print("Exception on:", idx)
raise Exception
def get_similar_patients(self, patient_id, similarity_type='gene'):
if similarity_type == 'gene':
sim_pats = np.array(self.patients_with_same_gene[patient_id])
np.random.shuffle(sim_pats)
return sim_pats
elif similarity_type == 'disease':
sim_pats = np.array(self.patients_with_same_disease[patient_id])
np.random.shuffle(sim_pats)
return sim_pats
else:
raise NotImplementedError
def get_candidate_diseases(self, cand_type='all_kg_nodes'):
if cand_type == 'all_kg_nodes':
all_kg_diseases_idx = np.unique(list(self.disease_to_idx_dict.values())) # get idx of all diseases in KG
return torch.LongTensor(all_kg_diseases_idx)
elif cand_type == 'orphanet':
orpha_mondo_diseases = [d[0] for d in list(self.orpha_mondo_map.values())]
orpha_mondo_idx = np.unique([self.disease_to_idx_dict[d] for d in orpha_mondo_diseases if d in self.disease_to_idx_dict])
return torch.LongTensor(orpha_mondo_idx)
else:
raise NotImplementedError