|
a |
|
b/experiments/train_normal_pipeline.py |
|
|
1 |
import sys |
|
|
2 |
from collect_generalizability_metrics import get_generalizability_gap |
|
|
3 |
from training.vanilla_trainer import VanillaTrainer |
|
|
4 |
|
|
|
5 |
if __name__ == '__main__': |
|
|
6 |
config = {"model": "DeepLab", |
|
|
7 |
"device": "cuda", |
|
|
8 |
"lr": 0.00001, |
|
|
9 |
"batch_size": 8, |
|
|
10 |
"epochs": 250} |
|
|
11 |
trainer = VanillaTrainer(sys.argv[1], config) |
|
|
12 |
trainer.train() |
|
|
13 |
|
|
|
14 |
# for i in ["TriUnet", "DeepLab", "FPN", "Unet"]: |
|
|
15 |
# config["model"] = i |
|
|
16 |
# trainer = VanillaTrainer(sys.argv[1], config) |
|
|
17 |
# trainer.train() |
|
|
18 |
# for i in range(13, 100): |
|
|
19 |
# trainer = VanillaTrainer("DeepLab", i, config) |
|
|
20 |
# trainer.train() |
|
|
21 |
# i = 4 |
|
|
22 |
# while True: |
|
|
23 |
# config = {"epochs": 200, "id": i, "lr": 0.00001, "pretrain": False} |
|
|
24 |
# training.vanilla_trainer.train_vanilla_predictor(config) |
|
|
25 |
# i += 1 |