Diff of /main.py [000000] .. [d6904d]

Switch to side-by-side view

--- a
+++ b/main.py
@@ -0,0 +1,49 @@
+import argparse
+import pathlib
+
+import torch
+from omegaconf import OmegaConf
+
+from app import create_app
+
+if __name__ == "__main__":
+    pathlib.Path("./checkpoints").mkdir(parents=True, exist_ok=True)
+    print("===[Start]===")
+    parser = argparse.ArgumentParser("Covid-EMR training script", add_help=False)
+    parser.add_argument(
+        "--cfg", type=str, required=True, metavar="FILE", help="path to config file"
+    )
+    parser.add_argument(
+        "--cuda",
+        type=int,
+        required=False,
+        metavar="CUDA NUMBER",
+        help="gpu to train",
+    )
+    parser.add_argument(
+        "--db",
+        action="store_true",
+        help="whether to connect database",
+    )
+
+    parser.add_argument(
+        "--train",
+        action="store_true",
+        help="whether to train model, only execute inference stage if not",
+    )
+
+    args = parser.parse_args()
+    print(f"===[{args.cfg}]===")
+    conf = OmegaConf.load(args.cfg)
+    conf.db = args.db
+    conf.train = args.train
+
+    # train on cpu by default
+    device = torch.device("cpu")
+    if args.cuda is not None:
+        device = torch.device(
+            f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
+        )
+
+    create_app(conf, device)
+    print("===[End]===")