|
a |
|
b/medicalbert/__main__.py |
|
|
1 |
import json, logging, os, torch |
|
|
2 |
import random |
|
|
3 |
from config import get_configuration |
|
|
4 |
import numpy as np |
|
|
5 |
from classifiers.classifier_factory import ClassifierFactory |
|
|
6 |
from datareader.data_reader_factory import DataReaderFactory |
|
|
7 |
from cliparser import setup_parser |
|
|
8 |
from evaluator.evaluator_factory import EvaluatorFactory |
|
|
9 |
from evaluator.standard_evaluator import StandardEvaluator |
|
|
10 |
from tokenizers.tokenizer_factory import TokenizerFactory |
|
|
11 |
|
|
|
12 |
from evaluator.validation_metric_factory import ValidationMetricFactory |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def set_random_seeds(seed): |
|
|
16 |
random.seed(seed) |
|
|
17 |
np.random.seed(seed) |
|
|
18 |
torch.manual_seed(seed) |
|
|
19 |
|
|
|
20 |
def save_config(defconfig): |
|
|
21 |
config_path = os.path.join(defconfig['output_dir'], defconfig['experiment_name'], 'config.json') |
|
|
22 |
if not os.path.exists( |
|
|
23 |
os.path.join(defconfig['output_dir'], defconfig['experiment_name'])): |
|
|
24 |
os.makedirs(os.path.join(defconfig['output_dir'], defconfig['experiment_name'])) |
|
|
25 |
|
|
|
26 |
with open(config_path, 'w') as f: |
|
|
27 |
json.dump(defconfig, f) |
|
|
28 |
|
|
|
29 |
if __name__ == "__main__": |
|
|
30 |
|
|
|
31 |
# Load config |
|
|
32 |
args = setup_parser() |
|
|
33 |
defconfig = get_configuration(args) |
|
|
34 |
print(defconfig) |
|
|
35 |
|
|
|
36 |
save_config(defconfig) |
|
|
37 |
|
|
|
38 |
logging.info("Number of GPUS: {}".format(torch.cuda.device_count())) |
|
|
39 |
|
|
|
40 |
if 'seed' in defconfig: |
|
|
41 |
set_random_seeds(defconfig['seed']) |
|
|
42 |
|
|
|
43 |
# Load the tokenizer to use |
|
|
44 |
tokenizerFactory = TokenizerFactory() |
|
|
45 |
tokenizer = tokenizerFactory.make_tokenizer(defconfig['tokenizer']) |
|
|
46 |
|
|
|
47 |
# Build a classifier object to use |
|
|
48 |
classifierFactory = ClassifierFactory(defconfig) |
|
|
49 |
classifier = classifierFactory.make_classifier(defconfig['classifier']) |
|
|
50 |
|
|
|
51 |
# load the data |
|
|
52 |
dataReaderFactory = DataReaderFactory(defconfig) |
|
|
53 |
|
|
|
54 |
datareader = dataReaderFactory.make_datareader(defconfig['datareader'], tokenizer) |
|
|
55 |
|
|
|
56 |
if args.train: |
|
|
57 |
|
|
|
58 |
# Load from checkpoint if we're using one (won't do anything if were not) |
|
|
59 |
if args.train_from_checkpoint: |
|
|
60 |
classifier.load_from_checkpoint(args.train_from_checkpoint) |
|
|
61 |
|
|
|
62 |
# Pass the classifier to the trainer |
|
|
63 |
classifier.train(datareader) |
|
|
64 |
|
|
|
65 |
if args.eval: |
|
|
66 |
|
|
|
67 |
# setup the correct validator |
|
|
68 |
results_path = os.path.join(defconfig['output_dir'], defconfig['experiment_name'], "results") |
|
|
69 |
validator = ValidationMetricFactory().make_validator(defconfig['validation_metric']) |
|
|
70 |
|
|
|
71 |
evaluator = EvaluatorFactory().make_evaluator(defconfig['evaluator'], results_path, defconfig, datareader, validator) |
|
|
72 |
|
|
|
73 |
checkpoints_path = os.path.join(defconfig['output_dir'], defconfig['experiment_name'], "checkpoints") |
|
|
74 |
# Loop over all the checkpoints, running evaluations on all them. |
|
|
75 |
for checkpoint in os.listdir(checkpoints_path): |
|
|
76 |
|
|
|
77 |
# Load the checkpoint model |
|
|
78 |
classifier.load_from_checkpoint(os.path.join(checkpoints_path, checkpoint)) |
|
|
79 |
|
|
|
80 |
evaluator.go(classifier, checkpoint) |
|
|
81 |
|
|
|
82 |
evaluator.test() |
|
|
83 |
|