[8eeb5a]: / experiments / train_with_inpainters.py

Download this file

27 lines (22 with data), 750 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from training.inductivenet_trainers import *
from training.consistency_trainers import *
from training.vanilla_trainer import *
import sys
if __name__ == '__main__':
id = sys.argv[1]
model = sys.argv[2]
config = {"model": model,
"device": "cuda",
"lr": 0.00001,
"batch_size": 8,
"epochs": 300,
"use_inpainter": True}
if model == "InductiveNet":
trainer = InductiveNetAugmentationTrainer(f"inpainter_zaugmentation_{id}", config.copy())
trainer.train()
else:
"""
Model-based augmentations
"""
trainer = ConsistencyTrainerUsingAugmentation(f"inpainter_augmentation_{id}", config.copy())
trainer.train()