Diff of /main.py [000000] .. [d6904d]

Switch to unified view

a b/main.py
1
import argparse
2
import pathlib
3
4
import torch
5
from omegaconf import OmegaConf
6
7
from app import create_app
8
9
if __name__ == "__main__":
10
    pathlib.Path("./checkpoints").mkdir(parents=True, exist_ok=True)
11
    print("===[Start]===")
12
    parser = argparse.ArgumentParser("Covid-EMR training script", add_help=False)
13
    parser.add_argument(
14
        "--cfg", type=str, required=True, metavar="FILE", help="path to config file"
15
    )
16
    parser.add_argument(
17
        "--cuda",
18
        type=int,
19
        required=False,
20
        metavar="CUDA NUMBER",
21
        help="gpu to train",
22
    )
23
    parser.add_argument(
24
        "--db",
25
        action="store_true",
26
        help="whether to connect database",
27
    )
28
29
    parser.add_argument(
30
        "--train",
31
        action="store_true",
32
        help="whether to train model, only execute inference stage if not",
33
    )
34
35
    args = parser.parse_args()
36
    print(f"===[{args.cfg}]===")
37
    conf = OmegaConf.load(args.cfg)
38
    conf.db = args.db
39
    conf.train = args.train
40
41
    # train on cpu by default
42
    device = torch.device("cpu")
43
    if args.cuda is not None:
44
        device = torch.device(
45
            f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
46
        )
47
48
    create_app(conf, device)
49
    print("===[End]===")