[92cc18]: / experiments / train_ensemble.py

Download this file

14 lines (12 with data), 401 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
from training.consistency_trainers import *
from training.vanilla_trainer import *
import sys
if __name__ == '__main__':
config = {"model": "InductiveNet",
"device": "cuda",
"lr": 0.00001,
"batch_size": 4,
"epochs": 250,
"use_inpainter": False}
trainer = EnsembleConsistencyTrainer("ensemble", config)
trainer.train()