|
a |
|
b/tools/MultiShellScripts.py |
|
|
1 |
import os |
|
|
2 |
import argparse |
|
|
3 |
|
|
|
4 |
def test_df(log_dir, epoch_i=0, best_model=False): |
|
|
5 |
OUTPUT_DIR = os.path.join(log_dir, "eval") |
|
|
6 |
if best_model: |
|
|
7 |
MODEL_PATH = os.path.join(log_dir, "ckpt", "model_best.pth") |
|
|
8 |
else: |
|
|
9 |
MODEL_PATH = os.path.join(log_dir, "ckpt", "checkpoint_epoch_{}.pth".format(epoch_i)) |
|
|
10 |
|
|
|
11 |
commands = "python tools/test_df.py --used_df U_NetDF --selfeat --mgpus 6 --model_path1 {} \ |
|
|
12 |
--output_dir {} --log_file ../log_evaluation_vis.txt --vis".format(MODEL_PATH, OUTPUT_DIR) |
|
|
13 |
os.system(commands) |
|
|
14 |
|
|
|
15 |
def train(): |
|
|
16 |
os.system("python -m torch.distributed.launch --nproc_per_node 2 --master_port $RANDOM tools/train.py --batch_size 24 --mgpus 2,3 --output_dir logs/acdc_logs/log_temp --train_with_eval") |
|
|
17 |
|
|
|
18 |
if __name__ == "__main__": |
|
|
19 |
parser = argparse.ArgumentParser(description="arg parser") |
|
|
20 |
parser.add_argument("--scrip", type=str, default=None, help="which scrips to running") |
|
|
21 |
args = parser.parse_args() |
|
|
22 |
|
|
|
23 |
if args.scrip == "train": |
|
|
24 |
train() |
|
|
25 |
elif args.scrip == "test_df": |
|
|
26 |
test_df("logs/acdc_logs/logs_256_supcat_auxseg_thresh0.1/", best_model=False, epoch_i=118) |