a b/sybil/datasets/mgh.py
1
import numpy as np
2
from tqdm import tqdm
3
from ast import literal_eval
4
from sybil.datasets.nlst import NLST_Survival_Dataset
5
from collections import Counter
6
import copy
7
8
DEVICE_ID = {
9
    "GE MEDICAL SYSTEMS": 0,
10
    "TOSHIBA": 1,
11
    "Philips": 2,
12
    "SIEMENS": 3,
13
    "Siemens Healthcare": 3,  # note: same id as SIEMENS
14
    "Vital Images, Inc.": 4,
15
    "Hitachi Medical Corporation": 5,
16
    "LightSpeed16": 6,
17
}
18
19
20
class MGH_Dataset(NLST_Survival_Dataset):
21
    """
22
    MGH Dataset Cohort 1
23
    """
24
25
    def create_dataset(self, split_group):
26
        """
27
        Gets the dataset from the paths and labels in the json.
28
        Arguments:
29
            split_group(str): One of ['train'|'dev'|'test'].
30
        Returns:
31
            The dataset as a dictionary with img paths, label,
32
            and additional information regarding exam or participant
33
        """
34
        dataset = []
35
36
        # if split probs is set, randomly assign new splits, (otherwise default is 70% train, 15% dev and 15% test)
37
        if self.args.assign_splits:
38
            np.random.seed(self.args.cross_val_seed)
39
            self.assign_splits(self.metadata_json)
40
41
        for mrn_row in tqdm(self.metadata_json):
42
            pid, split, exams = mrn_row["pid"], mrn_row["split"], mrn_row["accessions"]
43
            # pt_metadata missing
44
45
            for exam_dict in exams:
46
                studyuid = exam_dict["StudyInstanceUID"]
47
                bridge_uid = exam_dict["bridge_uid"]
48
                days_to_last_exam = -int(
49
                    exam_dict["diff_days"]
50
                )  # no. of days to the oldest exam (0 or a negative int)
51
52
                exam_no = self.get_exam_no(days_to_last_exam, exams)
53
54
                y, y_seq, y_mask, time_at_event = self.get_label(exam_dict, exams)
55
56
                for series_id, series_dict in exam_dict["image_series"].items():
57
58
                    if self.skip_sample(series_dict, exam_dict, mrn_row, split_group):
59
                        continue
60
61
                    img_paths = series_dict["paths"]
62
                    img_paths = [p.replace("Data082021", "pngs") for p in img_paths]
63
                    slice_locations = series_dict["image_posn"]
64
                    series_data = series_dict["series_data"]
65
                    device = DEVICE_ID[series_data["Manufacturer"]]
66
67
                    sorted_img_paths, sorted_slice_locs = self.order_slices(
68
                        img_paths, slice_locations
69
                    )
70
71
                    sample = {
72
                        "paths": sorted_img_paths,
73
                        "slice_locations": sorted_slice_locs,
74
                        "y": int(y),
75
                        "time_at_event": time_at_event,
76
                        "y_seq": y_seq,
77
                        "y_mask": y_mask,
78
                        "exam": int(
79
                            "{}{}".format(
80
                                studyuid.replace(".", "")[-5:],
81
                                series_id.replace(".", "")[-5:],
82
                            )
83
                        ),  # last 5 of study id + last 5 of series id
84
                        "exam_str": "{}_{}".format(bridge_uid, exam_no),
85
                        "accession": exam_no,
86
                        "study": studyuid,
87
                        "series": series_id,
88
                        "pid": pid,
89
                        "device": device,
90
                        "lung_rads": -1
91
                        if exam_dict["lung_rads"] == np.nan
92
                        else exam_dict["lung_rads"],
93
                        "IV_contrast": exam_dict["IV_contrast"],
94
                        "lung_cancer_screening": exam_dict["lung_cancer_screening"],
95
                        "cancer_location": np.zeros(14),  # mgh has no annotations
96
                        "cancer_laterality": np.zeros(
97
                            3, dtype=np.int
98
                        ),  # has to be int, while cancer_location has to be float
99
                        "num_original_slices": len(series_dict["paths"]),
100
                        "annotations": [],
101
                        "pixel_spacing": series_dict["pixel_spacing"]
102
                        + [series_dict["slice_thickness"]],
103
                        "slice_thickness": self.get_slice_thickness_class(
104
                            series_dict["slice_thickness"]
105
                        ),
106
                    }
107
108
                    if self.args.use_risk_factors:
109
                        sample["risk_factors"] = self.get_risk_factors(
110
                            exam_dict, return_dict=False
111
                        )
112
113
                    if self.args.use_annotations:
114
                        # mgh has no annotations, so set everything to zero / false
115
                        sample["volume_annotations"] = np.array(
116
                            [0 for _ in sample["paths"]]
117
                        )
118
                        sample["annotations"] = [
119
                            {"image_annotations": None} for path in sample["paths"]
120
                        ]
121
122
                    dataset.append(sample)
123
124
        return dataset
125
126
    def skip_sample(self, series_dict, exam_dict, mrn_row, split):
127
        if not mrn_row["split"] == split:
128
            return True
129
130
        if mrn_row["in_cohort2"]:
131
            return True
132
133
        # check if screen is localizer screen or not enough images
134
        if self.is_localizer(series_dict["series_data"]):
135
            return True
136
137
        slice_thickness = series_dict["slice_thickness"]
138
        # check if restricting to specific slice thicknesses
139
        if (self.args.slice_thickness_filter is not None) and (
140
            (slice_thickness in ["", None])
141
            or (slice_thickness > self.args.slice_thickness_filter)
142
            or (slice_thickness < 0)
143
        ):
144
            return True
145
146
        if series_dict["pixel_spacing"] is None:
147
            return True
148
149
        # remove where slice location doesn't change (different axis):
150
        if len(set(series_dict["image_posn"])) < 2:
151
            return True
152
153
        if len(series_dict["paths"]) < self.args.min_num_images:
154
            return True
155
156
        return False
157
158
    def get_exam_no(self, diff_days, exams):
159
        """Gets the index of the exam, compared to the other exams"""
160
        sorted_days = sorted([-exam["diff_days"] for exam in exams], reverse=True)
161
        return sorted_days.index(diff_days)
162
163
    def get_label(self, exam_dict, exams):
164
        is_cancer_cohort = exam_dict["cancer_cohort_yes_no"] == "yes"
165
        days_to_last_followup = -exam_dict["diff_days"]
166
        years_to_last_followup = days_to_last_followup // 365
167
168
        y = 0
169
        y_seq = np.zeros(self.args.max_followup)
170
        if is_cancer_cohort:
171
            days_to_cancer = -exam_dict["diff_days_exam_lung_cancer_diagnosis"]
172
            years_to_cancer = int(days_to_cancer // 365)
173
            y = years_to_cancer < self.args.max_followup
174
175
            time_at_event = min(years_to_cancer, self.args.max_followup - 1)
176
            y_seq[years_to_cancer:] = 1
177
        else:
178
            time_at_event = min(years_to_last_followup, self.args.max_followup - 1)
179
180
        y_mask = np.array(
181
            [1] * (time_at_event + 1)
182
            + [0] * (self.args.max_followup - (time_at_event + 1))
183
        )
184
        y_mask = y_mask[: self.args.max_followup]
185
        return y, y_seq.astype("float64"), y_mask.astype("float64"), time_at_event
186
187
    def get_risk_factors(self, exam_dict, return_dict=False):
188
        risk_factors = {
189
            "age_at_exam": exam_dict["age_at_exam"],
190
            "pack_years": exam_dict["pack_years"],
191
            "race": exam_dict["race"],
192
            "sex": exam_dict["sex"],
193
            "smoking_status": exam_dict["smoking_status"],
194
        }
195
196
        if return_dict:
197
            return risk_factors
198
        else:
199
            return np.array(
200
                [v for v in risk_factors.values() if not isinstance(v, str)]
201
            )
202
203
    def is_localizer(self, series_dict):
204
        is_localizer = "LOCALIZER" in literal_eval(series_dict["ImageType"])
205
        return is_localizer
206
207
    @staticmethod
208
    def set_args(args):
209
        args.num_classes = args.max_followup
210
211
    def get_summary_statement(self, dataset, split_group):
212
        summary = "Constructed MGH CT Cancer Survival {} dataset with {} records, {} exams, {} patients, and the following class balance \n {}"
213
        class_balance = Counter([d["y"] for d in dataset])
214
        exams = set([d["exam"] for d in dataset])
215
        patients = set([d["pid"] for d in dataset])
216
        statement = summary.format(
217
            split_group,
218
            len(dataset),
219
            len(exams),
220
            len(patients),
221
            class_balance,
222
        )
223
        statement += "\n" + "Censor Times: {}".format(
224
            Counter([d["time_at_event"] for d in dataset])
225
        )
226
        return statement
227
228
    def assign_splits(self, meta):
229
        for idx in range(len(meta)):
230
            meta[idx]["split"] = np.random.choice(
231
                ["train", "dev", "test"], p=self.args.split_probs
232
            )
233
234
235
class MGH_Screening(NLST_Survival_Dataset):
236
    """
237
    MGH Dataset Cohort 2
238
    """
239
240
    def create_dataset(self, split_group):
241
        """
242
        Gets the dataset from the paths and labels in the json.
243
        Arguments:
244
            split_group(str): One of ['train'|'dev'|'test'].
245
        Returns:
246
            The dataset as a dictionary with img paths, label,
247
            and additional information regarding exam or participant
248
        """
249
        assert not self.args.train, "Cohort 2 should not be used for training"
250
251
        dataset = []
252
253
        for mrn_row in tqdm(self.metadata_json):
254
            pid, exams = mrn_row["pid"], mrn_row["accessions"]
255
256
            for exam_dict in exams:
257
258
                for series_id, series_dict in exam_dict["image_series"].items():
259
                    if self.skip_sample(series_dict, exam_dict, mrn_row):
260
                        continue
261
262
                    sample = self.get_volume_dict(
263
                        series_id, series_dict, exam_dict, mrn_row
264
                    )
265
                    if len(sample) == 0:
266
                        continue
267
268
                    dataset.append(sample)
269
270
        return dataset
271
272
    def skip_sample(self, series_dict, exam_dict, mrn_row):
273
        # unknown cancer status
274
        if exam_dict["Future_cancer"] == "unkown":
275
            return True
276
277
        if (exam_dict["days_before_cancer_dx"] < 0) or (
278
            exam_dict["days_to_last_follow_up"] < 0
279
        ):
280
            return True
281
282
        # check if screen is localizer screen or not enough images
283
        if self.is_localizer(series_dict["series_data"]):
284
            return True
285
286
        slice_thickness = series_dict["SliceThickness"]
287
        # check if restricting to specific slice thicknesses
288
        if (self.args.slice_thickness_filter is not None) and (
289
            (slice_thickness in ["", None])
290
            or (slice_thickness > self.args.slice_thickness_filter)
291
            or (slice_thickness < 0)
292
        ):
293
            return True
294
295
        if series_dict["PixelSpacing"] is None:
296
            return True
297
298
        if len(series_dict["paths"]) < self.args.min_num_images:
299
            return True
300
301
        return False
302
303
    def get_volume_dict(self, series_id, series_dict, exam_dict, mrn_row):
304
305
        img_paths = series_dict["paths"]
306
        img_paths = [
307
            p.replace("MIT_Lung_Cancer_Screening", "screening_pngs").replace(
308
                ".dcm", ".png"
309
            )
310
            for p in img_paths
311
        ]
312
        slice_locations = series_dict["slice_location"]
313
        series_data = series_dict["series_data"]
314
        pixel_spacing = series_dict["PixelSpacing"] + [series_dict["SliceThickness"]]
315
        sorted_img_paths, sorted_slice_locs = self.order_slices(
316
            img_paths, slice_locations, reverse=True
317
        )
318
319
        device = DEVICE_ID[series_data["Manufacturer"]]
320
321
        studyuid = exam_dict["StudyInstanceUID"]
322
        bridge_uid = exam_dict["bridge_uid"]
323
324
        y, y_seq, y_mask, time_at_event = self.get_label(exam_dict, mrn_row)
325
326
        sample = {
327
            "paths": sorted_img_paths,
328
            "slice_locations": sorted_slice_locs,
329
            "y": int(y),
330
            "time_at_event": time_at_event,
331
            "y_seq": y_seq,
332
            "y_mask": y_mask,
333
            "exam": int(
334
                "{}{}".format(
335
                    studyuid.replace(".", "")[-5:],
336
                    series_id.replace(".", "")[-5:],
337
                )
338
            ),  # last 5 of study id + last 5 of series id
339
            "study": studyuid,
340
            "series": series_id,
341
            "pid": mrn_row["pid"],
342
            "bridge_uid": bridge_uid,
343
            "device": device,
344
            "lung_rads": exam_dict["LR Score"],
345
            "cancer_location": np.zeros(14),  # mgh has no annotations
346
            "cancer_laterality": np.zeros(
347
                3, dtype=np.int
348
            ),  # has to be int, while cancer_location has to be float
349
            "num_original_slices": len(series_dict["paths"]),
350
            "marital_status": exam_dict["marital_status"],
351
            "religion": exam_dict["religion"],
352
            "primary_site": exam_dict["Primary Site"],
353
            "laterality1": exam_dict["Laterality"],
354
            "laterality2": exam_dict["Laterality.1"],
355
            "icdo3": exam_dict["Histo/Behavior ICD-O-3"],
356
            "pixel_spacing": pixel_spacing,
357
            "slice_thickness": self.get_slice_thickness_class(pixel_spacing[-1]),
358
        }
359
360
        if self.args.use_risk_factors:
361
            sample["risk_factors"] = self.get_risk_factors(exam_dict, return_dict=False)
362
363
        if self.args.use_annotations:
364
            # mgh has no annotations, so set everything to zero / false
365
            sample["volume_annotations"] = np.array([0 for _ in sample["paths"]])
366
            sample["annotations"] = [
367
                {"image_annotations": None} for path in sample["paths"]
368
            ]
369
        return sample
370
371
    def get_label(self, exam_dict, mrn_row):
372
        is_cancer_cohort = exam_dict["Future_cancer"].lower().strip() == "yes"
373
        days_to_cancer = exam_dict["days_before_cancer_dx"]
374
375
        y = False
376
        if (
377
            is_cancer_cohort
378
            and (not np.isnan(days_to_cancer))
379
            and (days_to_cancer > -1)
380
        ):
381
            years_to_cancer = int(days_to_cancer // 365)
382
            y = years_to_cancer < self.args.max_followup
383
384
        y_seq = np.zeros(self.args.max_followup)
385
386
        if y:
387
            time_at_event = years_to_cancer
388
            y_seq[years_to_cancer:] = 1
389
        else:
390
            if is_cancer_cohort:
391
                assert (days_to_cancer < 0) or (
392
                    years_to_cancer >= self.args.max_followup
393
                )
394
                time_at_event = self.args.max_followup - 1
395
            else:
396
                days_to_last_neg_followup = exam_dict["days_to_last_follow_up"]
397
                years_to_last_neg_followup = int(days_to_last_neg_followup // 365)
398
                time_at_event = min(
399
                    years_to_last_neg_followup, self.args.max_followup - 1
400
                )
401
402
        y_mask = np.array(
403
            [1] * (time_at_event + 1)
404
            + [0] * (self.args.max_followup - (time_at_event + 1))
405
        )
406
        y_mask = y_mask[: self.args.max_followup]
407
        return y, y_seq.astype("float64"), y_mask.astype("float64"), time_at_event
408
409
    def get_risk_factors(self, exam_dict, return_dict=False):
410
        risk_factors = {
411
            "race": exam_dict["race"],
412
            "pack_years": exam_dict["Packs Years"],
413
            "age_at_exam": exam_dict["age at the exam"],
414
            "gender": exam_dict["gender"],
415
            "smoking_status": exam_dict["Smoking Status"],
416
            "lung_rads": exam_dict["LR Score"],
417
            "years_since_quit_smoking": exam_dict["Year Since Last Smoked"],
418
        }
419
420
        if return_dict:
421
            return risk_factors
422
        else:
423
            return np.array(
424
                [v for v in risk_factors.values() if not isinstance(v, str)]
425
            )
426
427
    def is_localizer(self, series_dict):
428
        is_localizer = "LOCALIZER" in literal_eval(series_dict["ImageType"])
429
        return is_localizer
430
431
    @staticmethod
432
    def set_args(args):
433
        args.num_classes = args.max_followup
434
435
    def get_summary_statement(self, dataset, split_group):
436
        summary = "Constructed MGH CT Cancer Survival {} dataset with {} records, {} exams, {} patients, and the following class balance \n {}"
437
        class_balance = Counter([d["y"] for d in dataset])
438
        exams = set([d["exam"] for d in dataset])
439
        patients = set([d["pid"] for d in dataset])
440
        statement = summary.format(
441
            split_group,
442
            len(dataset),
443
            len(exams),
444
            len(patients),
445
            class_balance,
446
        )
447
        statement += "\n" + "Censor Times: {}".format(
448
            Counter([d["time_at_event"] for d in dataset])
449
        )
450
        return statement
451
452
    def assign_splits(self, meta):
453
        for idx in range(len(meta)):
454
            meta[idx]["split"] = np.random.choice(
455
                ["train", "dev", "test"], p=self.args.split_probs
456
            )