a b/experiments/train_normal_pipeline.py
1
import sys
2
from collect_generalizability_metrics import get_generalizability_gap
3
from training.vanilla_trainer import VanillaTrainer
4
5
if __name__ == '__main__':
6
    config = {"model": "DeepLab",
7
              "device": "cuda",
8
              "lr": 0.00001,
9
              "batch_size": 8,
10
              "epochs": 250}
11
    trainer = VanillaTrainer(sys.argv[1], config)
12
    trainer.train()
13
14
    # for i in ["TriUnet", "DeepLab", "FPN", "Unet"]:
15
    #     config["model"] = i
16
    #     trainer = VanillaTrainer(sys.argv[1], config)
17
    #     trainer.train()
18
    # for i in range(13, 100):
19
    #     trainer = VanillaTrainer("DeepLab", i, config)
20
    #     trainer.train()
21
    # i = 4
22
    # while True:
23
    #     config = {"epochs": 200, "id": i, "lr": 0.00001, "pretrain": False}
24
    #     training.vanilla_trainer.train_vanilla_predictor(config)
25
    #     i += 1