Diff of /medicalbert/__main__.py [000000] .. [d129b2]

Switch to unified view

a b/medicalbert/__main__.py
1
import json, logging, os, torch
2
import random
3
from config import get_configuration
4
import numpy as np
5
from classifiers.classifier_factory import ClassifierFactory
6
from datareader.data_reader_factory import DataReaderFactory
7
from cliparser import setup_parser
8
from evaluator.evaluator_factory import EvaluatorFactory
9
from evaluator.standard_evaluator import StandardEvaluator
10
from tokenizers.tokenizer_factory import TokenizerFactory
11
12
from evaluator.validation_metric_factory import ValidationMetricFactory
13
14
15
def set_random_seeds(seed):
16
    random.seed(seed)
17
    np.random.seed(seed)
18
    torch.manual_seed(seed)
19
20
def save_config(defconfig):
21
    config_path = os.path.join(defconfig['output_dir'], defconfig['experiment_name'], 'config.json')
22
    if not os.path.exists(
23
            os.path.join(defconfig['output_dir'], defconfig['experiment_name'])):
24
        os.makedirs(os.path.join(defconfig['output_dir'], defconfig['experiment_name']))
25
26
    with open(config_path, 'w') as f:
27
        json.dump(defconfig, f)
28
29
if __name__ == "__main__":
30
31
    # Load config
32
    args = setup_parser()
33
    defconfig = get_configuration(args)
34
    print(defconfig)
35
36
    save_config(defconfig)
37
38
    logging.info("Number of GPUS: {}".format(torch.cuda.device_count()))
39
40
    if 'seed' in defconfig:
41
        set_random_seeds(defconfig['seed'])
42
43
    # Load the tokenizer to use
44
    tokenizerFactory = TokenizerFactory()
45
    tokenizer = tokenizerFactory.make_tokenizer(defconfig['tokenizer'])
46
47
    # Build a classifier object to use
48
    classifierFactory = ClassifierFactory(defconfig)
49
    classifier = classifierFactory.make_classifier(defconfig['classifier'])
50
51
    # load the data
52
    dataReaderFactory = DataReaderFactory(defconfig)
53
54
    datareader = dataReaderFactory.make_datareader(defconfig['datareader'], tokenizer)
55
56
    if args.train:
57
58
        # Load from checkpoint if we're using one (won't do anything if were not)
59
        if args.train_from_checkpoint:
60
            classifier.load_from_checkpoint(args.train_from_checkpoint)
61
62
        # Pass the classifier to the trainer
63
        classifier.train(datareader)
64
65
    if args.eval:
66
67
        # setup the correct validator
68
        results_path = os.path.join(defconfig['output_dir'], defconfig['experiment_name'], "results")
69
        validator = ValidationMetricFactory().make_validator(defconfig['validation_metric'])
70
71
        evaluator = EvaluatorFactory().make_evaluator(defconfig['evaluator'], results_path, defconfig, datareader, validator)
72
73
        checkpoints_path = os.path.join(defconfig['output_dir'], defconfig['experiment_name'], "checkpoints")
74
        # Loop over all the checkpoints, running evaluations on all them.
75
        for checkpoint in os.listdir(checkpoints_path):
76
77
            # Load the checkpoint model
78
            classifier.load_from_checkpoint(os.path.join(checkpoints_path, checkpoint))
79
80
            evaluator.go(classifier, checkpoint)
81
82
        evaluator.test()
83