Diff of /sybil/datasets/mgh.py [000000] .. [d9566e]

Switch to side-by-side view

--- a
+++ b/sybil/datasets/mgh.py
@@ -0,0 +1,456 @@
+import numpy as np
+from tqdm import tqdm
+from ast import literal_eval
+from sybil.datasets.nlst import NLST_Survival_Dataset
+from collections import Counter
+import copy
+
+DEVICE_ID = {
+    "GE MEDICAL SYSTEMS": 0,
+    "TOSHIBA": 1,
+    "Philips": 2,
+    "SIEMENS": 3,
+    "Siemens Healthcare": 3,  # note: same id as SIEMENS
+    "Vital Images, Inc.": 4,
+    "Hitachi Medical Corporation": 5,
+    "LightSpeed16": 6,
+}
+
+
+class MGH_Dataset(NLST_Survival_Dataset):
+    """
+    MGH Dataset Cohort 1
+    """
+
+    def create_dataset(self, split_group):
+        """
+        Gets the dataset from the paths and labels in the json.
+        Arguments:
+            split_group(str): One of ['train'|'dev'|'test'].
+        Returns:
+            The dataset as a dictionary with img paths, label,
+            and additional information regarding exam or participant
+        """
+        dataset = []
+
+        # if split probs is set, randomly assign new splits, (otherwise default is 70% train, 15% dev and 15% test)
+        if self.args.assign_splits:
+            np.random.seed(self.args.cross_val_seed)
+            self.assign_splits(self.metadata_json)
+
+        for mrn_row in tqdm(self.metadata_json):
+            pid, split, exams = mrn_row["pid"], mrn_row["split"], mrn_row["accessions"]
+            # pt_metadata missing
+
+            for exam_dict in exams:
+                studyuid = exam_dict["StudyInstanceUID"]
+                bridge_uid = exam_dict["bridge_uid"]
+                days_to_last_exam = -int(
+                    exam_dict["diff_days"]
+                )  # no. of days to the oldest exam (0 or a negative int)
+
+                exam_no = self.get_exam_no(days_to_last_exam, exams)
+
+                y, y_seq, y_mask, time_at_event = self.get_label(exam_dict, exams)
+
+                for series_id, series_dict in exam_dict["image_series"].items():
+
+                    if self.skip_sample(series_dict, exam_dict, mrn_row, split_group):
+                        continue
+
+                    img_paths = series_dict["paths"]
+                    img_paths = [p.replace("Data082021", "pngs") for p in img_paths]
+                    slice_locations = series_dict["image_posn"]
+                    series_data = series_dict["series_data"]
+                    device = DEVICE_ID[series_data["Manufacturer"]]
+
+                    sorted_img_paths, sorted_slice_locs = self.order_slices(
+                        img_paths, slice_locations
+                    )
+
+                    sample = {
+                        "paths": sorted_img_paths,
+                        "slice_locations": sorted_slice_locs,
+                        "y": int(y),
+                        "time_at_event": time_at_event,
+                        "y_seq": y_seq,
+                        "y_mask": y_mask,
+                        "exam": int(
+                            "{}{}".format(
+                                studyuid.replace(".", "")[-5:],
+                                series_id.replace(".", "")[-5:],
+                            )
+                        ),  # last 5 of study id + last 5 of series id
+                        "exam_str": "{}_{}".format(bridge_uid, exam_no),
+                        "accession": exam_no,
+                        "study": studyuid,
+                        "series": series_id,
+                        "pid": pid,
+                        "device": device,
+                        "lung_rads": -1
+                        if exam_dict["lung_rads"] == np.nan
+                        else exam_dict["lung_rads"],
+                        "IV_contrast": exam_dict["IV_contrast"],
+                        "lung_cancer_screening": exam_dict["lung_cancer_screening"],
+                        "cancer_location": np.zeros(14),  # mgh has no annotations
+                        "cancer_laterality": np.zeros(
+                            3, dtype=np.int
+                        ),  # has to be int, while cancer_location has to be float
+                        "num_original_slices": len(series_dict["paths"]),
+                        "annotations": [],
+                        "pixel_spacing": series_dict["pixel_spacing"]
+                        + [series_dict["slice_thickness"]],
+                        "slice_thickness": self.get_slice_thickness_class(
+                            series_dict["slice_thickness"]
+                        ),
+                    }
+
+                    if self.args.use_risk_factors:
+                        sample["risk_factors"] = self.get_risk_factors(
+                            exam_dict, return_dict=False
+                        )
+
+                    if self.args.use_annotations:
+                        # mgh has no annotations, so set everything to zero / false
+                        sample["volume_annotations"] = np.array(
+                            [0 for _ in sample["paths"]]
+                        )
+                        sample["annotations"] = [
+                            {"image_annotations": None} for path in sample["paths"]
+                        ]
+
+                    dataset.append(sample)
+
+        return dataset
+
+    def skip_sample(self, series_dict, exam_dict, mrn_row, split):
+        if not mrn_row["split"] == split:
+            return True
+
+        if mrn_row["in_cohort2"]:
+            return True
+
+        # check if screen is localizer screen or not enough images
+        if self.is_localizer(series_dict["series_data"]):
+            return True
+
+        slice_thickness = series_dict["slice_thickness"]
+        # check if restricting to specific slice thicknesses
+        if (self.args.slice_thickness_filter is not None) and (
+            (slice_thickness in ["", None])
+            or (slice_thickness > self.args.slice_thickness_filter)
+            or (slice_thickness < 0)
+        ):
+            return True
+
+        if series_dict["pixel_spacing"] is None:
+            return True
+
+        # remove where slice location doesn't change (different axis):
+        if len(set(series_dict["image_posn"])) < 2:
+            return True
+
+        if len(series_dict["paths"]) < self.args.min_num_images:
+            return True
+
+        return False
+
+    def get_exam_no(self, diff_days, exams):
+        """Gets the index of the exam, compared to the other exams"""
+        sorted_days = sorted([-exam["diff_days"] for exam in exams], reverse=True)
+        return sorted_days.index(diff_days)
+
+    def get_label(self, exam_dict, exams):
+        is_cancer_cohort = exam_dict["cancer_cohort_yes_no"] == "yes"
+        days_to_last_followup = -exam_dict["diff_days"]
+        years_to_last_followup = days_to_last_followup // 365
+
+        y = 0
+        y_seq = np.zeros(self.args.max_followup)
+        if is_cancer_cohort:
+            days_to_cancer = -exam_dict["diff_days_exam_lung_cancer_diagnosis"]
+            years_to_cancer = int(days_to_cancer // 365)
+            y = years_to_cancer < self.args.max_followup
+
+            time_at_event = min(years_to_cancer, self.args.max_followup - 1)
+            y_seq[years_to_cancer:] = 1
+        else:
+            time_at_event = min(years_to_last_followup, self.args.max_followup - 1)
+
+        y_mask = np.array(
+            [1] * (time_at_event + 1)
+            + [0] * (self.args.max_followup - (time_at_event + 1))
+        )
+        y_mask = y_mask[: self.args.max_followup]
+        return y, y_seq.astype("float64"), y_mask.astype("float64"), time_at_event
+
+    def get_risk_factors(self, exam_dict, return_dict=False):
+        risk_factors = {
+            "age_at_exam": exam_dict["age_at_exam"],
+            "pack_years": exam_dict["pack_years"],
+            "race": exam_dict["race"],
+            "sex": exam_dict["sex"],
+            "smoking_status": exam_dict["smoking_status"],
+        }
+
+        if return_dict:
+            return risk_factors
+        else:
+            return np.array(
+                [v for v in risk_factors.values() if not isinstance(v, str)]
+            )
+
+    def is_localizer(self, series_dict):
+        is_localizer = "LOCALIZER" in literal_eval(series_dict["ImageType"])
+        return is_localizer
+
+    @staticmethod
+    def set_args(args):
+        args.num_classes = args.max_followup
+
+    def get_summary_statement(self, dataset, split_group):
+        summary = "Constructed MGH CT Cancer Survival {} dataset with {} records, {} exams, {} patients, and the following class balance \n {}"
+        class_balance = Counter([d["y"] for d in dataset])
+        exams = set([d["exam"] for d in dataset])
+        patients = set([d["pid"] for d in dataset])
+        statement = summary.format(
+            split_group,
+            len(dataset),
+            len(exams),
+            len(patients),
+            class_balance,
+        )
+        statement += "\n" + "Censor Times: {}".format(
+            Counter([d["time_at_event"] for d in dataset])
+        )
+        return statement
+
+    def assign_splits(self, meta):
+        for idx in range(len(meta)):
+            meta[idx]["split"] = np.random.choice(
+                ["train", "dev", "test"], p=self.args.split_probs
+            )
+
+
+class MGH_Screening(NLST_Survival_Dataset):
+    """
+    MGH Dataset Cohort 2
+    """
+
+    def create_dataset(self, split_group):
+        """
+        Gets the dataset from the paths and labels in the json.
+        Arguments:
+            split_group(str): One of ['train'|'dev'|'test'].
+        Returns:
+            The dataset as a dictionary with img paths, label,
+            and additional information regarding exam or participant
+        """
+        assert not self.args.train, "Cohort 2 should not be used for training"
+
+        dataset = []
+
+        for mrn_row in tqdm(self.metadata_json):
+            pid, exams = mrn_row["pid"], mrn_row["accessions"]
+
+            for exam_dict in exams:
+
+                for series_id, series_dict in exam_dict["image_series"].items():
+                    if self.skip_sample(series_dict, exam_dict, mrn_row):
+                        continue
+
+                    sample = self.get_volume_dict(
+                        series_id, series_dict, exam_dict, mrn_row
+                    )
+                    if len(sample) == 0:
+                        continue
+
+                    dataset.append(sample)
+
+        return dataset
+
+    def skip_sample(self, series_dict, exam_dict, mrn_row):
+        # unknown cancer status
+        if exam_dict["Future_cancer"] == "unkown":
+            return True
+
+        if (exam_dict["days_before_cancer_dx"] < 0) or (
+            exam_dict["days_to_last_follow_up"] < 0
+        ):
+            return True
+
+        # check if screen is localizer screen or not enough images
+        if self.is_localizer(series_dict["series_data"]):
+            return True
+
+        slice_thickness = series_dict["SliceThickness"]
+        # check if restricting to specific slice thicknesses
+        if (self.args.slice_thickness_filter is not None) and (
+            (slice_thickness in ["", None])
+            or (slice_thickness > self.args.slice_thickness_filter)
+            or (slice_thickness < 0)
+        ):
+            return True
+
+        if series_dict["PixelSpacing"] is None:
+            return True
+
+        if len(series_dict["paths"]) < self.args.min_num_images:
+            return True
+
+        return False
+
+    def get_volume_dict(self, series_id, series_dict, exam_dict, mrn_row):
+
+        img_paths = series_dict["paths"]
+        img_paths = [
+            p.replace("MIT_Lung_Cancer_Screening", "screening_pngs").replace(
+                ".dcm", ".png"
+            )
+            for p in img_paths
+        ]
+        slice_locations = series_dict["slice_location"]
+        series_data = series_dict["series_data"]
+        pixel_spacing = series_dict["PixelSpacing"] + [series_dict["SliceThickness"]]
+        sorted_img_paths, sorted_slice_locs = self.order_slices(
+            img_paths, slice_locations, reverse=True
+        )
+
+        device = DEVICE_ID[series_data["Manufacturer"]]
+
+        studyuid = exam_dict["StudyInstanceUID"]
+        bridge_uid = exam_dict["bridge_uid"]
+
+        y, y_seq, y_mask, time_at_event = self.get_label(exam_dict, mrn_row)
+
+        sample = {
+            "paths": sorted_img_paths,
+            "slice_locations": sorted_slice_locs,
+            "y": int(y),
+            "time_at_event": time_at_event,
+            "y_seq": y_seq,
+            "y_mask": y_mask,
+            "exam": int(
+                "{}{}".format(
+                    studyuid.replace(".", "")[-5:],
+                    series_id.replace(".", "")[-5:],
+                )
+            ),  # last 5 of study id + last 5 of series id
+            "study": studyuid,
+            "series": series_id,
+            "pid": mrn_row["pid"],
+            "bridge_uid": bridge_uid,
+            "device": device,
+            "lung_rads": exam_dict["LR Score"],
+            "cancer_location": np.zeros(14),  # mgh has no annotations
+            "cancer_laterality": np.zeros(
+                3, dtype=np.int
+            ),  # has to be int, while cancer_location has to be float
+            "num_original_slices": len(series_dict["paths"]),
+            "marital_status": exam_dict["marital_status"],
+            "religion": exam_dict["religion"],
+            "primary_site": exam_dict["Primary Site"],
+            "laterality1": exam_dict["Laterality"],
+            "laterality2": exam_dict["Laterality.1"],
+            "icdo3": exam_dict["Histo/Behavior ICD-O-3"],
+            "pixel_spacing": pixel_spacing,
+            "slice_thickness": self.get_slice_thickness_class(pixel_spacing[-1]),
+        }
+
+        if self.args.use_risk_factors:
+            sample["risk_factors"] = self.get_risk_factors(exam_dict, return_dict=False)
+
+        if self.args.use_annotations:
+            # mgh has no annotations, so set everything to zero / false
+            sample["volume_annotations"] = np.array([0 for _ in sample["paths"]])
+            sample["annotations"] = [
+                {"image_annotations": None} for path in sample["paths"]
+            ]
+        return sample
+
+    def get_label(self, exam_dict, mrn_row):
+        is_cancer_cohort = exam_dict["Future_cancer"].lower().strip() == "yes"
+        days_to_cancer = exam_dict["days_before_cancer_dx"]
+
+        y = False
+        if (
+            is_cancer_cohort
+            and (not np.isnan(days_to_cancer))
+            and (days_to_cancer > -1)
+        ):
+            years_to_cancer = int(days_to_cancer // 365)
+            y = years_to_cancer < self.args.max_followup
+
+        y_seq = np.zeros(self.args.max_followup)
+
+        if y:
+            time_at_event = years_to_cancer
+            y_seq[years_to_cancer:] = 1
+        else:
+            if is_cancer_cohort:
+                assert (days_to_cancer < 0) or (
+                    years_to_cancer >= self.args.max_followup
+                )
+                time_at_event = self.args.max_followup - 1
+            else:
+                days_to_last_neg_followup = exam_dict["days_to_last_follow_up"]
+                years_to_last_neg_followup = int(days_to_last_neg_followup // 365)
+                time_at_event = min(
+                    years_to_last_neg_followup, self.args.max_followup - 1
+                )
+
+        y_mask = np.array(
+            [1] * (time_at_event + 1)
+            + [0] * (self.args.max_followup - (time_at_event + 1))
+        )
+        y_mask = y_mask[: self.args.max_followup]
+        return y, y_seq.astype("float64"), y_mask.astype("float64"), time_at_event
+
+    def get_risk_factors(self, exam_dict, return_dict=False):
+        risk_factors = {
+            "race": exam_dict["race"],
+            "pack_years": exam_dict["Packs Years"],
+            "age_at_exam": exam_dict["age at the exam"],
+            "gender": exam_dict["gender"],
+            "smoking_status": exam_dict["Smoking Status"],
+            "lung_rads": exam_dict["LR Score"],
+            "years_since_quit_smoking": exam_dict["Year Since Last Smoked"],
+        }
+
+        if return_dict:
+            return risk_factors
+        else:
+            return np.array(
+                [v for v in risk_factors.values() if not isinstance(v, str)]
+            )
+
+    def is_localizer(self, series_dict):
+        is_localizer = "LOCALIZER" in literal_eval(series_dict["ImageType"])
+        return is_localizer
+
+    @staticmethod
+    def set_args(args):
+        args.num_classes = args.max_followup
+
+    def get_summary_statement(self, dataset, split_group):
+        summary = "Constructed MGH CT Cancer Survival {} dataset with {} records, {} exams, {} patients, and the following class balance \n {}"
+        class_balance = Counter([d["y"] for d in dataset])
+        exams = set([d["exam"] for d in dataset])
+        patients = set([d["pid"] for d in dataset])
+        statement = summary.format(
+            split_group,
+            len(dataset),
+            len(exams),
+            len(patients),
+            class_balance,
+        )
+        statement += "\n" + "Censor Times: {}".format(
+            Counter([d["time_at_event"] for d in dataset])
+        )
+        return statement
+
+    def assign_splits(self, meta):
+        for idx in range(len(meta)):
+            meta[idx]["split"] = np.random.choice(
+                ["train", "dev", "test"], p=self.args.split_probs
+            )