Switch to side-by-side view

--- a
+++ b/dataset_builder/preprocess_label.py
@@ -0,0 +1,549 @@
+import os
+import json
+import argparse
+import numpy as np
+import pandas as pd
+
+# from tqdm import tqdm
+
+
+def config():
+    parser = argparse.ArgumentParser(description="preprocessing label information")
+
+    # debug
+    parser.add_argument("--debug", action="store_true", help="debug mode")
+    parser.add_argument("--debug_nrows", default=100000, type=int, help="debug mode - nrows")
+
+    # file directory
+    parser.add_argument("--save_dir", default="./preprocessed_data", type=str)
+    parser.add_argument("--mimic_cxr_jpg_dir", default="../mimic-cxr-jpg/", type=str)
+    parser.add_argument("--chest_imagenome_dir", default="../chest-imagenome/", type=str)
+
+    args = parser.parse_args()
+
+    return args
+
+
+class LabelPreprocessor:
+    def __init__(
+        self,
+        args,
+    ):
+        self.args = args
+
+        # check debug
+        self.nrows = args.debug_nrows if args.debug else None
+
+        # load
+        self._load_silver_attributes_relations()
+        self._load_gold_attributes_relations()
+        self._load_and_modify_chest_imagenome_ontology()
+
+    def _load_silver_attributes_relations(self):
+        # read
+        silver_attributes_relations = pd.read_csv(
+            os.path.join(self.args.chest_imagenome_dir, "silver_dataset/scene_tabular/attribute_relations_tabular.txt"),
+            sep="\t",
+            # usecols=None,
+        )
+        self.silver_attributes_relations = silver_attributes_relations
+
+        # arrange
+        silver_dataset = silver_attributes_relations.copy()
+        silver_dataset["image_id"] = silver_dataset["image_id"].str.replace(".dcm", "")
+        silver_dataset["object_id"] = silver_dataset["image_id"] + "_" + silver_dataset["bbox"]
+        silver_dataset["sent_loc"] = silver_dataset["row_id"].apply(lambda x: float(x.split("|")[-1]))
+        silver_dataset["annot_id"] = (
+            silver_dataset["study_id"].astype(str) + "|" + silver_dataset["bbox"] + "|" + silver_dataset["relation"].astype(str) + "|" + silver_dataset["label_name"]
+        )  # erase sent_loc
+        silver_dataset = silver_dataset[["study_id", "image_id", "sent_loc", "bbox", "relation", "label_name", "categoryID", "annot_id", "object_id"]]
+
+        self.silver_dataset = silver_dataset.reset_index(drop=True)
+
+    def _load_gold_attributes_relations(self):
+        # read
+        gold_attributes_relations = pd.read_csv(
+            os.path.join(self.args.chest_imagenome_dir, "gold_dataset/gold_attributes_relations_500pts_500studies1st.txt"),
+            sep="\t",
+            # usecols=None,
+        )
+        self.gold_attributes_relations = gold_attributes_relations
+
+        # arrange
+        gold_dataset = gold_attributes_relations.copy()
+        gold_dataset["image_id"] = gold_dataset["image_id"].str.replace(".dcm", "")
+        gold_dataset["object_id"] = gold_dataset["image_id"] + "_" + gold_dataset["bbox"]
+        gold_dataset["sent_loc"] = gold_dataset["row_id"].apply(lambda x: float(x.split("|")[-1]))
+        gold_dataset["annot_id"] = gold_dataset["study_id"].astype(str) + "|" + gold_dataset["bbox"] + "|" + gold_dataset["relation"].astype(str) + "|" + gold_dataset["label_name"]  # erase sent_loc
+        gold_dataset = gold_dataset[["study_id", "image_id", "sent_loc", "bbox", "relation", "label_name", "categoryID", "annot_id", "object_id"]]
+
+        self.gold_dataset = gold_dataset.reset_index(drop=True)
+
+    def _load_and_modify_chest_imagenome_ontology(self):
+        target_categories = ["anatomicalfinding", "technicalassessment", "disease", "tubesandlines", "device"]  # nlp
+
+        # object list: 38 objects in total
+        with open(os.path.join(self.args.chest_imagenome_dir, "semantics/objects_extracted_from_reports_v1.txt"), "r") as f:
+            obj_v1 = [line.strip().replace(",", "") for line in f.readlines()]
+        print(f"chest imagenome ontology - {len(obj_v1)} objects loaded: {obj_v1}")
+
+        # attribute list: 76 attributes in total
+        # anatomicalfinding(43), technicalassessment(5), disease(10), tubesandlines (12), device (5)
+        with open(os.path.join(self.args.chest_imagenome_dir, "semantics/attribute_relations_v1.txt"), "r") as f:
+            cat_attr_v1 = {}
+            for line in f.readlines():
+                cat, rel, attr = line.strip().replace(",", "").split("|")
+                if cat in target_categories:
+                    if cat not in cat_attr_v1:
+                        cat_attr_v1[cat] = []
+                    cat_attr_v1[cat] += [attr]
+            attr_v1 = [vv for v in cat_attr_v1.values() for vv in v]
+            attr_v1.remove("sternotomy wires")
+        print(f"chest imagenome ontology - {len(attr_v1)} attributes loaded: {attr_v1}")
+
+        attr_cat_v1 = {}
+        for k, v in cat_attr_v1.items():
+            for vv in v:
+                attr_cat_v1[vv] = k
+
+        # load ontology
+        with open(os.path.join(self.args.chest_imagenome_dir, "semantics/label_to_UMLS_mapping.json"), "r") as f:
+            ont = json.load(f)
+
+            # NOTE: modify ontology
+            # added; new obj-obj p2c ontology
+            ont["all_children"]["lungs"] += [
+                "hilar structures",
+                "right hilar structures",
+                "left hilar structures",
+                # "main stem bronchus",
+                # "left main stem bronchus",
+                # "right main stem bronchus",
+            ]
+            ont["all_children"]["left lung"] += [
+                "left hilar structures",
+            ]
+            ont["all_children"]["right lung"] += [
+                "right hilar structures",
+            ]
+            ont["all_children"]["hilar structures"] = [
+                "right hilar structures",
+                "left hilar structures",
+            ]
+            ont["all_children"]["mediastinum"] += [
+                "aortic arch",
+                "svc",
+            ]
+
+            # build attr-attr p2c ontology
+            attr_p2c_map = {}
+            for parent, childs in ont["all_children"].items():
+                if parent in attr_v1:
+                    attr_p2c_map[parent] = [child for child in childs if child in attr_v1]
+
+            # build attr-attr c2p ontology (reversed)
+            attr_c2p_map = {}
+            for parent, childs in attr_p2c_map.items():
+                for child in childs:
+                    if child not in attr_c2p_map:
+                        attr_c2p_map[child] = []
+                    attr_c2p_map[child].append(parent)
+
+            # build obj-obj p2c ontology
+            obj_p2c_map = {}
+            for parent, childs in ont["all_children"].items():
+                if parent in obj_v1:
+                    new_childs = [child for child in childs if child in obj_v1]
+                    if len(new_childs) > 0:
+                        obj_p2c_map[parent] = new_childs
+
+            # build obj-obj c2p ontology
+            obj_c2p_map = {}
+            for parent, childs in obj_p2c_map.items():
+                for child in childs:
+                    if child in obj_c2p_map.keys():
+                        obj_c2p_map[child].append(parent)
+                    else:
+                        obj_c2p_map[child] = [parent]
+
+            # NOTE: modify ontology
+            # added: obj-attr possible relationship w/ obj-obj p2c ontology
+            for k, v in ont["possible_attribute_of"].items():
+                for child in obj_c2p_map:
+                    parents = obj_c2p_map[child]
+                    for parent in parents:
+                        if (child in v) and (k in attr_v1):
+                            if parent not in v:
+                                ont["possible_attribute_of"][k] += [parent]
+                                print(f"{k}-{child} \t => {k}-{parent}")
+
+            # NOTE: modify ontology
+            # added: obj-attr possible relationship w/ attr-attr p2c ontology
+            for kc in ont["possible_attribute_of"].keys():
+                if kc in attr_c2p_map:
+                    kps = attr_c2p_map[kc]
+                    for kp in kps:
+                        vc = ont["possible_attribute_of"][kc]
+                        vp = ont["possible_attribute_of"][kp]
+                        vc_diff_vp = set(vc) - set(vp)
+                        # vp_diff_vc = set(vp) - set(vc)
+                        if len(vc_diff_vp) > 0:
+                            ont["possible_attribute_of"][kp] += list(vc_diff_vp)
+                            print(f"{kc}:{vc_diff_vp} \t => {kp}+={vc_diff_vp}")
+
+        # list
+        self.obj_v1 = obj_v1
+        self.attr_v1 = attr_v1
+        # cat-attr
+        self.cat_attr_v1 = cat_attr_v1
+        self.attr_cat_v1 = attr_cat_v1
+        # p2c
+        self.attr_p2c_map = attr_p2c_map
+        self.attr_c2p_map = attr_c2p_map
+        self.obj_p2c_map = obj_p2c_map
+        self.obj_c2p_map = obj_c2p_map
+        # full ont
+        self.ont = ont
+
+    def get_dataset_by_flag(self, flag="silver"):
+        datasets = {
+            "silver": self.silver_dataset,
+            "gold": self.gold_dataset,
+            "gold+": self.gold_dataset,
+        }
+
+        if flag not in datasets:
+            raise ValueError("flag must be either 'silver', 'gold', or 'gold+'")
+
+        dataset = datasets[flag].copy()
+        return dataset
+
+    def keep_dataset_by_flag(self, dataset, flag="silver"):
+        if flag not in ["silver", "gold", "gold+"]:
+            raise ValueError("flag must be either 'silver', 'gold', or 'gold+'")
+
+        if flag == "silver":
+            self.silver_dataset = dataset
+        else:
+            self.gold_dataset = dataset
+
+    def preprocessLabel_category(self, flag="silver"):
+        dataset = self.get_dataset_by_flag(flag)
+
+        dataset = dataset[dataset["categoryID"] != "nlp"]
+        assert "normal" not in dataset.label_name.unique()
+        assert "abnormal" not in dataset.label_name.unique()
+        dataset = dataset.reset_index(drop=True)
+
+        self.keep_dataset_by_flag(dataset, flag)
+
+    def preprocessLabel_bbox(self, flag="silver"):
+        dataset = self.get_dataset_by_flag(flag)
+
+        dataset = dataset[dataset["bbox"] != "unknown"]
+        dataset = dataset.reset_index(drop=True)
+
+        self.keep_dataset_by_flag(dataset, flag)
+
+    def preprocessLabel_cohort(self, flag="silver"):
+        if flag == "silver":
+            silver_dataset = self.silver_dataset.copy()
+            silver_cohort = pd.read_csv(
+                os.path.join(self.args.save_dir, f"cohort_silver.csv"),
+            )
+            silver_iids = silver_cohort["image_id"].unique()
+            silver_dataset = silver_dataset[silver_dataset.image_id.isin(silver_iids)]
+            silver_dataset["subject_id"] = silver_dataset["image_id"].map(silver_cohort.set_index("image_id")["subject_id"])
+            silver_dataset["subject_id"] = silver_dataset["subject_id"].astype(int)
+            assert silver_dataset["subject_id"].isna().sum() == 0
+
+            silver_dataset = silver_dataset[["subject_id", "study_id", "image_id", "sent_loc", "bbox", "relation", "label_name", "categoryID", "annot_id", "object_id"]]
+            self.silver_dataset = silver_dataset.reset_index(drop=True)
+
+        elif flag == "gold":
+            gold_dataset = self.gold_dataset.copy()
+            gold_cohort = pd.read_csv(
+                os.path.join(self.args.save_dir, f"cohort_gold.csv"),
+            )
+            gold_iids = gold_cohort["image_id"].unique()
+            gold_dataset = gold_dataset[gold_dataset["image_id"].isin(gold_iids)]
+            assert len(gold_dataset["image_id"].unique()) == 500
+            gold_dataset["subject_id"] = gold_dataset["image_id"].map(gold_cohort.set_index("image_id")["subject_id"])
+            gold_dataset["subject_id"] = gold_dataset["subject_id"].astype(int)
+            assert gold_dataset["subject_id"].isna().sum() == 0
+
+            gold_dataset = gold_dataset[["subject_id", "study_id", "image_id", "sent_loc", "bbox", "relation", "label_name", "categoryID", "annot_id", "object_id"]]
+            self.gold_dataset = gold_dataset.reset_index(drop=True)
+
+        elif flag == "gold+":
+            gold_dataset = self.gold_dataset.copy()
+            gold_cohort = pd.read_csv(
+                os.path.join(self.args.save_dir, f"cohort_gold.csv"),
+            )
+            gold_iids = gold_cohort["image_id"].unique()  # 2338 images
+            gold_dataset = gold_dataset[gold_dataset["image_id"].isin(gold_iids)]  # 500 images
+            assert len(gold_dataset["image_id"].unique()) == 500
+
+            # NOTE: For studies where order >= 2, we use the silver labels
+            # the pre-processing code of silver dataset (until here) should be run again
+            self._load_silver_attributes_relations()
+            self.preprocessLabel_category(flag="silver")
+            self.preprocessLabel_bbox(flag="silver")
+            silver_dataset = self.silver_dataset.copy()
+            silver_dataset = silver_dataset[(~silver_dataset["image_id"].isin(gold_dataset.image_id.unique())) & (silver_dataset["image_id"].isin(gold_iids))]
+
+            # The number of images of gold patients w/ 1st study (in cohort_gold.csv)
+            print("# of images for gold patients w/ 1st study", gold_dataset["image_id"].nunique())
+            # The number of images of gold patients w/ >=2nd study (in cohort_gold.csv)
+            print("# of images for gold patients w/ >=2nd study", gold_cohort[gold_cohort["StudyOrder"] >= 2]["image_id"].nunique())
+            # The number of images of gold patients w/ >=2nd study (in silver_dataset.csv ~ attribute_relations_tabular.txt)
+            print("# of images for gold patients w/ >=2nd study with silver labels", silver_dataset["image_id"].nunique())
+            print("NOTE: Some gold studies (>=2nd) are missing in the silver dataset!!!")
+            gold_dataset = pd.concat([gold_dataset, silver_dataset], axis=0)
+
+            gold_dataset["subject_id"] = gold_dataset["image_id"].map(gold_cohort.set_index("image_id")["subject_id"])
+            gold_dataset["subject_id"] = gold_dataset["subject_id"].astype(int)
+            assert gold_dataset["subject_id"].isna().sum() == 0
+
+            gold_dataset = gold_dataset[["subject_id", "study_id", "image_id", "sent_loc", "bbox", "relation", "label_name", "categoryID", "annot_id", "object_id"]]
+            self.gold_dataset = gold_dataset.reset_index(drop=True)
+            print(gold_dataset.head())
+
+        else:
+            raise ValueError("flag must be either 'silver' or 'gold'")
+
+    def aggregate_labels_by_report_level(self, flag="silver", agg_option="last"):
+        sort_columns = ["subject_id", "study_id", "image_id", "sent_loc", "bbox", "label_name"]
+        agg_columns = ["subject_id", "study_id", "image_id", "bbox", "label_name"]
+
+        dataset = self.get_dataset_by_flag(flag)
+
+        dataset = dataset.sort_values(by=sort_columns)
+        dataset = dataset.drop_duplicates(subset=agg_columns, keep=agg_option)
+        dataset = dataset.reset_index(drop=True)
+
+        self.keep_dataset_by_flag(dataset, flag)
+
+    def apply_attribute_p2c_ontology_to_dataset(self, flag="silver"):
+        def _apply_attribute_p2c_ontology_to_dataset(dataset, attr_c2p_map, attr_cat_v1):
+            sort_columns = ["subject_id", "study_id", "image_id", "sent_loc", "bbox", "label_name"]
+            agg_columns = ["subject_id", "study_id", "image_id", "bbox", "label_name"]
+            agg_option = "last"
+
+            dataset_parents = None
+            for child, parents in attr_c2p_map.items():
+                for parent in parents:
+                    dataset_parent = dataset[dataset.label_name == child].copy()
+                    dataset_parent["label_name"] = parent
+                    dataset_parent["categoryID"] = attr_cat_v1[parent]
+                    dataset_parent["annot_id"] = (
+                        dataset_parent["study_id"].astype(str) + "|" + dataset_parent["bbox"] + "|" + dataset_parent["relation"].astype(str) + "|" + dataset_parent["label_name"]
+                    )
+                    dataset_parents = pd.concat([dataset_parents, dataset_parent], axis=0) if dataset_parents is not None else dataset_parent
+            dataset = pd.concat([dataset, dataset_parents], axis=0)
+            dataset = dataset.sort_values(by=sort_columns)
+            dataset = dataset.drop_duplicates(subset=agg_columns, keep=agg_option)
+            return dataset
+
+        dataset = self.get_dataset_by_flag(flag)
+
+        dataset = _apply_attribute_p2c_ontology_to_dataset(dataset=dataset, attr_c2p_map=self.attr_c2p_map, attr_cat_v1=self.attr_cat_v1)
+        dataset = dataset.reset_index(drop=True)
+
+        self.keep_dataset_by_flag(dataset, flag)
+
+    def apply_object_p2c_ontology_to_dataset(self, flag="silver"):
+        def _apply_object_p2c_ontology_to_dataset(dataset, obj_c2p_map):
+            sort_columns = ["subject_id", "study_id", "image_id", "sent_loc", "bbox", "label_name"]
+            agg_columns = ["subject_id", "study_id", "image_id", "bbox", "label_name"]
+            agg_option = "last"
+
+            dataset_parents = None
+            for child, parents in obj_c2p_map.items():
+                for parent in parents:
+                    dataset_parent = dataset[dataset.bbox == child].copy()
+                    dataset_parent["bbox"] = parent
+                    dataset_parent["annot_id"] = (
+                        dataset_parent["study_id"].astype(str) + "|" + dataset_parent["bbox"] + "|" + dataset_parent["relation"].astype(str) + "|" + dataset_parent["label_name"]
+                    )
+                    dataset_parent["object_id"] = dataset_parent["image_id"] + "_" + dataset_parent["bbox"]
+                    dataset_parents = pd.concat([dataset_parents, dataset_parent], axis=0) if dataset_parents is not None else dataset_parent
+            dataset = pd.concat([dataset, dataset_parents], axis=0)
+            dataset = dataset.sort_values(by=sort_columns)
+            dataset = dataset.drop_duplicates(subset=agg_columns, keep=agg_option)
+            return dataset
+
+        dataset = self.get_dataset_by_flag(flag)
+
+        dataset = _apply_object_p2c_ontology_to_dataset(dataset=dataset, obj_c2p_map=self.obj_c2p_map)
+        dataset = dataset.reset_index(drop=True)
+
+        self.keep_dataset_by_flag(dataset, flag)
+
+    def apply_object_attribute_possible_relationship_to_dataset(self, flag="silver"):
+        def _apply_object_attribute_possible_relationship_to_dataset(dataset, ont):
+            removed_annot_ids = []
+            obj_attr_combs = dataset[["bbox", "label_name"]].drop_duplicates().values
+            for obj_attr_comb in obj_attr_combs:
+                obj, attr = obj_attr_comb
+                if attr not in ["abnormal", "normal", "artifact"]:
+                    possible_objs = ont["possible_attribute_of"][attr]
+                    if obj not in possible_objs:
+                        removed_annot_ids += dataset[(dataset["bbox"] == obj) & (dataset["label_name"] == attr)].annot_id.tolist()
+
+            dataset = dataset[~dataset.annot_id.isin(removed_annot_ids)]
+            return dataset
+
+        dataset = self.get_dataset_by_flag(flag)
+
+        dataset = _apply_object_attribute_possible_relationship_to_dataset(dataset=dataset, ont=self.ont)
+        dataset = dataset.reset_index(drop=True)
+
+        self.keep_dataset_by_flag(dataset, flag)
+
+    def sanity_check(self, flag="silver"):
+        if flag == "silver":
+            dataset = self.silver_dataset
+        elif flag in ["gold", "gold+"]:
+            dataset = self.gold_dataset
+        else:
+            raise ValueError("flag must be either 'silver' or 'gold'")
+
+        assert (dataset.object_id.apply(lambda x: x.split("_")[-1]) != dataset.bbox).sum() == 0
+        assert (dataset.object_id.apply(lambda x: x.split("_")[0]) != dataset.image_id).sum() == 0
+        assert (dataset.annot_id.apply(lambda x: x.split("|")[0]) != dataset.study_id.astype(str)).sum() == 0
+        assert (dataset.annot_id.apply(lambda x: x.split("|")[1]) != dataset.bbox).sum() == 0
+        assert (dataset.annot_id.apply(lambda x: x.split("|")[2]) != dataset.relation.astype(str)).sum() == 0
+        assert (dataset.annot_id.apply(lambda x: x.split("|")[-1]) != dataset.label_name).sum() == 0
+
+    def remove_minority_label(self, flag="silver"):
+        REMOVED_OBJS = ["left arm", "right arm"]
+        REMOVED_ATTRS = [
+            "artifact",
+            "bronchiectasis",
+            "pigtail catheter",
+            "skin fold",
+            "aortic graft/repair",
+            "diaphragmatic eventration (benign)",
+            "sternotomy wires",
+        ]
+
+        dataset = self.get_dataset_by_flag(flag)
+
+        dataset = dataset[(~dataset["bbox"].isin(REMOVED_OBJS)) & (~dataset["label_name"].isin(REMOVED_ATTRS))]
+        dataset = dataset.reset_index(drop=True)
+
+        self.keep_dataset_by_flag(dataset, flag)
+
+    def restore_normal_relation(self, flag="silver"):
+        sort_columns = ["subject_id", "study_id", "image_id", "bbox", "label_name", "relation"]
+        agg_columns = ["subject_id", "study_id", "image_id", "bbox", "label_name"]
+        agg_option = "last"
+
+        # compute possible object-attribute combinations (in here, a total of 609 combinations)
+        objattr_combs = [(vv, 0, k, self.attr_cat_v1[k]) for k, v in self.ont["possible_attribute_of"].items() if k in self.attr_v1 for vv in v if vv in self.obj_v1]
+        objattr_combs = pd.DataFrame(objattr_combs, columns=["bbox", "relation", "label_name", "categoryID"])
+
+        dataset = self.get_dataset_by_flag(flag)
+
+        all_data = pd.merge(dataset[["subject_id", "study_id", "image_id"]].drop_duplicates().assign(key=1), objattr_combs.assign(key=1), on="key").drop("key", axis=1)
+        all_data["annot_id"] = all_data["study_id"].astype(str) + "|" + all_data["bbox"] + "|" + all_data["relation"].astype(str) + "|" + all_data["label_name"]
+        all_data["object_id"] = all_data["image_id"] + "_" + all_data["bbox"]
+
+        dataset = pd.concat([dataset, all_data])
+        dataset = dataset.sort_values(by=sort_columns, ascending=True)
+        dataset = dataset.drop_duplicates(subset=agg_columns, keep=agg_option)
+        dataset = dataset.reset_index(drop=True)
+
+        self.keep_dataset_by_flag(dataset, flag)
+
+    def save_label_dataset(self, flag="silver"):
+        dataset = self.get_dataset_by_flag(flag)
+
+        dataset.to_csv(os.path.join(self.args.save_dir, f"{flag}_dataset.csv"), index=False)
+        print(f"{flag} dataset saved, shape: {dataset.shape}")
+
+    def split_and_save_dataset(self):
+        from sklearn.model_selection import train_test_split
+
+        # 0
+        # gold_dataset = self.gold_dataset
+        # silver_dataset = self.silver_dataset
+        # To avoid the problem of the order of the dataset, we load the dataset from the saved csv file.
+        gold_dataset = pd.read_csv(os.path.join(self.args.save_dir, "gold_dataset.csv"))
+        silver_dataset = pd.read_csv(os.path.join(self.args.save_dir, "silver_dataset.csv"))
+
+        # 1: divide by relation type
+        silver_dataset_abn = silver_dataset[silver_dataset["relation"] == 1]
+        silver_dataset_nm = silver_dataset[~silver_dataset.study_id.isin(silver_dataset_abn.study_id.unique())]
+
+        SEED = 103
+        TEST_SIZE = 0.05
+
+        # 2: split subject_ids (train, valid)
+        train_subject_ids_abn = silver_dataset_abn.subject_id.unique()
+        train_subject_ids_abn, valid_subject_ids_abn = train_test_split(train_subject_ids_abn, test_size=TEST_SIZE, random_state=SEED)
+        train_subject_ids_nm = silver_dataset_nm.subject_id.unique()
+        train_subject_ids_nm, valid_subject_ids_nm = train_test_split(train_subject_ids_nm, test_size=TEST_SIZE, random_state=SEED)
+
+        train_sids = (
+            silver_dataset_abn[silver_dataset_abn.subject_id.isin(train_subject_ids_abn)].study_id.unique().tolist()
+            + silver_dataset_nm[silver_dataset_nm.subject_id.isin(train_subject_ids_nm)].study_id.unique().tolist()
+        )
+        valid_sids = (
+            silver_dataset_abn[silver_dataset_abn.subject_id.isin(valid_subject_ids_abn)].study_id.unique().tolist()
+            + silver_dataset_nm[silver_dataset_nm.subject_id.isin(valid_subject_ids_nm)].study_id.unique().tolist()
+        )
+
+        train_dataset = silver_dataset.loc[silver_dataset["study_id"].isin(train_sids)]
+        valid_dataset = silver_dataset.loc[silver_dataset["study_id"].isin(valid_sids)]
+        # test_dataset = gold_dataset.loc[gold_dataset["study_id"].isin(test_sids)]
+        test_dataset = gold_dataset.copy()
+
+        # 3: save
+        train_dataset = train_dataset.reset_index(drop=True)
+        valid_dataset = valid_dataset.reset_index(drop=True)
+        test_dataset = test_dataset.reset_index(drop=True)
+
+        train_dataset.to_csv(os.path.join(self.args.save_dir, "train_dataset.csv"), index=False)  # silver
+        valid_dataset.to_csv(os.path.join(self.args.save_dir, "valid_dataset.csv"), index=False)  # silver
+        test_dataset.to_csv(os.path.join(self.args.save_dir, "test_dataset.csv"), index=False)  # gold
+        print("silver, gold dataset -> train, valid, test dataset saved")
+
+
+def main(args):
+    # load preprocessor
+    label_preproc = LabelPreprocessor(args)
+
+    FLAGS = ["silver", "gold", "gold+"]
+    # FLAGS = ["gold+"]
+
+    for flag in FLAGS:
+        # preprocessing - labels
+        label_preproc.preprocessLabel_category(flag=flag)
+        label_preproc.preprocessLabel_bbox(flag=flag)
+        label_preproc.preprocessLabel_cohort(flag=flag)
+        label_preproc.aggregate_labels_by_report_level(flag=flag, agg_option="last")
+
+        # preprocessing - ontology
+        label_preproc.apply_attribute_p2c_ontology_to_dataset(flag=flag)
+        label_preproc.apply_object_p2c_ontology_to_dataset(flag=flag)
+        label_preproc.apply_object_attribute_possible_relationship_to_dataset(flag=flag)
+
+        label_preproc.sanity_check(flag=flag)
+
+        # preprocessing
+        label_preproc.remove_minority_label(flag=flag)
+        label_preproc.restore_normal_relation(flag=flag)
+
+        # save
+        label_preproc.save_label_dataset(flag=flag)
+
+    if "silver" in FLAGS and "gold" in FLAGS:
+        label_preproc.split_and_save_dataset()
+
+
+if __name__ == "__main__":
+    args = config()
+    main(args)
+    print("Done")