--- a +++ b/dataset_builder/preprocess_label.py @@ -0,0 +1,549 @@ +import os +import json +import argparse +import numpy as np +import pandas as pd + +# from tqdm import tqdm + + +def config(): + parser = argparse.ArgumentParser(description="preprocessing label information") + + # debug + parser.add_argument("--debug", action="store_true", help="debug mode") + parser.add_argument("--debug_nrows", default=100000, type=int, help="debug mode - nrows") + + # file directory + parser.add_argument("--save_dir", default="./preprocessed_data", type=str) + parser.add_argument("--mimic_cxr_jpg_dir", default="../mimic-cxr-jpg/", type=str) + parser.add_argument("--chest_imagenome_dir", default="../chest-imagenome/", type=str) + + args = parser.parse_args() + + return args + + +class LabelPreprocessor: + def __init__( + self, + args, + ): + self.args = args + + # check debug + self.nrows = args.debug_nrows if args.debug else None + + # load + self._load_silver_attributes_relations() + self._load_gold_attributes_relations() + self._load_and_modify_chest_imagenome_ontology() + + def _load_silver_attributes_relations(self): + # read + silver_attributes_relations = pd.read_csv( + os.path.join(self.args.chest_imagenome_dir, "silver_dataset/scene_tabular/attribute_relations_tabular.txt"), + sep="\t", + # usecols=None, + ) + self.silver_attributes_relations = silver_attributes_relations + + # arrange + silver_dataset = silver_attributes_relations.copy() + silver_dataset["image_id"] = silver_dataset["image_id"].str.replace(".dcm", "") + silver_dataset["object_id"] = silver_dataset["image_id"] + "_" + silver_dataset["bbox"] + silver_dataset["sent_loc"] = silver_dataset["row_id"].apply(lambda x: float(x.split("|")[-1])) + silver_dataset["annot_id"] = ( + silver_dataset["study_id"].astype(str) + "|" + silver_dataset["bbox"] + "|" + silver_dataset["relation"].astype(str) + "|" + silver_dataset["label_name"] + ) # erase sent_loc + silver_dataset = silver_dataset[["study_id", "image_id", "sent_loc", "bbox", "relation", "label_name", "categoryID", "annot_id", "object_id"]] + + self.silver_dataset = silver_dataset.reset_index(drop=True) + + def _load_gold_attributes_relations(self): + # read + gold_attributes_relations = pd.read_csv( + os.path.join(self.args.chest_imagenome_dir, "gold_dataset/gold_attributes_relations_500pts_500studies1st.txt"), + sep="\t", + # usecols=None, + ) + self.gold_attributes_relations = gold_attributes_relations + + # arrange + gold_dataset = gold_attributes_relations.copy() + gold_dataset["image_id"] = gold_dataset["image_id"].str.replace(".dcm", "") + gold_dataset["object_id"] = gold_dataset["image_id"] + "_" + gold_dataset["bbox"] + gold_dataset["sent_loc"] = gold_dataset["row_id"].apply(lambda x: float(x.split("|")[-1])) + gold_dataset["annot_id"] = gold_dataset["study_id"].astype(str) + "|" + gold_dataset["bbox"] + "|" + gold_dataset["relation"].astype(str) + "|" + gold_dataset["label_name"] # erase sent_loc + gold_dataset = gold_dataset[["study_id", "image_id", "sent_loc", "bbox", "relation", "label_name", "categoryID", "annot_id", "object_id"]] + + self.gold_dataset = gold_dataset.reset_index(drop=True) + + def _load_and_modify_chest_imagenome_ontology(self): + target_categories = ["anatomicalfinding", "technicalassessment", "disease", "tubesandlines", "device"] # nlp + + # object list: 38 objects in total + with open(os.path.join(self.args.chest_imagenome_dir, "semantics/objects_extracted_from_reports_v1.txt"), "r") as f: + obj_v1 = [line.strip().replace(",", "") for line in f.readlines()] + print(f"chest imagenome ontology - {len(obj_v1)} objects loaded: {obj_v1}") + + # attribute list: 76 attributes in total + # anatomicalfinding(43), technicalassessment(5), disease(10), tubesandlines (12), device (5) + with open(os.path.join(self.args.chest_imagenome_dir, "semantics/attribute_relations_v1.txt"), "r") as f: + cat_attr_v1 = {} + for line in f.readlines(): + cat, rel, attr = line.strip().replace(",", "").split("|") + if cat in target_categories: + if cat not in cat_attr_v1: + cat_attr_v1[cat] = [] + cat_attr_v1[cat] += [attr] + attr_v1 = [vv for v in cat_attr_v1.values() for vv in v] + attr_v1.remove("sternotomy wires") + print(f"chest imagenome ontology - {len(attr_v1)} attributes loaded: {attr_v1}") + + attr_cat_v1 = {} + for k, v in cat_attr_v1.items(): + for vv in v: + attr_cat_v1[vv] = k + + # load ontology + with open(os.path.join(self.args.chest_imagenome_dir, "semantics/label_to_UMLS_mapping.json"), "r") as f: + ont = json.load(f) + + # NOTE: modify ontology + # added; new obj-obj p2c ontology + ont["all_children"]["lungs"] += [ + "hilar structures", + "right hilar structures", + "left hilar structures", + # "main stem bronchus", + # "left main stem bronchus", + # "right main stem bronchus", + ] + ont["all_children"]["left lung"] += [ + "left hilar structures", + ] + ont["all_children"]["right lung"] += [ + "right hilar structures", + ] + ont["all_children"]["hilar structures"] = [ + "right hilar structures", + "left hilar structures", + ] + ont["all_children"]["mediastinum"] += [ + "aortic arch", + "svc", + ] + + # build attr-attr p2c ontology + attr_p2c_map = {} + for parent, childs in ont["all_children"].items(): + if parent in attr_v1: + attr_p2c_map[parent] = [child for child in childs if child in attr_v1] + + # build attr-attr c2p ontology (reversed) + attr_c2p_map = {} + for parent, childs in attr_p2c_map.items(): + for child in childs: + if child not in attr_c2p_map: + attr_c2p_map[child] = [] + attr_c2p_map[child].append(parent) + + # build obj-obj p2c ontology + obj_p2c_map = {} + for parent, childs in ont["all_children"].items(): + if parent in obj_v1: + new_childs = [child for child in childs if child in obj_v1] + if len(new_childs) > 0: + obj_p2c_map[parent] = new_childs + + # build obj-obj c2p ontology + obj_c2p_map = {} + for parent, childs in obj_p2c_map.items(): + for child in childs: + if child in obj_c2p_map.keys(): + obj_c2p_map[child].append(parent) + else: + obj_c2p_map[child] = [parent] + + # NOTE: modify ontology + # added: obj-attr possible relationship w/ obj-obj p2c ontology + for k, v in ont["possible_attribute_of"].items(): + for child in obj_c2p_map: + parents = obj_c2p_map[child] + for parent in parents: + if (child in v) and (k in attr_v1): + if parent not in v: + ont["possible_attribute_of"][k] += [parent] + print(f"{k}-{child} \t => {k}-{parent}") + + # NOTE: modify ontology + # added: obj-attr possible relationship w/ attr-attr p2c ontology + for kc in ont["possible_attribute_of"].keys(): + if kc in attr_c2p_map: + kps = attr_c2p_map[kc] + for kp in kps: + vc = ont["possible_attribute_of"][kc] + vp = ont["possible_attribute_of"][kp] + vc_diff_vp = set(vc) - set(vp) + # vp_diff_vc = set(vp) - set(vc) + if len(vc_diff_vp) > 0: + ont["possible_attribute_of"][kp] += list(vc_diff_vp) + print(f"{kc}:{vc_diff_vp} \t => {kp}+={vc_diff_vp}") + + # list + self.obj_v1 = obj_v1 + self.attr_v1 = attr_v1 + # cat-attr + self.cat_attr_v1 = cat_attr_v1 + self.attr_cat_v1 = attr_cat_v1 + # p2c + self.attr_p2c_map = attr_p2c_map + self.attr_c2p_map = attr_c2p_map + self.obj_p2c_map = obj_p2c_map + self.obj_c2p_map = obj_c2p_map + # full ont + self.ont = ont + + def get_dataset_by_flag(self, flag="silver"): + datasets = { + "silver": self.silver_dataset, + "gold": self.gold_dataset, + "gold+": self.gold_dataset, + } + + if flag not in datasets: + raise ValueError("flag must be either 'silver', 'gold', or 'gold+'") + + dataset = datasets[flag].copy() + return dataset + + def keep_dataset_by_flag(self, dataset, flag="silver"): + if flag not in ["silver", "gold", "gold+"]: + raise ValueError("flag must be either 'silver', 'gold', or 'gold+'") + + if flag == "silver": + self.silver_dataset = dataset + else: + self.gold_dataset = dataset + + def preprocessLabel_category(self, flag="silver"): + dataset = self.get_dataset_by_flag(flag) + + dataset = dataset[dataset["categoryID"] != "nlp"] + assert "normal" not in dataset.label_name.unique() + assert "abnormal" not in dataset.label_name.unique() + dataset = dataset.reset_index(drop=True) + + self.keep_dataset_by_flag(dataset, flag) + + def preprocessLabel_bbox(self, flag="silver"): + dataset = self.get_dataset_by_flag(flag) + + dataset = dataset[dataset["bbox"] != "unknown"] + dataset = dataset.reset_index(drop=True) + + self.keep_dataset_by_flag(dataset, flag) + + def preprocessLabel_cohort(self, flag="silver"): + if flag == "silver": + silver_dataset = self.silver_dataset.copy() + silver_cohort = pd.read_csv( + os.path.join(self.args.save_dir, f"cohort_silver.csv"), + ) + silver_iids = silver_cohort["image_id"].unique() + silver_dataset = silver_dataset[silver_dataset.image_id.isin(silver_iids)] + silver_dataset["subject_id"] = silver_dataset["image_id"].map(silver_cohort.set_index("image_id")["subject_id"]) + silver_dataset["subject_id"] = silver_dataset["subject_id"].astype(int) + assert silver_dataset["subject_id"].isna().sum() == 0 + + silver_dataset = silver_dataset[["subject_id", "study_id", "image_id", "sent_loc", "bbox", "relation", "label_name", "categoryID", "annot_id", "object_id"]] + self.silver_dataset = silver_dataset.reset_index(drop=True) + + elif flag == "gold": + gold_dataset = self.gold_dataset.copy() + gold_cohort = pd.read_csv( + os.path.join(self.args.save_dir, f"cohort_gold.csv"), + ) + gold_iids = gold_cohort["image_id"].unique() + gold_dataset = gold_dataset[gold_dataset["image_id"].isin(gold_iids)] + assert len(gold_dataset["image_id"].unique()) == 500 + gold_dataset["subject_id"] = gold_dataset["image_id"].map(gold_cohort.set_index("image_id")["subject_id"]) + gold_dataset["subject_id"] = gold_dataset["subject_id"].astype(int) + assert gold_dataset["subject_id"].isna().sum() == 0 + + gold_dataset = gold_dataset[["subject_id", "study_id", "image_id", "sent_loc", "bbox", "relation", "label_name", "categoryID", "annot_id", "object_id"]] + self.gold_dataset = gold_dataset.reset_index(drop=True) + + elif flag == "gold+": + gold_dataset = self.gold_dataset.copy() + gold_cohort = pd.read_csv( + os.path.join(self.args.save_dir, f"cohort_gold.csv"), + ) + gold_iids = gold_cohort["image_id"].unique() # 2338 images + gold_dataset = gold_dataset[gold_dataset["image_id"].isin(gold_iids)] # 500 images + assert len(gold_dataset["image_id"].unique()) == 500 + + # NOTE: For studies where order >= 2, we use the silver labels + # the pre-processing code of silver dataset (until here) should be run again + self._load_silver_attributes_relations() + self.preprocessLabel_category(flag="silver") + self.preprocessLabel_bbox(flag="silver") + silver_dataset = self.silver_dataset.copy() + silver_dataset = silver_dataset[(~silver_dataset["image_id"].isin(gold_dataset.image_id.unique())) & (silver_dataset["image_id"].isin(gold_iids))] + + # The number of images of gold patients w/ 1st study (in cohort_gold.csv) + print("# of images for gold patients w/ 1st study", gold_dataset["image_id"].nunique()) + # The number of images of gold patients w/ >=2nd study (in cohort_gold.csv) + print("# of images for gold patients w/ >=2nd study", gold_cohort[gold_cohort["StudyOrder"] >= 2]["image_id"].nunique()) + # The number of images of gold patients w/ >=2nd study (in silver_dataset.csv ~ attribute_relations_tabular.txt) + print("# of images for gold patients w/ >=2nd study with silver labels", silver_dataset["image_id"].nunique()) + print("NOTE: Some gold studies (>=2nd) are missing in the silver dataset!!!") + gold_dataset = pd.concat([gold_dataset, silver_dataset], axis=0) + + gold_dataset["subject_id"] = gold_dataset["image_id"].map(gold_cohort.set_index("image_id")["subject_id"]) + gold_dataset["subject_id"] = gold_dataset["subject_id"].astype(int) + assert gold_dataset["subject_id"].isna().sum() == 0 + + gold_dataset = gold_dataset[["subject_id", "study_id", "image_id", "sent_loc", "bbox", "relation", "label_name", "categoryID", "annot_id", "object_id"]] + self.gold_dataset = gold_dataset.reset_index(drop=True) + print(gold_dataset.head()) + + else: + raise ValueError("flag must be either 'silver' or 'gold'") + + def aggregate_labels_by_report_level(self, flag="silver", agg_option="last"): + sort_columns = ["subject_id", "study_id", "image_id", "sent_loc", "bbox", "label_name"] + agg_columns = ["subject_id", "study_id", "image_id", "bbox", "label_name"] + + dataset = self.get_dataset_by_flag(flag) + + dataset = dataset.sort_values(by=sort_columns) + dataset = dataset.drop_duplicates(subset=agg_columns, keep=agg_option) + dataset = dataset.reset_index(drop=True) + + self.keep_dataset_by_flag(dataset, flag) + + def apply_attribute_p2c_ontology_to_dataset(self, flag="silver"): + def _apply_attribute_p2c_ontology_to_dataset(dataset, attr_c2p_map, attr_cat_v1): + sort_columns = ["subject_id", "study_id", "image_id", "sent_loc", "bbox", "label_name"] + agg_columns = ["subject_id", "study_id", "image_id", "bbox", "label_name"] + agg_option = "last" + + dataset_parents = None + for child, parents in attr_c2p_map.items(): + for parent in parents: + dataset_parent = dataset[dataset.label_name == child].copy() + dataset_parent["label_name"] = parent + dataset_parent["categoryID"] = attr_cat_v1[parent] + dataset_parent["annot_id"] = ( + dataset_parent["study_id"].astype(str) + "|" + dataset_parent["bbox"] + "|" + dataset_parent["relation"].astype(str) + "|" + dataset_parent["label_name"] + ) + dataset_parents = pd.concat([dataset_parents, dataset_parent], axis=0) if dataset_parents is not None else dataset_parent + dataset = pd.concat([dataset, dataset_parents], axis=0) + dataset = dataset.sort_values(by=sort_columns) + dataset = dataset.drop_duplicates(subset=agg_columns, keep=agg_option) + return dataset + + dataset = self.get_dataset_by_flag(flag) + + dataset = _apply_attribute_p2c_ontology_to_dataset(dataset=dataset, attr_c2p_map=self.attr_c2p_map, attr_cat_v1=self.attr_cat_v1) + dataset = dataset.reset_index(drop=True) + + self.keep_dataset_by_flag(dataset, flag) + + def apply_object_p2c_ontology_to_dataset(self, flag="silver"): + def _apply_object_p2c_ontology_to_dataset(dataset, obj_c2p_map): + sort_columns = ["subject_id", "study_id", "image_id", "sent_loc", "bbox", "label_name"] + agg_columns = ["subject_id", "study_id", "image_id", "bbox", "label_name"] + agg_option = "last" + + dataset_parents = None + for child, parents in obj_c2p_map.items(): + for parent in parents: + dataset_parent = dataset[dataset.bbox == child].copy() + dataset_parent["bbox"] = parent + dataset_parent["annot_id"] = ( + dataset_parent["study_id"].astype(str) + "|" + dataset_parent["bbox"] + "|" + dataset_parent["relation"].astype(str) + "|" + dataset_parent["label_name"] + ) + dataset_parent["object_id"] = dataset_parent["image_id"] + "_" + dataset_parent["bbox"] + dataset_parents = pd.concat([dataset_parents, dataset_parent], axis=0) if dataset_parents is not None else dataset_parent + dataset = pd.concat([dataset, dataset_parents], axis=0) + dataset = dataset.sort_values(by=sort_columns) + dataset = dataset.drop_duplicates(subset=agg_columns, keep=agg_option) + return dataset + + dataset = self.get_dataset_by_flag(flag) + + dataset = _apply_object_p2c_ontology_to_dataset(dataset=dataset, obj_c2p_map=self.obj_c2p_map) + dataset = dataset.reset_index(drop=True) + + self.keep_dataset_by_flag(dataset, flag) + + def apply_object_attribute_possible_relationship_to_dataset(self, flag="silver"): + def _apply_object_attribute_possible_relationship_to_dataset(dataset, ont): + removed_annot_ids = [] + obj_attr_combs = dataset[["bbox", "label_name"]].drop_duplicates().values + for obj_attr_comb in obj_attr_combs: + obj, attr = obj_attr_comb + if attr not in ["abnormal", "normal", "artifact"]: + possible_objs = ont["possible_attribute_of"][attr] + if obj not in possible_objs: + removed_annot_ids += dataset[(dataset["bbox"] == obj) & (dataset["label_name"] == attr)].annot_id.tolist() + + dataset = dataset[~dataset.annot_id.isin(removed_annot_ids)] + return dataset + + dataset = self.get_dataset_by_flag(flag) + + dataset = _apply_object_attribute_possible_relationship_to_dataset(dataset=dataset, ont=self.ont) + dataset = dataset.reset_index(drop=True) + + self.keep_dataset_by_flag(dataset, flag) + + def sanity_check(self, flag="silver"): + if flag == "silver": + dataset = self.silver_dataset + elif flag in ["gold", "gold+"]: + dataset = self.gold_dataset + else: + raise ValueError("flag must be either 'silver' or 'gold'") + + assert (dataset.object_id.apply(lambda x: x.split("_")[-1]) != dataset.bbox).sum() == 0 + assert (dataset.object_id.apply(lambda x: x.split("_")[0]) != dataset.image_id).sum() == 0 + assert (dataset.annot_id.apply(lambda x: x.split("|")[0]) != dataset.study_id.astype(str)).sum() == 0 + assert (dataset.annot_id.apply(lambda x: x.split("|")[1]) != dataset.bbox).sum() == 0 + assert (dataset.annot_id.apply(lambda x: x.split("|")[2]) != dataset.relation.astype(str)).sum() == 0 + assert (dataset.annot_id.apply(lambda x: x.split("|")[-1]) != dataset.label_name).sum() == 0 + + def remove_minority_label(self, flag="silver"): + REMOVED_OBJS = ["left arm", "right arm"] + REMOVED_ATTRS = [ + "artifact", + "bronchiectasis", + "pigtail catheter", + "skin fold", + "aortic graft/repair", + "diaphragmatic eventration (benign)", + "sternotomy wires", + ] + + dataset = self.get_dataset_by_flag(flag) + + dataset = dataset[(~dataset["bbox"].isin(REMOVED_OBJS)) & (~dataset["label_name"].isin(REMOVED_ATTRS))] + dataset = dataset.reset_index(drop=True) + + self.keep_dataset_by_flag(dataset, flag) + + def restore_normal_relation(self, flag="silver"): + sort_columns = ["subject_id", "study_id", "image_id", "bbox", "label_name", "relation"] + agg_columns = ["subject_id", "study_id", "image_id", "bbox", "label_name"] + agg_option = "last" + + # compute possible object-attribute combinations (in here, a total of 609 combinations) + objattr_combs = [(vv, 0, k, self.attr_cat_v1[k]) for k, v in self.ont["possible_attribute_of"].items() if k in self.attr_v1 for vv in v if vv in self.obj_v1] + objattr_combs = pd.DataFrame(objattr_combs, columns=["bbox", "relation", "label_name", "categoryID"]) + + dataset = self.get_dataset_by_flag(flag) + + all_data = pd.merge(dataset[["subject_id", "study_id", "image_id"]].drop_duplicates().assign(key=1), objattr_combs.assign(key=1), on="key").drop("key", axis=1) + all_data["annot_id"] = all_data["study_id"].astype(str) + "|" + all_data["bbox"] + "|" + all_data["relation"].astype(str) + "|" + all_data["label_name"] + all_data["object_id"] = all_data["image_id"] + "_" + all_data["bbox"] + + dataset = pd.concat([dataset, all_data]) + dataset = dataset.sort_values(by=sort_columns, ascending=True) + dataset = dataset.drop_duplicates(subset=agg_columns, keep=agg_option) + dataset = dataset.reset_index(drop=True) + + self.keep_dataset_by_flag(dataset, flag) + + def save_label_dataset(self, flag="silver"): + dataset = self.get_dataset_by_flag(flag) + + dataset.to_csv(os.path.join(self.args.save_dir, f"{flag}_dataset.csv"), index=False) + print(f"{flag} dataset saved, shape: {dataset.shape}") + + def split_and_save_dataset(self): + from sklearn.model_selection import train_test_split + + # 0 + # gold_dataset = self.gold_dataset + # silver_dataset = self.silver_dataset + # To avoid the problem of the order of the dataset, we load the dataset from the saved csv file. + gold_dataset = pd.read_csv(os.path.join(self.args.save_dir, "gold_dataset.csv")) + silver_dataset = pd.read_csv(os.path.join(self.args.save_dir, "silver_dataset.csv")) + + # 1: divide by relation type + silver_dataset_abn = silver_dataset[silver_dataset["relation"] == 1] + silver_dataset_nm = silver_dataset[~silver_dataset.study_id.isin(silver_dataset_abn.study_id.unique())] + + SEED = 103 + TEST_SIZE = 0.05 + + # 2: split subject_ids (train, valid) + train_subject_ids_abn = silver_dataset_abn.subject_id.unique() + train_subject_ids_abn, valid_subject_ids_abn = train_test_split(train_subject_ids_abn, test_size=TEST_SIZE, random_state=SEED) + train_subject_ids_nm = silver_dataset_nm.subject_id.unique() + train_subject_ids_nm, valid_subject_ids_nm = train_test_split(train_subject_ids_nm, test_size=TEST_SIZE, random_state=SEED) + + train_sids = ( + silver_dataset_abn[silver_dataset_abn.subject_id.isin(train_subject_ids_abn)].study_id.unique().tolist() + + silver_dataset_nm[silver_dataset_nm.subject_id.isin(train_subject_ids_nm)].study_id.unique().tolist() + ) + valid_sids = ( + silver_dataset_abn[silver_dataset_abn.subject_id.isin(valid_subject_ids_abn)].study_id.unique().tolist() + + silver_dataset_nm[silver_dataset_nm.subject_id.isin(valid_subject_ids_nm)].study_id.unique().tolist() + ) + + train_dataset = silver_dataset.loc[silver_dataset["study_id"].isin(train_sids)] + valid_dataset = silver_dataset.loc[silver_dataset["study_id"].isin(valid_sids)] + # test_dataset = gold_dataset.loc[gold_dataset["study_id"].isin(test_sids)] + test_dataset = gold_dataset.copy() + + # 3: save + train_dataset = train_dataset.reset_index(drop=True) + valid_dataset = valid_dataset.reset_index(drop=True) + test_dataset = test_dataset.reset_index(drop=True) + + train_dataset.to_csv(os.path.join(self.args.save_dir, "train_dataset.csv"), index=False) # silver + valid_dataset.to_csv(os.path.join(self.args.save_dir, "valid_dataset.csv"), index=False) # silver + test_dataset.to_csv(os.path.join(self.args.save_dir, "test_dataset.csv"), index=False) # gold + print("silver, gold dataset -> train, valid, test dataset saved") + + +def main(args): + # load preprocessor + label_preproc = LabelPreprocessor(args) + + FLAGS = ["silver", "gold", "gold+"] + # FLAGS = ["gold+"] + + for flag in FLAGS: + # preprocessing - labels + label_preproc.preprocessLabel_category(flag=flag) + label_preproc.preprocessLabel_bbox(flag=flag) + label_preproc.preprocessLabel_cohort(flag=flag) + label_preproc.aggregate_labels_by_report_level(flag=flag, agg_option="last") + + # preprocessing - ontology + label_preproc.apply_attribute_p2c_ontology_to_dataset(flag=flag) + label_preproc.apply_object_p2c_ontology_to_dataset(flag=flag) + label_preproc.apply_object_attribute_possible_relationship_to_dataset(flag=flag) + + label_preproc.sanity_check(flag=flag) + + # preprocessing + label_preproc.remove_minority_label(flag=flag) + label_preproc.restore_normal_relation(flag=flag) + + # save + label_preproc.save_label_dataset(flag=flag) + + if "silver" in FLAGS and "gold" in FLAGS: + label_preproc.split_and_save_dataset() + + +if __name__ == "__main__": + args = config() + main(args) + print("Done")