--- a +++ b/main.py @@ -0,0 +1,49 @@ +import argparse +import pathlib + +import torch +from omegaconf import OmegaConf + +from app import create_app + +if __name__ == "__main__": + pathlib.Path("./checkpoints").mkdir(parents=True, exist_ok=True) + print("===[Start]===") + parser = argparse.ArgumentParser("Covid-EMR training script", add_help=False) + parser.add_argument( + "--cfg", type=str, required=True, metavar="FILE", help="path to config file" + ) + parser.add_argument( + "--cuda", + type=int, + required=False, + metavar="CUDA NUMBER", + help="gpu to train", + ) + parser.add_argument( + "--db", + action="store_true", + help="whether to connect database", + ) + + parser.add_argument( + "--train", + action="store_true", + help="whether to train model, only execute inference stage if not", + ) + + args = parser.parse_args() + print(f"===[{args.cfg}]===") + conf = OmegaConf.load(args.cfg) + conf.db = args.db + conf.train = args.train + + # train on cpu by default + device = torch.device("cpu") + if args.cuda is not None: + device = torch.device( + f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu" + ) + + create_app(conf, device) + print("===[End]===")