|
a |
|
b/sybil/datasets/mgh.py |
|
|
1 |
import numpy as np |
|
|
2 |
from tqdm import tqdm |
|
|
3 |
from ast import literal_eval |
|
|
4 |
from sybil.datasets.nlst import NLST_Survival_Dataset |
|
|
5 |
from collections import Counter |
|
|
6 |
import copy |
|
|
7 |
|
|
|
8 |
DEVICE_ID = { |
|
|
9 |
"GE MEDICAL SYSTEMS": 0, |
|
|
10 |
"TOSHIBA": 1, |
|
|
11 |
"Philips": 2, |
|
|
12 |
"SIEMENS": 3, |
|
|
13 |
"Siemens Healthcare": 3, # note: same id as SIEMENS |
|
|
14 |
"Vital Images, Inc.": 4, |
|
|
15 |
"Hitachi Medical Corporation": 5, |
|
|
16 |
"LightSpeed16": 6, |
|
|
17 |
} |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
class MGH_Dataset(NLST_Survival_Dataset): |
|
|
21 |
""" |
|
|
22 |
MGH Dataset Cohort 1 |
|
|
23 |
""" |
|
|
24 |
|
|
|
25 |
def create_dataset(self, split_group): |
|
|
26 |
""" |
|
|
27 |
Gets the dataset from the paths and labels in the json. |
|
|
28 |
Arguments: |
|
|
29 |
split_group(str): One of ['train'|'dev'|'test']. |
|
|
30 |
Returns: |
|
|
31 |
The dataset as a dictionary with img paths, label, |
|
|
32 |
and additional information regarding exam or participant |
|
|
33 |
""" |
|
|
34 |
dataset = [] |
|
|
35 |
|
|
|
36 |
# if split probs is set, randomly assign new splits, (otherwise default is 70% train, 15% dev and 15% test) |
|
|
37 |
if self.args.assign_splits: |
|
|
38 |
np.random.seed(self.args.cross_val_seed) |
|
|
39 |
self.assign_splits(self.metadata_json) |
|
|
40 |
|
|
|
41 |
for mrn_row in tqdm(self.metadata_json): |
|
|
42 |
pid, split, exams = mrn_row["pid"], mrn_row["split"], mrn_row["accessions"] |
|
|
43 |
# pt_metadata missing |
|
|
44 |
|
|
|
45 |
for exam_dict in exams: |
|
|
46 |
studyuid = exam_dict["StudyInstanceUID"] |
|
|
47 |
bridge_uid = exam_dict["bridge_uid"] |
|
|
48 |
days_to_last_exam = -int( |
|
|
49 |
exam_dict["diff_days"] |
|
|
50 |
) # no. of days to the oldest exam (0 or a negative int) |
|
|
51 |
|
|
|
52 |
exam_no = self.get_exam_no(days_to_last_exam, exams) |
|
|
53 |
|
|
|
54 |
y, y_seq, y_mask, time_at_event = self.get_label(exam_dict, exams) |
|
|
55 |
|
|
|
56 |
for series_id, series_dict in exam_dict["image_series"].items(): |
|
|
57 |
|
|
|
58 |
if self.skip_sample(series_dict, exam_dict, mrn_row, split_group): |
|
|
59 |
continue |
|
|
60 |
|
|
|
61 |
img_paths = series_dict["paths"] |
|
|
62 |
img_paths = [p.replace("Data082021", "pngs") for p in img_paths] |
|
|
63 |
slice_locations = series_dict["image_posn"] |
|
|
64 |
series_data = series_dict["series_data"] |
|
|
65 |
device = DEVICE_ID[series_data["Manufacturer"]] |
|
|
66 |
|
|
|
67 |
sorted_img_paths, sorted_slice_locs = self.order_slices( |
|
|
68 |
img_paths, slice_locations |
|
|
69 |
) |
|
|
70 |
|
|
|
71 |
sample = { |
|
|
72 |
"paths": sorted_img_paths, |
|
|
73 |
"slice_locations": sorted_slice_locs, |
|
|
74 |
"y": int(y), |
|
|
75 |
"time_at_event": time_at_event, |
|
|
76 |
"y_seq": y_seq, |
|
|
77 |
"y_mask": y_mask, |
|
|
78 |
"exam": int( |
|
|
79 |
"{}{}".format( |
|
|
80 |
studyuid.replace(".", "")[-5:], |
|
|
81 |
series_id.replace(".", "")[-5:], |
|
|
82 |
) |
|
|
83 |
), # last 5 of study id + last 5 of series id |
|
|
84 |
"exam_str": "{}_{}".format(bridge_uid, exam_no), |
|
|
85 |
"accession": exam_no, |
|
|
86 |
"study": studyuid, |
|
|
87 |
"series": series_id, |
|
|
88 |
"pid": pid, |
|
|
89 |
"device": device, |
|
|
90 |
"lung_rads": -1 |
|
|
91 |
if exam_dict["lung_rads"] == np.nan |
|
|
92 |
else exam_dict["lung_rads"], |
|
|
93 |
"IV_contrast": exam_dict["IV_contrast"], |
|
|
94 |
"lung_cancer_screening": exam_dict["lung_cancer_screening"], |
|
|
95 |
"cancer_location": np.zeros(14), # mgh has no annotations |
|
|
96 |
"cancer_laterality": np.zeros( |
|
|
97 |
3, dtype=np.int |
|
|
98 |
), # has to be int, while cancer_location has to be float |
|
|
99 |
"num_original_slices": len(series_dict["paths"]), |
|
|
100 |
"annotations": [], |
|
|
101 |
"pixel_spacing": series_dict["pixel_spacing"] |
|
|
102 |
+ [series_dict["slice_thickness"]], |
|
|
103 |
"slice_thickness": self.get_slice_thickness_class( |
|
|
104 |
series_dict["slice_thickness"] |
|
|
105 |
), |
|
|
106 |
} |
|
|
107 |
|
|
|
108 |
if self.args.use_risk_factors: |
|
|
109 |
sample["risk_factors"] = self.get_risk_factors( |
|
|
110 |
exam_dict, return_dict=False |
|
|
111 |
) |
|
|
112 |
|
|
|
113 |
if self.args.use_annotations: |
|
|
114 |
# mgh has no annotations, so set everything to zero / false |
|
|
115 |
sample["volume_annotations"] = np.array( |
|
|
116 |
[0 for _ in sample["paths"]] |
|
|
117 |
) |
|
|
118 |
sample["annotations"] = [ |
|
|
119 |
{"image_annotations": None} for path in sample["paths"] |
|
|
120 |
] |
|
|
121 |
|
|
|
122 |
dataset.append(sample) |
|
|
123 |
|
|
|
124 |
return dataset |
|
|
125 |
|
|
|
126 |
def skip_sample(self, series_dict, exam_dict, mrn_row, split): |
|
|
127 |
if not mrn_row["split"] == split: |
|
|
128 |
return True |
|
|
129 |
|
|
|
130 |
if mrn_row["in_cohort2"]: |
|
|
131 |
return True |
|
|
132 |
|
|
|
133 |
# check if screen is localizer screen or not enough images |
|
|
134 |
if self.is_localizer(series_dict["series_data"]): |
|
|
135 |
return True |
|
|
136 |
|
|
|
137 |
slice_thickness = series_dict["slice_thickness"] |
|
|
138 |
# check if restricting to specific slice thicknesses |
|
|
139 |
if (self.args.slice_thickness_filter is not None) and ( |
|
|
140 |
(slice_thickness in ["", None]) |
|
|
141 |
or (slice_thickness > self.args.slice_thickness_filter) |
|
|
142 |
or (slice_thickness < 0) |
|
|
143 |
): |
|
|
144 |
return True |
|
|
145 |
|
|
|
146 |
if series_dict["pixel_spacing"] is None: |
|
|
147 |
return True |
|
|
148 |
|
|
|
149 |
# remove where slice location doesn't change (different axis): |
|
|
150 |
if len(set(series_dict["image_posn"])) < 2: |
|
|
151 |
return True |
|
|
152 |
|
|
|
153 |
if len(series_dict["paths"]) < self.args.min_num_images: |
|
|
154 |
return True |
|
|
155 |
|
|
|
156 |
return False |
|
|
157 |
|
|
|
158 |
def get_exam_no(self, diff_days, exams): |
|
|
159 |
"""Gets the index of the exam, compared to the other exams""" |
|
|
160 |
sorted_days = sorted([-exam["diff_days"] for exam in exams], reverse=True) |
|
|
161 |
return sorted_days.index(diff_days) |
|
|
162 |
|
|
|
163 |
def get_label(self, exam_dict, exams): |
|
|
164 |
is_cancer_cohort = exam_dict["cancer_cohort_yes_no"] == "yes" |
|
|
165 |
days_to_last_followup = -exam_dict["diff_days"] |
|
|
166 |
years_to_last_followup = days_to_last_followup // 365 |
|
|
167 |
|
|
|
168 |
y = 0 |
|
|
169 |
y_seq = np.zeros(self.args.max_followup) |
|
|
170 |
if is_cancer_cohort: |
|
|
171 |
days_to_cancer = -exam_dict["diff_days_exam_lung_cancer_diagnosis"] |
|
|
172 |
years_to_cancer = int(days_to_cancer // 365) |
|
|
173 |
y = years_to_cancer < self.args.max_followup |
|
|
174 |
|
|
|
175 |
time_at_event = min(years_to_cancer, self.args.max_followup - 1) |
|
|
176 |
y_seq[years_to_cancer:] = 1 |
|
|
177 |
else: |
|
|
178 |
time_at_event = min(years_to_last_followup, self.args.max_followup - 1) |
|
|
179 |
|
|
|
180 |
y_mask = np.array( |
|
|
181 |
[1] * (time_at_event + 1) |
|
|
182 |
+ [0] * (self.args.max_followup - (time_at_event + 1)) |
|
|
183 |
) |
|
|
184 |
y_mask = y_mask[: self.args.max_followup] |
|
|
185 |
return y, y_seq.astype("float64"), y_mask.astype("float64"), time_at_event |
|
|
186 |
|
|
|
187 |
def get_risk_factors(self, exam_dict, return_dict=False): |
|
|
188 |
risk_factors = { |
|
|
189 |
"age_at_exam": exam_dict["age_at_exam"], |
|
|
190 |
"pack_years": exam_dict["pack_years"], |
|
|
191 |
"race": exam_dict["race"], |
|
|
192 |
"sex": exam_dict["sex"], |
|
|
193 |
"smoking_status": exam_dict["smoking_status"], |
|
|
194 |
} |
|
|
195 |
|
|
|
196 |
if return_dict: |
|
|
197 |
return risk_factors |
|
|
198 |
else: |
|
|
199 |
return np.array( |
|
|
200 |
[v for v in risk_factors.values() if not isinstance(v, str)] |
|
|
201 |
) |
|
|
202 |
|
|
|
203 |
def is_localizer(self, series_dict): |
|
|
204 |
is_localizer = "LOCALIZER" in literal_eval(series_dict["ImageType"]) |
|
|
205 |
return is_localizer |
|
|
206 |
|
|
|
207 |
@staticmethod |
|
|
208 |
def set_args(args): |
|
|
209 |
args.num_classes = args.max_followup |
|
|
210 |
|
|
|
211 |
def get_summary_statement(self, dataset, split_group): |
|
|
212 |
summary = "Constructed MGH CT Cancer Survival {} dataset with {} records, {} exams, {} patients, and the following class balance \n {}" |
|
|
213 |
class_balance = Counter([d["y"] for d in dataset]) |
|
|
214 |
exams = set([d["exam"] for d in dataset]) |
|
|
215 |
patients = set([d["pid"] for d in dataset]) |
|
|
216 |
statement = summary.format( |
|
|
217 |
split_group, |
|
|
218 |
len(dataset), |
|
|
219 |
len(exams), |
|
|
220 |
len(patients), |
|
|
221 |
class_balance, |
|
|
222 |
) |
|
|
223 |
statement += "\n" + "Censor Times: {}".format( |
|
|
224 |
Counter([d["time_at_event"] for d in dataset]) |
|
|
225 |
) |
|
|
226 |
return statement |
|
|
227 |
|
|
|
228 |
def assign_splits(self, meta): |
|
|
229 |
for idx in range(len(meta)): |
|
|
230 |
meta[idx]["split"] = np.random.choice( |
|
|
231 |
["train", "dev", "test"], p=self.args.split_probs |
|
|
232 |
) |
|
|
233 |
|
|
|
234 |
|
|
|
235 |
class MGH_Screening(NLST_Survival_Dataset): |
|
|
236 |
""" |
|
|
237 |
MGH Dataset Cohort 2 |
|
|
238 |
""" |
|
|
239 |
|
|
|
240 |
def create_dataset(self, split_group): |
|
|
241 |
""" |
|
|
242 |
Gets the dataset from the paths and labels in the json. |
|
|
243 |
Arguments: |
|
|
244 |
split_group(str): One of ['train'|'dev'|'test']. |
|
|
245 |
Returns: |
|
|
246 |
The dataset as a dictionary with img paths, label, |
|
|
247 |
and additional information regarding exam or participant |
|
|
248 |
""" |
|
|
249 |
assert not self.args.train, "Cohort 2 should not be used for training" |
|
|
250 |
|
|
|
251 |
dataset = [] |
|
|
252 |
|
|
|
253 |
for mrn_row in tqdm(self.metadata_json): |
|
|
254 |
pid, exams = mrn_row["pid"], mrn_row["accessions"] |
|
|
255 |
|
|
|
256 |
for exam_dict in exams: |
|
|
257 |
|
|
|
258 |
for series_id, series_dict in exam_dict["image_series"].items(): |
|
|
259 |
if self.skip_sample(series_dict, exam_dict, mrn_row): |
|
|
260 |
continue |
|
|
261 |
|
|
|
262 |
sample = self.get_volume_dict( |
|
|
263 |
series_id, series_dict, exam_dict, mrn_row |
|
|
264 |
) |
|
|
265 |
if len(sample) == 0: |
|
|
266 |
continue |
|
|
267 |
|
|
|
268 |
dataset.append(sample) |
|
|
269 |
|
|
|
270 |
return dataset |
|
|
271 |
|
|
|
272 |
def skip_sample(self, series_dict, exam_dict, mrn_row): |
|
|
273 |
# unknown cancer status |
|
|
274 |
if exam_dict["Future_cancer"] == "unkown": |
|
|
275 |
return True |
|
|
276 |
|
|
|
277 |
if (exam_dict["days_before_cancer_dx"] < 0) or ( |
|
|
278 |
exam_dict["days_to_last_follow_up"] < 0 |
|
|
279 |
): |
|
|
280 |
return True |
|
|
281 |
|
|
|
282 |
# check if screen is localizer screen or not enough images |
|
|
283 |
if self.is_localizer(series_dict["series_data"]): |
|
|
284 |
return True |
|
|
285 |
|
|
|
286 |
slice_thickness = series_dict["SliceThickness"] |
|
|
287 |
# check if restricting to specific slice thicknesses |
|
|
288 |
if (self.args.slice_thickness_filter is not None) and ( |
|
|
289 |
(slice_thickness in ["", None]) |
|
|
290 |
or (slice_thickness > self.args.slice_thickness_filter) |
|
|
291 |
or (slice_thickness < 0) |
|
|
292 |
): |
|
|
293 |
return True |
|
|
294 |
|
|
|
295 |
if series_dict["PixelSpacing"] is None: |
|
|
296 |
return True |
|
|
297 |
|
|
|
298 |
if len(series_dict["paths"]) < self.args.min_num_images: |
|
|
299 |
return True |
|
|
300 |
|
|
|
301 |
return False |
|
|
302 |
|
|
|
303 |
def get_volume_dict(self, series_id, series_dict, exam_dict, mrn_row): |
|
|
304 |
|
|
|
305 |
img_paths = series_dict["paths"] |
|
|
306 |
img_paths = [ |
|
|
307 |
p.replace("MIT_Lung_Cancer_Screening", "screening_pngs").replace( |
|
|
308 |
".dcm", ".png" |
|
|
309 |
) |
|
|
310 |
for p in img_paths |
|
|
311 |
] |
|
|
312 |
slice_locations = series_dict["slice_location"] |
|
|
313 |
series_data = series_dict["series_data"] |
|
|
314 |
pixel_spacing = series_dict["PixelSpacing"] + [series_dict["SliceThickness"]] |
|
|
315 |
sorted_img_paths, sorted_slice_locs = self.order_slices( |
|
|
316 |
img_paths, slice_locations, reverse=True |
|
|
317 |
) |
|
|
318 |
|
|
|
319 |
device = DEVICE_ID[series_data["Manufacturer"]] |
|
|
320 |
|
|
|
321 |
studyuid = exam_dict["StudyInstanceUID"] |
|
|
322 |
bridge_uid = exam_dict["bridge_uid"] |
|
|
323 |
|
|
|
324 |
y, y_seq, y_mask, time_at_event = self.get_label(exam_dict, mrn_row) |
|
|
325 |
|
|
|
326 |
sample = { |
|
|
327 |
"paths": sorted_img_paths, |
|
|
328 |
"slice_locations": sorted_slice_locs, |
|
|
329 |
"y": int(y), |
|
|
330 |
"time_at_event": time_at_event, |
|
|
331 |
"y_seq": y_seq, |
|
|
332 |
"y_mask": y_mask, |
|
|
333 |
"exam": int( |
|
|
334 |
"{}{}".format( |
|
|
335 |
studyuid.replace(".", "")[-5:], |
|
|
336 |
series_id.replace(".", "")[-5:], |
|
|
337 |
) |
|
|
338 |
), # last 5 of study id + last 5 of series id |
|
|
339 |
"study": studyuid, |
|
|
340 |
"series": series_id, |
|
|
341 |
"pid": mrn_row["pid"], |
|
|
342 |
"bridge_uid": bridge_uid, |
|
|
343 |
"device": device, |
|
|
344 |
"lung_rads": exam_dict["LR Score"], |
|
|
345 |
"cancer_location": np.zeros(14), # mgh has no annotations |
|
|
346 |
"cancer_laterality": np.zeros( |
|
|
347 |
3, dtype=np.int |
|
|
348 |
), # has to be int, while cancer_location has to be float |
|
|
349 |
"num_original_slices": len(series_dict["paths"]), |
|
|
350 |
"marital_status": exam_dict["marital_status"], |
|
|
351 |
"religion": exam_dict["religion"], |
|
|
352 |
"primary_site": exam_dict["Primary Site"], |
|
|
353 |
"laterality1": exam_dict["Laterality"], |
|
|
354 |
"laterality2": exam_dict["Laterality.1"], |
|
|
355 |
"icdo3": exam_dict["Histo/Behavior ICD-O-3"], |
|
|
356 |
"pixel_spacing": pixel_spacing, |
|
|
357 |
"slice_thickness": self.get_slice_thickness_class(pixel_spacing[-1]), |
|
|
358 |
} |
|
|
359 |
|
|
|
360 |
if self.args.use_risk_factors: |
|
|
361 |
sample["risk_factors"] = self.get_risk_factors(exam_dict, return_dict=False) |
|
|
362 |
|
|
|
363 |
if self.args.use_annotations: |
|
|
364 |
# mgh has no annotations, so set everything to zero / false |
|
|
365 |
sample["volume_annotations"] = np.array([0 for _ in sample["paths"]]) |
|
|
366 |
sample["annotations"] = [ |
|
|
367 |
{"image_annotations": None} for path in sample["paths"] |
|
|
368 |
] |
|
|
369 |
return sample |
|
|
370 |
|
|
|
371 |
def get_label(self, exam_dict, mrn_row): |
|
|
372 |
is_cancer_cohort = exam_dict["Future_cancer"].lower().strip() == "yes" |
|
|
373 |
days_to_cancer = exam_dict["days_before_cancer_dx"] |
|
|
374 |
|
|
|
375 |
y = False |
|
|
376 |
if ( |
|
|
377 |
is_cancer_cohort |
|
|
378 |
and (not np.isnan(days_to_cancer)) |
|
|
379 |
and (days_to_cancer > -1) |
|
|
380 |
): |
|
|
381 |
years_to_cancer = int(days_to_cancer // 365) |
|
|
382 |
y = years_to_cancer < self.args.max_followup |
|
|
383 |
|
|
|
384 |
y_seq = np.zeros(self.args.max_followup) |
|
|
385 |
|
|
|
386 |
if y: |
|
|
387 |
time_at_event = years_to_cancer |
|
|
388 |
y_seq[years_to_cancer:] = 1 |
|
|
389 |
else: |
|
|
390 |
if is_cancer_cohort: |
|
|
391 |
assert (days_to_cancer < 0) or ( |
|
|
392 |
years_to_cancer >= self.args.max_followup |
|
|
393 |
) |
|
|
394 |
time_at_event = self.args.max_followup - 1 |
|
|
395 |
else: |
|
|
396 |
days_to_last_neg_followup = exam_dict["days_to_last_follow_up"] |
|
|
397 |
years_to_last_neg_followup = int(days_to_last_neg_followup // 365) |
|
|
398 |
time_at_event = min( |
|
|
399 |
years_to_last_neg_followup, self.args.max_followup - 1 |
|
|
400 |
) |
|
|
401 |
|
|
|
402 |
y_mask = np.array( |
|
|
403 |
[1] * (time_at_event + 1) |
|
|
404 |
+ [0] * (self.args.max_followup - (time_at_event + 1)) |
|
|
405 |
) |
|
|
406 |
y_mask = y_mask[: self.args.max_followup] |
|
|
407 |
return y, y_seq.astype("float64"), y_mask.astype("float64"), time_at_event |
|
|
408 |
|
|
|
409 |
def get_risk_factors(self, exam_dict, return_dict=False): |
|
|
410 |
risk_factors = { |
|
|
411 |
"race": exam_dict["race"], |
|
|
412 |
"pack_years": exam_dict["Packs Years"], |
|
|
413 |
"age_at_exam": exam_dict["age at the exam"], |
|
|
414 |
"gender": exam_dict["gender"], |
|
|
415 |
"smoking_status": exam_dict["Smoking Status"], |
|
|
416 |
"lung_rads": exam_dict["LR Score"], |
|
|
417 |
"years_since_quit_smoking": exam_dict["Year Since Last Smoked"], |
|
|
418 |
} |
|
|
419 |
|
|
|
420 |
if return_dict: |
|
|
421 |
return risk_factors |
|
|
422 |
else: |
|
|
423 |
return np.array( |
|
|
424 |
[v for v in risk_factors.values() if not isinstance(v, str)] |
|
|
425 |
) |
|
|
426 |
|
|
|
427 |
def is_localizer(self, series_dict): |
|
|
428 |
is_localizer = "LOCALIZER" in literal_eval(series_dict["ImageType"]) |
|
|
429 |
return is_localizer |
|
|
430 |
|
|
|
431 |
@staticmethod |
|
|
432 |
def set_args(args): |
|
|
433 |
args.num_classes = args.max_followup |
|
|
434 |
|
|
|
435 |
def get_summary_statement(self, dataset, split_group): |
|
|
436 |
summary = "Constructed MGH CT Cancer Survival {} dataset with {} records, {} exams, {} patients, and the following class balance \n {}" |
|
|
437 |
class_balance = Counter([d["y"] for d in dataset]) |
|
|
438 |
exams = set([d["exam"] for d in dataset]) |
|
|
439 |
patients = set([d["pid"] for d in dataset]) |
|
|
440 |
statement = summary.format( |
|
|
441 |
split_group, |
|
|
442 |
len(dataset), |
|
|
443 |
len(exams), |
|
|
444 |
len(patients), |
|
|
445 |
class_balance, |
|
|
446 |
) |
|
|
447 |
statement += "\n" + "Censor Times: {}".format( |
|
|
448 |
Counter([d["time_at_event"] for d in dataset]) |
|
|
449 |
) |
|
|
450 |
return statement |
|
|
451 |
|
|
|
452 |
def assign_splits(self, meta): |
|
|
453 |
for idx in range(len(meta)): |
|
|
454 |
meta[idx]["split"] = np.random.choice( |
|
|
455 |
["train", "dev", "test"], p=self.args.split_probs |
|
|
456 |
) |