Diff of /train.py [000000] .. [fbbdf8]

Switch to unified view

a b/train.py
1
import argparse
2
import json
3
4
from trainers import trainers
5
6
7
def parse_args():
8
    parser = argparse.ArgumentParser()
9
    parser.add_argument("--config", required=True)
10
    return parser.parse_args()
11
12
13
if __name__ == "__main__":
14
    args = parse_args()
15
    config = json.loads(open(args.config).read())
16
    trainer_type = getattr(trainers, config["type"])
17
18
    print("Trainer: ", config["type"], trainer_type)
19
    trainer = trainer_type(config)
20
    trainer.loop()