Diff of /shepherd/dataset.py [000000] .. [db6163]

Switch to unified view

a b/shepherd/dataset.py
1
2
3
import random
4
import numpy as np
5
import pandas as pd
6
import pickle
7
import sys
8
import torch
9
import time
10
import re
11
from torch.utils.data import Dataset
12
from collections import defaultdict
13
14
from project_utils import read_patients
15
import project_config
16
17
18
class PatientDataset(Dataset):
19
20
    def __init__(self, filepath, gp_spl=None, raw_data=False, mondo_map_file=str(project_config.PROJECT_DIR / 'mondo_references.csv'), needs_disease_mapping=False, time=False): 
21
        self.filepath = filepath
22
        self.patients = read_patients(filepath)
23
        print('Dataset filepath: ', filepath)
24
        print('Number of patients: ', len(self.patients))
25
        
26
        # add placeholder for true genes/diseases if they don't exist
27
        for patient in self.patients:
28
            if 'true_genes' not in patient: patient['true_genes'] = []
29
            if 'true_diseases' not in patient: patient['true_diseases'] = []
30
31
        self.raw_data = raw_data
32
        self.needs_disease_mapping = needs_disease_mapping
33
        self.time = time
34
35
        # create HPO to node_idx map
36
        with open(project_config.KG_DIR / f'hpo_to_idx_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
37
            self.hpo_to_idx_dict = pickle.load(handle)
38
        with open(project_config.KG_DIR / f'hpo_to_name_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
39
            self.hpo_to_name_dict = pickle.load(handle)
40
        self.idx_to_hpo_dict = {v:self.hpo_to_name_dict[k] if k in self.hpo_to_name_dict else k for k, v in self.hpo_to_idx_dict.items()}
41
42
43
        # create ensembl to node_idx map
44
        # NOTE: assumes conversion from gene symbols to ensembl IDs has already occurred
45
        with open(str(project_config.KG_DIR  / f'ensembl_to_idx_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle:
46
            self.ensembl_to_idx_dict = pickle.load(handle)
47
        self.idx_to_ensembl_dict = {v:k for k, v in self.ensembl_to_idx_dict.items()}
48
49
        # orphanet to mondo disease map
50
        with open(str(project_config.PROJECT_DIR / 'preprocess' / 'orphanet' / 'orphanet_to_mondo_dict.pkl'), 'rb') as handle:
51
            self.orpha_mondo_map = pickle.load(handle)
52
53
        with open(project_config.KG_DIR / f'mondo_to_idx_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
54
            self.disease_to_idx_dict = pickle.load(handle)
55
        with open(project_config.KG_DIR / f'mondo_to_name_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
56
            self.disease_to_name_dict = pickle.load(handle)
57
        self.idx_to_disease_dict = {v:self.disease_to_name_dict[k] if k in self.disease_to_name_dict else k for k, v in self.disease_to_idx_dict.items()}
58
59
60
        # degree dict from idx to degree - used for debugging
61
        #NOTE: may need to subtract 1 to index into this dict
62
        with open(str(project_config.KG_DIR / f'degree_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle:
63
            self.degree_dict = pickle.load(handle)
64
65
        # get patients with similar genes
66
        if all(['true_genes' in patient for patient in self.patients]): # first check to make sure all patients have true genes
67
            genes_to_patients = defaultdict(list)
68
            for patient in self.patients:
69
                for g in patient['true_genes']:
70
                    genes_to_patients[g].append(patient['id'])
71
            self.patients_with_same_gene = defaultdict(list)
72
            for patients in genes_to_patients.values():
73
                for p in patients:
74
                    self.patients_with_same_gene[p].extend([pat for pat in patients if pat != p])
75
76
        # get patients with similar diseases
77
        if all(['true_diseases' in patient for patient in self.patients]): # first check to make sure all patients have true diseases
78
            dis_to_patients = defaultdict(list)
79
            for patient in self.patients:
80
                patient_diseases = patient['true_diseases']
81
                for d in patient_diseases: 
82
                    dis_to_patients[d].append(patient['id'])
83
            self.patients_with_same_disease = defaultdict(list)
84
            for patients in dis_to_patients.values():
85
                for p in patients:
86
                    self.patients_with_same_disease[p].extend([pat for pat in patients if pat != p])
87
88
        # map from patient id to index in dataset
89
        self.patient_id_to_index = {p['id']:i for i, p in enumerate(self.patients)}
90
91
        print('Finished initalizing dataset')
92
93
94
    def __len__(self):
95
        '''
96
        Returns the length of the dataset
97
        '''
98
        return len(self.patients)
99
100
    def __getitem__(self, idx):
101
        '''
102
        Returns a single example from the dataset
103
        '''
104
        t0 = time.time()
105
        patient = self.patients[idx]
106
107
        additional_labels_dict = {}
108
        if 'additional_labels' in patient:
109
            for label, values in patient['additional_labels'].items():
110
                if label == "n_hops_cand_g_p": continue
111
                if values == None: values = [[-1]]
112
                if type(values) != list: values = [[values]] #  wrap in list if needed
113
                if type(values) == list and (len(values)==0 or type(values[0]) != list): values = [values]
114
                additional_labels_dict[label] = values 
115
            if 'max_percent_phen_overlap_train' not in patient['additional_labels']: additional_labels_dict['max_percent_phen_overlap_train'] = [[-1]] 
116
            if 'max_phen_overlap_train' not in patient['additional_labels']: additional_labels_dict['max_phen_overlap_train'] = [[-1]]
117
       
118
        phenotype_node_idx = [self.hpo_to_idx_dict[p] for p in patient['positive_phenotypes'] if p in self.hpo_to_idx_dict ]
119
        correct_genes_node_idx = [self.ensembl_to_idx_dict[g] for g in patient['true_genes'] if g in self.ensembl_to_idx_dict ]
120
        if 'all_candidate_genes' in patient:
121
            candidate_gene_node_idx = [self.ensembl_to_idx_dict[g] for g in patient['all_candidate_genes'] if g in self.ensembl_to_idx_dict ]
122
        else: candidate_gene_node_idx = []
123
124
        if 'true_diseases' in patient:
125
            if self.needs_disease_mapping:
126
                orpha_diseases = [ int(d) if len(re.match("^[0-9]*", d)[0]) > 0 else d for d in patient['true_diseases']]
127
                mondo_diseases = [mondo_d for orpha_d in set(orpha_diseases).intersection(set(self.orpha_mondo_map.keys())) for mondo_d in self.orpha_mondo_map[orpha_d]]
128
            else:
129
                mondo_diseases = [str(d) for d in patient['true_diseases']]
130
            disease_node_idx = [self.disease_to_idx_dict[d] for d in mondo_diseases if d in self.disease_to_idx_dict]
131
        else: disease_node_idx = None
132
        
133
        if not self.raw_data:
134
            phenotype_node_idx = torch.LongTensor(phenotype_node_idx)
135
            correct_genes_node_idx = torch.LongTensor(correct_genes_node_idx)
136
            candidate_gene_node_idx = torch.LongTensor(candidate_gene_node_idx)
137
            if 'true_diseases' in patient: disease_node_idx = torch.LongTensor(disease_node_idx)
138
            
139
        
140
        assert len(phenotype_node_idx) >= 1, f'There are no phenotypes for patient: {patient}'
141
        
142
        #NOTE: assumes that patient has a single causal/correct gene (the model still outputs a score for each candidate gene)
143
        if not self.raw_data:
144
            if len(correct_genes_node_idx) > 1:
145
                #print('NOTE: The patient has multiple correct genes, but we\'re only selecting the first.')
146
                correct_genes_node_idx = correct_genes_node_idx[0].unsqueeze(-1)
147
148
        # get index of correct gene
149
        if len(correct_genes_node_idx) == 0: # no correct genes available for the patient
150
            label_idx = None
151
        else:
152
            if self.raw_data:
153
                label_idx = [candidate_gene_node_idx.index(g) for g in correct_genes_node_idx]
154
            else:
155
                label_idx = (candidate_gene_node_idx == correct_genes_node_idx[0]).nonzero(as_tuple=True)[0]
156
157
        if self.time:
158
            t1 = time.time()
159
            print(f'It takes {t1-t0:0.4f}s to get an item from the dataset')
160
            
161
        return (phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, label_idx, additional_labels_dict, patient['id'])
162
163
    def node_idx_to_degree(self, idx):
164
        if idx in self.degree_dict:
165
            return self.degree_dict[idx]
166
        else:
167
            return -1
168
    
169
    def node_idx_to_name(self, idx):
170
        if idx in self.idx_to_hpo_dict:
171
            return self.idx_to_hpo_dict[idx]
172
        elif idx in self.idx_to_ensembl_dict:
173
            return self.idx_to_ensembl_dict[idx]
174
        elif idx in self.idx_to_disease_dict:
175
            return self.idx_to_disease_dict[idx]
176
        elif idx == 0:
177
            return 'padding'
178
        else:
179
            print("Exception on:", idx)
180
            raise Exception
181
182
    def get_similar_patients(self, patient_id, similarity_type='gene'):
183
        if similarity_type == 'gene':
184
            sim_pats = np.array(self.patients_with_same_gene[patient_id])
185
            np.random.shuffle(sim_pats)
186
            return sim_pats
187
        elif similarity_type == 'disease':
188
            sim_pats = np.array(self.patients_with_same_disease[patient_id])
189
            np.random.shuffle(sim_pats)
190
            return sim_pats
191
        else:
192
            raise NotImplementedError
193
194
    def get_candidate_diseases(self, cand_type='all_kg_nodes'):
195
        if cand_type == 'all_kg_nodes':
196
            all_kg_diseases_idx = np.unique(list(self.disease_to_idx_dict.values())) # get idx of all diseases in KG
197
            return torch.LongTensor(all_kg_diseases_idx)
198
        elif cand_type == 'orphanet':
199
            orpha_mondo_diseases = [d[0] for d in list(self.orpha_mondo_map.values())]
200
            orpha_mondo_idx = np.unique([self.disease_to_idx_dict[d] for d in orpha_mondo_diseases if d in self.disease_to_idx_dict])
201
            return torch.LongTensor(orpha_mondo_idx)
202
        else:
203
            raise NotImplementedError
204