|
a |
|
b/shepherd/dataset.py |
|
|
1 |
|
|
|
2 |
|
|
|
3 |
import random |
|
|
4 |
import numpy as np |
|
|
5 |
import pandas as pd |
|
|
6 |
import pickle |
|
|
7 |
import sys |
|
|
8 |
import torch |
|
|
9 |
import time |
|
|
10 |
import re |
|
|
11 |
from torch.utils.data import Dataset |
|
|
12 |
from collections import defaultdict |
|
|
13 |
|
|
|
14 |
from project_utils import read_patients |
|
|
15 |
import project_config |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
class PatientDataset(Dataset): |
|
|
19 |
|
|
|
20 |
def __init__(self, filepath, gp_spl=None, raw_data=False, mondo_map_file=str(project_config.PROJECT_DIR / 'mondo_references.csv'), needs_disease_mapping=False, time=False): |
|
|
21 |
self.filepath = filepath |
|
|
22 |
self.patients = read_patients(filepath) |
|
|
23 |
print('Dataset filepath: ', filepath) |
|
|
24 |
print('Number of patients: ', len(self.patients)) |
|
|
25 |
|
|
|
26 |
# add placeholder for true genes/diseases if they don't exist |
|
|
27 |
for patient in self.patients: |
|
|
28 |
if 'true_genes' not in patient: patient['true_genes'] = [] |
|
|
29 |
if 'true_diseases' not in patient: patient['true_diseases'] = [] |
|
|
30 |
|
|
|
31 |
self.raw_data = raw_data |
|
|
32 |
self.needs_disease_mapping = needs_disease_mapping |
|
|
33 |
self.time = time |
|
|
34 |
|
|
|
35 |
# create HPO to node_idx map |
|
|
36 |
with open(project_config.KG_DIR / f'hpo_to_idx_dict_{project_config.CURR_KG}.pkl', 'rb') as handle: |
|
|
37 |
self.hpo_to_idx_dict = pickle.load(handle) |
|
|
38 |
with open(project_config.KG_DIR / f'hpo_to_name_dict_{project_config.CURR_KG}.pkl', 'rb') as handle: |
|
|
39 |
self.hpo_to_name_dict = pickle.load(handle) |
|
|
40 |
self.idx_to_hpo_dict = {v:self.hpo_to_name_dict[k] if k in self.hpo_to_name_dict else k for k, v in self.hpo_to_idx_dict.items()} |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
# create ensembl to node_idx map |
|
|
44 |
# NOTE: assumes conversion from gene symbols to ensembl IDs has already occurred |
|
|
45 |
with open(str(project_config.KG_DIR / f'ensembl_to_idx_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle: |
|
|
46 |
self.ensembl_to_idx_dict = pickle.load(handle) |
|
|
47 |
self.idx_to_ensembl_dict = {v:k for k, v in self.ensembl_to_idx_dict.items()} |
|
|
48 |
|
|
|
49 |
# orphanet to mondo disease map |
|
|
50 |
with open(str(project_config.PROJECT_DIR / 'preprocess' / 'orphanet' / 'orphanet_to_mondo_dict.pkl'), 'rb') as handle: |
|
|
51 |
self.orpha_mondo_map = pickle.load(handle) |
|
|
52 |
|
|
|
53 |
with open(project_config.KG_DIR / f'mondo_to_idx_dict_{project_config.CURR_KG}.pkl', 'rb') as handle: |
|
|
54 |
self.disease_to_idx_dict = pickle.load(handle) |
|
|
55 |
with open(project_config.KG_DIR / f'mondo_to_name_dict_{project_config.CURR_KG}.pkl', 'rb') as handle: |
|
|
56 |
self.disease_to_name_dict = pickle.load(handle) |
|
|
57 |
self.idx_to_disease_dict = {v:self.disease_to_name_dict[k] if k in self.disease_to_name_dict else k for k, v in self.disease_to_idx_dict.items()} |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
# degree dict from idx to degree - used for debugging |
|
|
61 |
#NOTE: may need to subtract 1 to index into this dict |
|
|
62 |
with open(str(project_config.KG_DIR / f'degree_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle: |
|
|
63 |
self.degree_dict = pickle.load(handle) |
|
|
64 |
|
|
|
65 |
# get patients with similar genes |
|
|
66 |
if all(['true_genes' in patient for patient in self.patients]): # first check to make sure all patients have true genes |
|
|
67 |
genes_to_patients = defaultdict(list) |
|
|
68 |
for patient in self.patients: |
|
|
69 |
for g in patient['true_genes']: |
|
|
70 |
genes_to_patients[g].append(patient['id']) |
|
|
71 |
self.patients_with_same_gene = defaultdict(list) |
|
|
72 |
for patients in genes_to_patients.values(): |
|
|
73 |
for p in patients: |
|
|
74 |
self.patients_with_same_gene[p].extend([pat for pat in patients if pat != p]) |
|
|
75 |
|
|
|
76 |
# get patients with similar diseases |
|
|
77 |
if all(['true_diseases' in patient for patient in self.patients]): # first check to make sure all patients have true diseases |
|
|
78 |
dis_to_patients = defaultdict(list) |
|
|
79 |
for patient in self.patients: |
|
|
80 |
patient_diseases = patient['true_diseases'] |
|
|
81 |
for d in patient_diseases: |
|
|
82 |
dis_to_patients[d].append(patient['id']) |
|
|
83 |
self.patients_with_same_disease = defaultdict(list) |
|
|
84 |
for patients in dis_to_patients.values(): |
|
|
85 |
for p in patients: |
|
|
86 |
self.patients_with_same_disease[p].extend([pat for pat in patients if pat != p]) |
|
|
87 |
|
|
|
88 |
# map from patient id to index in dataset |
|
|
89 |
self.patient_id_to_index = {p['id']:i for i, p in enumerate(self.patients)} |
|
|
90 |
|
|
|
91 |
print('Finished initalizing dataset') |
|
|
92 |
|
|
|
93 |
|
|
|
94 |
def __len__(self): |
|
|
95 |
''' |
|
|
96 |
Returns the length of the dataset |
|
|
97 |
''' |
|
|
98 |
return len(self.patients) |
|
|
99 |
|
|
|
100 |
def __getitem__(self, idx): |
|
|
101 |
''' |
|
|
102 |
Returns a single example from the dataset |
|
|
103 |
''' |
|
|
104 |
t0 = time.time() |
|
|
105 |
patient = self.patients[idx] |
|
|
106 |
|
|
|
107 |
additional_labels_dict = {} |
|
|
108 |
if 'additional_labels' in patient: |
|
|
109 |
for label, values in patient['additional_labels'].items(): |
|
|
110 |
if label == "n_hops_cand_g_p": continue |
|
|
111 |
if values == None: values = [[-1]] |
|
|
112 |
if type(values) != list: values = [[values]] # wrap in list if needed |
|
|
113 |
if type(values) == list and (len(values)==0 or type(values[0]) != list): values = [values] |
|
|
114 |
additional_labels_dict[label] = values |
|
|
115 |
if 'max_percent_phen_overlap_train' not in patient['additional_labels']: additional_labels_dict['max_percent_phen_overlap_train'] = [[-1]] |
|
|
116 |
if 'max_phen_overlap_train' not in patient['additional_labels']: additional_labels_dict['max_phen_overlap_train'] = [[-1]] |
|
|
117 |
|
|
|
118 |
phenotype_node_idx = [self.hpo_to_idx_dict[p] for p in patient['positive_phenotypes'] if p in self.hpo_to_idx_dict ] |
|
|
119 |
correct_genes_node_idx = [self.ensembl_to_idx_dict[g] for g in patient['true_genes'] if g in self.ensembl_to_idx_dict ] |
|
|
120 |
if 'all_candidate_genes' in patient: |
|
|
121 |
candidate_gene_node_idx = [self.ensembl_to_idx_dict[g] for g in patient['all_candidate_genes'] if g in self.ensembl_to_idx_dict ] |
|
|
122 |
else: candidate_gene_node_idx = [] |
|
|
123 |
|
|
|
124 |
if 'true_diseases' in patient: |
|
|
125 |
if self.needs_disease_mapping: |
|
|
126 |
orpha_diseases = [ int(d) if len(re.match("^[0-9]*", d)[0]) > 0 else d for d in patient['true_diseases']] |
|
|
127 |
mondo_diseases = [mondo_d for orpha_d in set(orpha_diseases).intersection(set(self.orpha_mondo_map.keys())) for mondo_d in self.orpha_mondo_map[orpha_d]] |
|
|
128 |
else: |
|
|
129 |
mondo_diseases = [str(d) for d in patient['true_diseases']] |
|
|
130 |
disease_node_idx = [self.disease_to_idx_dict[d] for d in mondo_diseases if d in self.disease_to_idx_dict] |
|
|
131 |
else: disease_node_idx = None |
|
|
132 |
|
|
|
133 |
if not self.raw_data: |
|
|
134 |
phenotype_node_idx = torch.LongTensor(phenotype_node_idx) |
|
|
135 |
correct_genes_node_idx = torch.LongTensor(correct_genes_node_idx) |
|
|
136 |
candidate_gene_node_idx = torch.LongTensor(candidate_gene_node_idx) |
|
|
137 |
if 'true_diseases' in patient: disease_node_idx = torch.LongTensor(disease_node_idx) |
|
|
138 |
|
|
|
139 |
|
|
|
140 |
assert len(phenotype_node_idx) >= 1, f'There are no phenotypes for patient: {patient}' |
|
|
141 |
|
|
|
142 |
#NOTE: assumes that patient has a single causal/correct gene (the model still outputs a score for each candidate gene) |
|
|
143 |
if not self.raw_data: |
|
|
144 |
if len(correct_genes_node_idx) > 1: |
|
|
145 |
#print('NOTE: The patient has multiple correct genes, but we\'re only selecting the first.') |
|
|
146 |
correct_genes_node_idx = correct_genes_node_idx[0].unsqueeze(-1) |
|
|
147 |
|
|
|
148 |
# get index of correct gene |
|
|
149 |
if len(correct_genes_node_idx) == 0: # no correct genes available for the patient |
|
|
150 |
label_idx = None |
|
|
151 |
else: |
|
|
152 |
if self.raw_data: |
|
|
153 |
label_idx = [candidate_gene_node_idx.index(g) for g in correct_genes_node_idx] |
|
|
154 |
else: |
|
|
155 |
label_idx = (candidate_gene_node_idx == correct_genes_node_idx[0]).nonzero(as_tuple=True)[0] |
|
|
156 |
|
|
|
157 |
if self.time: |
|
|
158 |
t1 = time.time() |
|
|
159 |
print(f'It takes {t1-t0:0.4f}s to get an item from the dataset') |
|
|
160 |
|
|
|
161 |
return (phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, label_idx, additional_labels_dict, patient['id']) |
|
|
162 |
|
|
|
163 |
def node_idx_to_degree(self, idx): |
|
|
164 |
if idx in self.degree_dict: |
|
|
165 |
return self.degree_dict[idx] |
|
|
166 |
else: |
|
|
167 |
return -1 |
|
|
168 |
|
|
|
169 |
def node_idx_to_name(self, idx): |
|
|
170 |
if idx in self.idx_to_hpo_dict: |
|
|
171 |
return self.idx_to_hpo_dict[idx] |
|
|
172 |
elif idx in self.idx_to_ensembl_dict: |
|
|
173 |
return self.idx_to_ensembl_dict[idx] |
|
|
174 |
elif idx in self.idx_to_disease_dict: |
|
|
175 |
return self.idx_to_disease_dict[idx] |
|
|
176 |
elif idx == 0: |
|
|
177 |
return 'padding' |
|
|
178 |
else: |
|
|
179 |
print("Exception on:", idx) |
|
|
180 |
raise Exception |
|
|
181 |
|
|
|
182 |
def get_similar_patients(self, patient_id, similarity_type='gene'): |
|
|
183 |
if similarity_type == 'gene': |
|
|
184 |
sim_pats = np.array(self.patients_with_same_gene[patient_id]) |
|
|
185 |
np.random.shuffle(sim_pats) |
|
|
186 |
return sim_pats |
|
|
187 |
elif similarity_type == 'disease': |
|
|
188 |
sim_pats = np.array(self.patients_with_same_disease[patient_id]) |
|
|
189 |
np.random.shuffle(sim_pats) |
|
|
190 |
return sim_pats |
|
|
191 |
else: |
|
|
192 |
raise NotImplementedError |
|
|
193 |
|
|
|
194 |
def get_candidate_diseases(self, cand_type='all_kg_nodes'): |
|
|
195 |
if cand_type == 'all_kg_nodes': |
|
|
196 |
all_kg_diseases_idx = np.unique(list(self.disease_to_idx_dict.values())) # get idx of all diseases in KG |
|
|
197 |
return torch.LongTensor(all_kg_diseases_idx) |
|
|
198 |
elif cand_type == 'orphanet': |
|
|
199 |
orpha_mondo_diseases = [d[0] for d in list(self.orpha_mondo_map.values())] |
|
|
200 |
orpha_mondo_idx = np.unique([self.disease_to_idx_dict[d] for d in orpha_mondo_diseases if d in self.disease_to_idx_dict]) |
|
|
201 |
return torch.LongTensor(orpha_mondo_idx) |
|
|
202 |
else: |
|
|
203 |
raise NotImplementedError |
|
|
204 |
|