Diff of /AICare-baselines/train.py [000000] .. [0f1df3]

Switch to side-by-side view

--- a
+++ b/AICare-baselines/train.py
@@ -0,0 +1,70 @@
+import lightning as L
+from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
+from lightning.pytorch.loggers import CSVLogger
+
+# from configs.experiments_mimic import hparams
+from configs.exp import hparams
+from datasets.loader.datamodule import EhrDataModule
+from datasets.loader.load_los_info import get_los_info
+from pipelines import DlPipeline, MlPipeline
+
+def run_ml_experiment(config):
+    los_config = get_los_info(f'datasets/{config["dataset"]}/processed/fold_{config["fold"]}')
+    config.update({"los_info": los_config})
+
+    # data
+    dm = EhrDataModule(f'datasets/{config["dataset"]}/processed/fold_{config["fold"]}', batch_size=config["batch_size"])
+    # logger
+    checkpoint_filename = f'{config["model"]}-fold{config["fold"]}-seed{config["seed"]}'
+    logger = CSVLogger(save_dir="logs", name=f'train/{config["dataset"]}/{config["task"]}', version=checkpoint_filename)
+    L.seed_everything(config["seed"]) # seed for reproducibility
+
+    # train/val/test
+    pipeline = MlPipeline(config)
+    trainer = L.Trainer(accelerator="cpu", max_epochs=1, logger=logger, num_sanity_val_steps=0)
+    trainer.fit(pipeline, dm)
+    perf = pipeline.cur_best_performance
+    return perf
+
+def run_dl_experiment(config):
+    los_config = get_los_info(f'datasets/{config["dataset"]}/processed/fold_{config["fold"]}')
+    config.update({"los_info": los_config})
+
+    # data
+    dm = EhrDataModule(f'datasets/{config["dataset"]}/processed/fold_{config["fold"]}', batch_size=config["batch_size"])
+    # logger
+    checkpoint_filename = f'{config["model"]}-fold{config["fold"]}-seed{config["seed"]}'
+    if "time_aware" in config and config["time_aware"] == True:
+        checkpoint_filename+="-ta" # time-aware loss applied
+    logger = CSVLogger(save_dir="logs", name=f'train/{config["dataset"]}/{config["task"]}', version=checkpoint_filename)
+
+    # EarlyStop and checkpoint callback
+    if config["task"] in ["outcome", "multitask"]:
+        early_stopping_callback = EarlyStopping(monitor="auprc", patience=config["patience"], mode="max",)
+        checkpoint_callback = ModelCheckpoint(filename="best", monitor="auprc", mode="max")
+    elif config["task"] == "los":
+        early_stopping_callback = EarlyStopping(monitor="mae", patience=config["patience"], mode="min",)
+        checkpoint_callback = ModelCheckpoint(filename="best", monitor="mae", mode="min")
+
+    L.seed_everything(config["seed"]) # seed for reproducibility
+
+    # train/val/test
+    pipeline = DlPipeline(config)
+    trainer = L.Trainer(accelerator="gpu", devices=[1], max_epochs=config["epochs"], logger=logger, callbacks=[early_stopping_callback, checkpoint_callback], num_sanity_val_steps=0)
+    trainer.fit(pipeline, dm)
+    perf = pipeline.cur_best_performance
+    return perf
+
+if __name__ == "__main__":
+    best_hparams = hparams # [TO-SPECIFY]
+    for i in range(len(best_hparams)):
+        config = best_hparams[i]
+        run_func = run_ml_experiment if config["model"] in ["RF", "DT", "GBDT", "XGBoost", "CatBoost", "LR", "LightGBM"] else run_dl_experiment
+        seeds = [0] # [0,1,2,3,4]
+        folds = ['nshot']
+        for fold in folds:
+            config["fold"] = fold
+            for seed in seeds:
+                config["seed"] = seed
+                perf = run_func(config)
+                print(f"{config}, Val Performance: {perf}")