[d6904d]: / main.py

Download this file

50 lines (42 with data), 1.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import argparse
import pathlib
import torch
from omegaconf import OmegaConf
from app import create_app
if __name__ == "__main__":
pathlib.Path("./checkpoints").mkdir(parents=True, exist_ok=True)
print("===[Start]===")
parser = argparse.ArgumentParser("Covid-EMR training script", add_help=False)
parser.add_argument(
"--cfg", type=str, required=True, metavar="FILE", help="path to config file"
)
parser.add_argument(
"--cuda",
type=int,
required=False,
metavar="CUDA NUMBER",
help="gpu to train",
)
parser.add_argument(
"--db",
action="store_true",
help="whether to connect database",
)
parser.add_argument(
"--train",
action="store_true",
help="whether to train model, only execute inference stage if not",
)
args = parser.parse_args()
print(f"===[{args.cfg}]===")
conf = OmegaConf.load(args.cfg)
conf.db = args.db
conf.train = args.train
# train on cpu by default
device = torch.device("cpu")
if args.cuda is not None:
device = torch.device(
f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
)
create_app(conf, device)
print("===[End]===")