a | b/experiments/train_ensemble.py | ||
---|---|---|---|
1 | from training.consistency_trainers import * |
||
2 | from training.vanilla_trainer import * |
||
3 | import sys |
||
4 | |||
5 | if __name__ == '__main__': |
||
6 | config = {"model": "InductiveNet", |
||
7 | "device": "cuda", |
||
8 | "lr": 0.00001, |
||
9 | "batch_size": 4, |
||
10 | "epochs": 250, |
||
11 | "use_inpainter": False} |
||
12 | trainer = EnsembleConsistencyTrainer("ensemble", config) |
||
13 | trainer.train() |