[bdbb47]: / shepherd / dataset.py

Download this file

205 lines (167 with data), 10.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import random
import numpy as np
import pandas as pd
import pickle
import sys
import torch
import time
import re
from torch.utils.data import Dataset
from collections import defaultdict
from project_utils import read_patients
import project_config
class PatientDataset(Dataset):
def __init__(self, filepath, gp_spl=None, raw_data=False, mondo_map_file=str(project_config.PROJECT_DIR / 'mondo_references.csv'), needs_disease_mapping=False, time=False):
self.filepath = filepath
self.patients = read_patients(filepath)
print('Dataset filepath: ', filepath)
print('Number of patients: ', len(self.patients))
# add placeholder for true genes/diseases if they don't exist
for patient in self.patients:
if 'true_genes' not in patient: patient['true_genes'] = []
if 'true_diseases' not in patient: patient['true_diseases'] = []
self.raw_data = raw_data
self.needs_disease_mapping = needs_disease_mapping
self.time = time
# create HPO to node_idx map
with open(project_config.KG_DIR / f'hpo_to_idx_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
self.hpo_to_idx_dict = pickle.load(handle)
with open(project_config.KG_DIR / f'hpo_to_name_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
self.hpo_to_name_dict = pickle.load(handle)
self.idx_to_hpo_dict = {v:self.hpo_to_name_dict[k] if k in self.hpo_to_name_dict else k for k, v in self.hpo_to_idx_dict.items()}
# create ensembl to node_idx map
# NOTE: assumes conversion from gene symbols to ensembl IDs has already occurred
with open(str(project_config.KG_DIR / f'ensembl_to_idx_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle:
self.ensembl_to_idx_dict = pickle.load(handle)
self.idx_to_ensembl_dict = {v:k for k, v in self.ensembl_to_idx_dict.items()}
# orphanet to mondo disease map
with open(str(project_config.PROJECT_DIR / 'preprocess' / 'orphanet' / 'orphanet_to_mondo_dict.pkl'), 'rb') as handle:
self.orpha_mondo_map = pickle.load(handle)
with open(project_config.KG_DIR / f'mondo_to_idx_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
self.disease_to_idx_dict = pickle.load(handle)
with open(project_config.KG_DIR / f'mondo_to_name_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
self.disease_to_name_dict = pickle.load(handle)
self.idx_to_disease_dict = {v:self.disease_to_name_dict[k] if k in self.disease_to_name_dict else k for k, v in self.disease_to_idx_dict.items()}
# degree dict from idx to degree - used for debugging
#NOTE: may need to subtract 1 to index into this dict
with open(str(project_config.KG_DIR / f'degree_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle:
self.degree_dict = pickle.load(handle)
# get patients with similar genes
if all(['true_genes' in patient for patient in self.patients]): # first check to make sure all patients have true genes
genes_to_patients = defaultdict(list)
for patient in self.patients:
for g in patient['true_genes']:
genes_to_patients[g].append(patient['id'])
self.patients_with_same_gene = defaultdict(list)
for patients in genes_to_patients.values():
for p in patients:
self.patients_with_same_gene[p].extend([pat for pat in patients if pat != p])
# get patients with similar diseases
if all(['true_diseases' in patient for patient in self.patients]): # first check to make sure all patients have true diseases
dis_to_patients = defaultdict(list)
for patient in self.patients:
patient_diseases = patient['true_diseases']
for d in patient_diseases:
dis_to_patients[d].append(patient['id'])
self.patients_with_same_disease = defaultdict(list)
for patients in dis_to_patients.values():
for p in patients:
self.patients_with_same_disease[p].extend([pat for pat in patients if pat != p])
# map from patient id to index in dataset
self.patient_id_to_index = {p['id']:i for i, p in enumerate(self.patients)}
print('Finished initalizing dataset')
def __len__(self):
'''
Returns the length of the dataset
'''
return len(self.patients)
def __getitem__(self, idx):
'''
Returns a single example from the dataset
'''
t0 = time.time()
patient = self.patients[idx]
additional_labels_dict = {}
if 'additional_labels' in patient:
for label, values in patient['additional_labels'].items():
if label == "n_hops_cand_g_p": continue
if values == None: values = [[-1]]
if type(values) != list: values = [[values]] # wrap in list if needed
if type(values) == list and (len(values)==0 or type(values[0]) != list): values = [values]
additional_labels_dict[label] = values
if 'max_percent_phen_overlap_train' not in patient['additional_labels']: additional_labels_dict['max_percent_phen_overlap_train'] = [[-1]]
if 'max_phen_overlap_train' not in patient['additional_labels']: additional_labels_dict['max_phen_overlap_train'] = [[-1]]
phenotype_node_idx = [self.hpo_to_idx_dict[p] for p in patient['positive_phenotypes'] if p in self.hpo_to_idx_dict ]
correct_genes_node_idx = [self.ensembl_to_idx_dict[g] for g in patient['true_genes'] if g in self.ensembl_to_idx_dict ]
if 'all_candidate_genes' in patient:
candidate_gene_node_idx = [self.ensembl_to_idx_dict[g] for g in patient['all_candidate_genes'] if g in self.ensembl_to_idx_dict ]
else: candidate_gene_node_idx = []
if 'true_diseases' in patient:
if self.needs_disease_mapping:
orpha_diseases = [ int(d) if len(re.match("^[0-9]*", d)[0]) > 0 else d for d in patient['true_diseases']]
mondo_diseases = [mondo_d for orpha_d in set(orpha_diseases).intersection(set(self.orpha_mondo_map.keys())) for mondo_d in self.orpha_mondo_map[orpha_d]]
else:
mondo_diseases = [str(d) for d in patient['true_diseases']]
disease_node_idx = [self.disease_to_idx_dict[d] for d in mondo_diseases if d in self.disease_to_idx_dict]
else: disease_node_idx = None
if not self.raw_data:
phenotype_node_idx = torch.LongTensor(phenotype_node_idx)
correct_genes_node_idx = torch.LongTensor(correct_genes_node_idx)
candidate_gene_node_idx = torch.LongTensor(candidate_gene_node_idx)
if 'true_diseases' in patient: disease_node_idx = torch.LongTensor(disease_node_idx)
assert len(phenotype_node_idx) >= 1, f'There are no phenotypes for patient: {patient}'
#NOTE: assumes that patient has a single causal/correct gene (the model still outputs a score for each candidate gene)
if not self.raw_data:
if len(correct_genes_node_idx) > 1:
#print('NOTE: The patient has multiple correct genes, but we\'re only selecting the first.')
correct_genes_node_idx = correct_genes_node_idx[0].unsqueeze(-1)
# get index of correct gene
if len(correct_genes_node_idx) == 0: # no correct genes available for the patient
label_idx = None
else:
if self.raw_data:
label_idx = [candidate_gene_node_idx.index(g) for g in correct_genes_node_idx]
else:
label_idx = (candidate_gene_node_idx == correct_genes_node_idx[0]).nonzero(as_tuple=True)[0]
if self.time:
t1 = time.time()
print(f'It takes {t1-t0:0.4f}s to get an item from the dataset')
return (phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, label_idx, additional_labels_dict, patient['id'])
def node_idx_to_degree(self, idx):
if idx in self.degree_dict:
return self.degree_dict[idx]
else:
return -1
def node_idx_to_name(self, idx):
if idx in self.idx_to_hpo_dict:
return self.idx_to_hpo_dict[idx]
elif idx in self.idx_to_ensembl_dict:
return self.idx_to_ensembl_dict[idx]
elif idx in self.idx_to_disease_dict:
return self.idx_to_disease_dict[idx]
elif idx == 0:
return 'padding'
else:
print("Exception on:", idx)
raise Exception
def get_similar_patients(self, patient_id, similarity_type='gene'):
if similarity_type == 'gene':
sim_pats = np.array(self.patients_with_same_gene[patient_id])
np.random.shuffle(sim_pats)
return sim_pats
elif similarity_type == 'disease':
sim_pats = np.array(self.patients_with_same_disease[patient_id])
np.random.shuffle(sim_pats)
return sim_pats
else:
raise NotImplementedError
def get_candidate_diseases(self, cand_type='all_kg_nodes'):
if cand_type == 'all_kg_nodes':
all_kg_diseases_idx = np.unique(list(self.disease_to_idx_dict.values())) # get idx of all diseases in KG
return torch.LongTensor(all_kg_diseases_idx)
elif cand_type == 'orphanet':
orpha_mondo_diseases = [d[0] for d in list(self.orpha_mondo_map.values())]
orpha_mondo_idx = np.unique([self.disease_to_idx_dict[d] for d in orpha_mondo_diseases if d in self.disease_to_idx_dict])
return torch.LongTensor(orpha_mondo_idx)
else:
raise NotImplementedError