--- a +++ b/experiments/train_normal_pipeline.py @@ -0,0 +1,25 @@ +import sys +from collect_generalizability_metrics import get_generalizability_gap +from training.vanilla_trainer import VanillaTrainer + +if __name__ == '__main__': + config = {"model": "DeepLab", + "device": "cuda", + "lr": 0.00001, + "batch_size": 8, + "epochs": 250} + trainer = VanillaTrainer(sys.argv[1], config) + trainer.train() + + # for i in ["TriUnet", "DeepLab", "FPN", "Unet"]: + # config["model"] = i + # trainer = VanillaTrainer(sys.argv[1], config) + # trainer.train() + # for i in range(13, 100): + # trainer = VanillaTrainer("DeepLab", i, config) + # trainer.train() + # i = 4 + # while True: + # config = {"epochs": 200, "id": i, "lr": 0.00001, "pretrain": False} + # training.vanilla_trainer.train_vanilla_predictor(config) + # i += 1