Diff of /loader.py [000000] .. [7b3b0e]

Switch to unified view

a b/loader.py
1
'''
2
This file contains a class which will load the different
3
attributes of the patients of a given dataset.
4
Firstly, NSCLC-Radiogenomics
5
'''
6
7
from copy import deepcopy
8
import pandas as pd
9
import numpy as np
10
import os
11
from itertools import chain
12
from sklearn.preprocessing import LabelBinarizer
13
import nibabel as nib
14
15
16
class Dataset(object):
17
    '''
18
    Class which contains a list of all patients,
19
    images, genomics, recurrence, survival, and mutation information.
20
21
    If the corresponding data is unavailable, 'NA' will
22
    be used.
23
    '''
24
25
    def __init__(self, config, dataset='NSCLC-Radiogenomics'):
26
        self.dataset_info = dataset
27
        self.config = config
28
        self.get_patient_list()
29
30
        # these are list of files which should be read appropriately
31
        self.image_list = {}
32
        self.images = {}
33
        self.seg_list = {}
34
        self.feature_list = {}
35
36
        # these are already loaded attributes
37
        self.genomics_list = {}
38
        self.egfr_mutation = {}
39
        self.recurrence_bool = {}
40
        self.recurrence_value = {}
41
        self.survival_bool = {}
42
        self.survival_value = {}
43
        self.durations = {}
44
45
        #TODO: include this:       self.last_known_alive = []
46
        self.clinical_list = {}
47
48
        self.load_all()
49
50
    def get_patient_list(self):
51
52
        self.data_location = self.config.location
53
        self.patient_list = list(pd.read_csv(self.config.clinical)['Case ID'])
54
        return
55
56
    def set_patient_list(self, patient_list):
57
        self.patient_list = patient_list
58
59
    def load_all(self):
60
        self.load_images()
61
        self.load_segmentations()
62
        self.load_pyradiomics()
63
        self.load_genomics()
64
        self.load_recurrence()
65
        self.load_survival()
66
        self.load_egfr_mutation()
67
#         self.load_clinical()
68
        self.load_densenet_features()
69
70
    def load_images(self):
71
        for patientID in self.patient_list:
72
            image_path = self.config.images + patientID + '.nii'
73
            if os.path.exists(image_path):
74
                self.image_list[patientID] = image_path
75
            else:
76
                self.image_list[patientID] = 'N/A'
77
        return
78
79
    def get_images_cropped(self):
80
        self.images = {}
81
        for patientID in self.patient_list:
82
            image_path = self.config.cropped + patientID + '_cropped_nodule.nii'
83
            if os.path.exists(image_path):
84
                self.images[patientID] = nib.load(image_path).get_fdata()
85
            else:
86
                self.images[patientID] = 'N/A'
87
        return
88
89
    def get_pyradiomics(self, patient_list):
90
        features=[]
91
        for patientID in patient_list:
92
            feature_path = self.config.pyradiomics + patientID + '_dilated.npz'
93
            features.append(np.load(feature_path)['arr_0'])
94
        return features
95
96
    def get_densenet_features(self, patient_list):
97
        features=[]
98
        for patientID in patient_list:
99
            feature_path = self.config.densenet + patientID + '_densenet.npy'
100
            loaded = np.load(feature_path)
101
            features.append(loaded)
102
103
        features = np.array(features)
104
        features = np.squeeze(features)
105
        return features
106
107
    def get_genomics(self, patient_list):
108
        genomics = pd.read_csv(self.config.genomics, index_col=False)
109
        genomics.set_index('Unnamed: 0.1', inplace=True)
110
        genomics = genomics.drop('Unnamed: 0', axis=1)
111
        genomics = genomics.transpose()
112
113
        # TODO: we can add code here to normalize the genomics data
114
        genomics_list = []
115
        for id in patient_list:
116
            genomics_list.append(list(genomics.loc[id]))
117
        return genomics_list, genomics.columns
118
119
120
    def get_clinical(self, patient_list):
121
        clinical = pd.read_csv(self.config.clinical)
122
123
        list_of_variables = clinical.columns.values
124
125
        predictors_labels = list(chain([list_of_variables[0]], [list_of_variables[2]], list_of_variables[6:8], list_of_variables[9:22], list_of_variables[23:24])) #:30]))
126
        predictors = clinical[predictors_labels]
127
128
        predictors.set_index('Case ID', inplace=True)
129
130
        predictors['Smoking status'].replace(self.config.smoking_dict, inplace=True)
131
        predictors['Pack Years'].replace({'N/A': 0, 'Not Collected': 40}, inplace=True)
132
        predictors['%GG'].replace(self.config.gg_dict, inplace=True)
133
        for idx in range(10, 17):
134
            predictors[list_of_variables[idx]].replace(self.config.location_dict, inplace=True)
135
136
        encoder = LabelBinarizer()
137
        for idx in range(17, 24):
138
            if idx == 22:
139
                continue
140
            predictors[list_of_variables[idx]] = encoder.fit_transform(predictors[list_of_variables[idx]])
141
        predictors.fillna(0, inplace=True)
142
143
        clinical_data = []
144
145
        for id in patient_list:
146
            clinical_data.append([int(x) for x in list(predictors.loc[id])])
147
148
        return clinical_data
149
150
    def load_segmentations(self):
151
        self.seg_list = {}
152
        for patientID in self.patient_list:
153
            seg_path = self.config.segs + patientID + '.nii.gz'
154
            if os.path.exists(seg_path):
155
                self.seg_list[patientID] = seg_path
156
            else:
157
                self.seg_list[patientID] = 'N/A'
158
        return
159
160
    def load_pyradiomics(self):
161
162
        self.feature_list = {}
163
        for patientID in self.patient_list:
164
            feature_path = self.config.pyradiomics + patientID + '.npz'
165
            if os.path.exists(feature_path):
166
                self.feature_list[patientID] = feature_path
167
            else:
168
                self.feature_list[patientID] = 'N/A'
169
        return
170
171
    def load_densenet_features(self):
172
        self.densenet_features = {}
173
        for patientID in self.patient_list:
174
            feature_path = self.config.densenet + patientID + '_densenet.npy'
175
            if os.path.exists(feature_path):
176
                temp = np.load(feature_path)
177
                if np.size(temp) == 1:
178
                    self.densenet_features[patientID] = 'N/A'
179
                else:
180
                    self.densenet_features[patientID] = feature_path
181
            else:
182
                self.densenet_features[patientID] = 'N/A'
183
        return
184
185
    def load_genomics(self):
186
        self.genomics_list = {}
187
        genomics = pd.read_csv(self.config.genomics, index_col=False)
188
        genomics.set_index('Unnamed: 0.1', inplace=True)
189
        genomics = genomics.drop('Unnamed: 0', axis=1)
190
        genomics = genomics.transpose()
191
192
        for id in self.patient_list:
193
            if id in genomics.index.values:
194
               self.genomics_list[id] = list(genomics.loc[id])
195
            else:
196
                self.genomics_list[id] = 'N/A'
197
        return
198
199
    def load_recurrence(self):
200
        #TODO: Include the location information as well
201
202
        self.recurrence_value = {}
203
        self.recurrence_bool = {}
204
        self.durations = {}
205
        recurrence = pd.read_csv(self.config.recurrence, index_col=False)
206
        recurrence.set_index('Case ID', inplace=True)
207
208
        for id in self.patient_list:
209
            if id in recurrence.index.values:
210
                curr_patient = recurrence.loc[id]
211
                value = curr_patient['Recurrence']
212
                self.recurrence_bool[id] = value
213
                self.recurrence_value[id] = curr_patient['Days']
214
215
        return
216
217
    def load_survival(self):
218
219
        self.survival_value = {}
220
        self.survival_bool = {}
221
        recurrence = pd.read_csv(self.config.clinical, index_col=False)
222
        recurrence.set_index('Case ID', inplace=True)
223
224
        for id in self.patient_list:
225
            if id in recurrence.index.values:
226
                curr_patient = recurrence.loc[id]
227
                value = curr_patient['Survival Status']
228
                mapped_value = self.config.survival_mapping[value]
229
                self.survival_bool[id] = mapped_value
230
231
                if mapped_value == 1:
232
                    self.survival_value[id] = curr_patient['Time to Death (days)']
233
                else:
234
                    self.survival_value[id] = 'N/A'
235
            else:
236
                self.survival_bool[id] = 'N/A'
237
                self.survival_value[id] = 'N/A'
238
        return
239
240
    def load_egfr_mutation(self):
241
        self.egfr_mutation = {}
242
        egfr = pd.read_csv(self.config.clinical, index_col=False)
243
244
        egfr.set_index('Case ID', inplace=True)
245
246
        for id in self.patient_list:
247
            if id in egfr.index.values:
248
                value = egfr.loc[id]['EGFR mutation status']
249
                mapped_value = self.config.mutation_mapping[value]
250
                self.egfr_mutation[id] = mapped_value
251
252
            else:
253
                self.egfr_mutation[id] = 'N/A'
254
        return
255
256
    def load_clinical(self):
257
        self.clinical_list = {}
258
        clinical = pd.read_csv(self.config.clinical)
259
260
        list_of_variables = clinical.columns.values
261
262
        # clinical = clinical.loc[49:]
263
264
        # for id in range(len(list_of_variables)):
265
        #     print(id, list_of_variables[id])
266
267
        predictors_labels = list(chain([list_of_variables[0]], [list_of_variables[2]], list_of_variables[6:8], list_of_variables[9:22], list_of_variables[23:24])) #:30]))
268
        predictors = clinical[predictors_labels]
269
270
        predictors.set_index('Case ID', inplace=True)
271
272
        predictors['Smoking status'].replace(self.config.smoking_dict, inplace=True)
273
        predictors['Pack Years'].replace({'N/A': 0, 'Not Collected': 40}, inplace=True)
274
        predictors['%GG'].replace(self.config.gg_dict, inplace=True)
275
        for idx in range(10, 17):
276
            predictors[list_of_variables[idx]].replace(self.config.location_dict, inplace=True)
277
278
        encoder = LabelBinarizer()
279
        for idx in range(17, 24):
280
            if idx == 22:
281
                continue
282
            predictors[list_of_variables[idx]] = encoder.fit_transform(predictors[list_of_variables[idx]])
283
        predictors.fillna(0, inplace=True)
284
285
        print(predictors.columns.values)
286
        for id in self.patient_list:
287
            self.clinical_list[id] = [int(x) for x in list(predictors.loc[id])]
288
        return
289
290
    def select_subset_patients(self, to_select, replace_list=False):
291
        '''
292
293
        :param dataset: dataset of type Dataset
294
        :param to_select: list of features to subselect
295
        :return: updated dataset
296
        '''
297
298
        patient_list = deepcopy(self.patient_list)
299
        for id in self.patient_list:
300
            remove_bool = False
301
            for attr in to_select:
302
                if attr == 'pyradiomics':
303
                    if self.feature_list[id] == 'N/A':
304
                        remove_bool = True
305
                if attr == 'gene_expressions':
306
                    if self.genomics_list[id] == 'N/A':
307
                        remove_bool = True
308
                if attr == 'clinical':
309
                    if self.clinical_list[id] == 'N/A':
310
                        remove_bool = True
311
                if attr == 'recurrence':
312
                    if self.recurrence_bool[id] == 'N/A':
313
                        remove_bool = True
314
                if attr == 'densenet':
315
                    if self.densenet_features[id] == 'N/A':
316
                        remove_bool = True
317
                if attr == 'egfr':
318
                    if self.egfr_mutation[id] == 'N/A':
319
                        remove_bool = True
320
            if remove_bool is True:
321
                patient_list.remove(id)
322
323
        if replace_list == True:
324
            self.set_patient_list(patient_list)
325
            self.load_all()
326
327
        return patient_list
328
329
if __name__ == '__main__':
330
    nrg = Dataset()