a b/code/bert_code/main.py
1
from preprocessor import PreProcessor
2
from MIMIC_Bert_NER import *
3
4
5
6
if __name__ == '__main__':
7
    #Training data
8
    save_path = "./cleaned_files_train"
9
    concept_path = "../dnc_code/medical_data/train_data/concept"
10
    text_path = "../dnc_code/medical_data/train_data/txt"
11
12
    if not os.path.exists(save_path):
13
        os.mkdir(save_path)
14
    preproc = PreProcessor(concept_path,text_path,save_path)
15
16
    preproc.pre_process()
17
18
    #Train BERT
19
    bert = MIMICBertNER(save_path)
20
    bert.get_inputs()
21
22
    #bert_out_address = './saved_models/bert_out_model/mimic_bert'
23
24
    model = bert.train()
25
26
    # Save a trained model, configuration and tokenizer
27
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
28
29
    # If we save using the predefined names, we can load using `from_pretrained`
30
    output_model_file = os.path.join(bert.bert_out_address, "pytorch_model.bin")
31
    output_config_file = os.path.join(bert.bert_out_address, "config.json")
32
33
    torch.save(model_to_save.state_dict(), output_model_file)
34
    model_to_save.config.to_json_file(output_config_file)
35
    bert.tokenizer.save_vocabulary(bert.bert_out_address)
36
37
    #Validation data metrics
38
    bert.evaluate()
39
40
    #Test data
41
    save_path = "./cleaned_files_test"
42
    concept_path = "../dnc_code/medical_data/test_data/concept"
43
    text_path = "../dnc_code/medical_data/test_data/txt"
44
45
    if not os.path.exists(save_path):
46
        os.mkdir(save_path)
47
    preproc = PreProcessor(concept_path,text_path,save_path)
48
    preproc.pre_process()
49
50
    #Test BERT
51
    bert_test = MIMICBertNER(saved_path)
52
    bert_test.get_inputs("test")
53
    bert_test.evaluate()