a b/tools/MultiShellScripts.py
1
import os
2
import argparse
3
4
def test_df(log_dir, epoch_i=0, best_model=False):
5
    OUTPUT_DIR = os.path.join(log_dir, "eval")
6
    if best_model:
7
        MODEL_PATH = os.path.join(log_dir, "ckpt", "model_best.pth")
8
    else:
9
        MODEL_PATH = os.path.join(log_dir, "ckpt", "checkpoint_epoch_{}.pth".format(epoch_i))
10
11
    commands = "python tools/test_df.py --used_df U_NetDF --selfeat --mgpus 6 --model_path1 {} \
12
                    --output_dir {} --log_file ../log_evaluation_vis.txt --vis".format(MODEL_PATH, OUTPUT_DIR)
13
    os.system(commands)
14
15
def train():
16
    os.system("python -m torch.distributed.launch --nproc_per_node 2 --master_port $RANDOM tools/train.py --batch_size 24 --mgpus 2,3 --output_dir logs/acdc_logs/log_temp --train_with_eval")
17
18
if __name__ == "__main__":
19
    parser = argparse.ArgumentParser(description="arg parser")
20
    parser.add_argument("--scrip", type=str, default=None, help="which scrips to running")
21
    args = parser.parse_args()
22
23
    if args.scrip == "train":
24
        train()
25
    elif args.scrip == "test_df":
26
        test_df("logs/acdc_logs/logs_256_supcat_auxseg_thresh0.1/", best_model=False, epoch_i=118)