[8eeb5a]: / experiments / train_predictors_for_model.py

Download this file

42 lines (36 with data), 1.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from training.inductivenet_trainers import *
from training.consistency_trainers import *
from training.vanilla_trainer import *
import sys
if __name__ == '__main__':
id = sys.argv[1]
model = sys.argv[2]
config = {"model": model,
"device": "cuda",
"lr": 0.00001,
"batch_size": 8,
"epochs": 300,
"use_inpainter": False}
if model == "InductiveNet":
trainer = InductiveNetAugmentationTrainer(f"augmentation_{id}", config.copy())
trainer.train()
trainer = InductiveNetConsistencyTrainer(f"consistency_{id}", config.copy())
trainer.train()
trainer = InductiveNetVanillaTrainer(f"vanilla_{id}", config.copy())
trainer.train()
else:
"""
Consistency Training
"""
trainer = ConsistencyTrainer(f"consistency_{id}", config.copy())
trainer.train()
"""
Model-based augmentations
"""
trainer = ConsistencyTrainerUsingAugmentation(f"augmentation_{id}", config.copy())
trainer.train()
"""
No augmentations
"""
trainer = VanillaTrainer(f"vanilla_{id}", config.copy())
trainer.train()