[8eeb5a]: / experiments / train_normal_pipeline.py

Download this file

26 lines (23 with data), 859 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import sys
from collect_generalizability_metrics import get_generalizability_gap
from training.vanilla_trainer import VanillaTrainer
if __name__ == '__main__':
config = {"model": "DeepLab",
"device": "cuda",
"lr": 0.00001,
"batch_size": 8,
"epochs": 250}
trainer = VanillaTrainer(sys.argv[1], config)
trainer.train()
# for i in ["TriUnet", "DeepLab", "FPN", "Unet"]:
# config["model"] = i
# trainer = VanillaTrainer(sys.argv[1], config)
# trainer.train()
# for i in range(13, 100):
# trainer = VanillaTrainer("DeepLab", i, config)
# trainer.train()
# i = 4
# while True:
# config = {"epochs": 200, "id": i, "lr": 0.00001, "pretrain": False}
# training.vanilla_trainer.train_vanilla_predictor(config)
# i += 1