|
a |
|
b/src/main.py |
|
|
1 |
import argparse |
|
|
2 |
|
|
|
3 |
parser = argparse.ArgumentParser( |
|
|
4 |
description='This class is used to train a transformer-based model on admission notes, labelled with <3 by Patrick.') |
|
|
5 |
|
|
|
6 |
parser.add_argument('-o', '--output', type=str, default="", |
|
|
7 |
help='Choose where to save the model after training. Saving is optional.') |
|
|
8 |
parser.add_argument('-lr', '--learning_rate', type=float, default=1e-2, |
|
|
9 |
help='Choose the learning rate of the model.') |
|
|
10 |
parser.add_argument('-b', '--batch_size', type=int, default=16, |
|
|
11 |
help='Choose the batch size of the model.') |
|
|
12 |
parser.add_argument('-e', '--epochs', type=int, default=5, |
|
|
13 |
help='Choose the epochs of the model.') |
|
|
14 |
parser.add_argument('-opt', '--optimizer', type=str, default='SGD', |
|
|
15 |
help='Choose the optimizer to be used for the model: SDG | Adam') |
|
|
16 |
parser.add_argument('-tr', '--transfer_learning', type=bool, default=False, |
|
|
17 |
help='Choose whether the BioBERT model should be used as baseline or not.') |
|
|
18 |
parser.add_argument('-v', '--verbose', type=bool, default=False, |
|
|
19 |
help='Choose whether the model should be evaluated after each epoch or only after the training.') |
|
|
20 |
parser.add_argument('-l', '--input_length', type=int, default=128, |
|
|
21 |
help='Choose the maximum length of the model\'s input layer.') |
|
|
22 |
parser.add_argument('-ag', '--data_augmentation', type=bool, default=False, |
|
|
23 |
help='Choose whether data-augmentation should be used.') |
|
|
24 |
parser.add_argument('-t', '--type', type=str, required=True, |
|
|
25 |
help='Specify the type of annotation to process. Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure') |
|
|
26 |
|
|
|
27 |
args = parser.parse_args() |
|
|
28 |
|
|
|
29 |
if args.type not in ['Medical Condition', 'Symptom', 'Medication', 'Vital Statistic', 'Measurement Value', 'Negation Cue', 'Medical Procedure']: |
|
|
30 |
raise ValueError('Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure') |
|
|
31 |
|
|
|
32 |
from utils.dataloader import Dataloader |
|
|
33 |
from utils.BertArchitecture import BertNER |
|
|
34 |
from utils.BertArchitecture import BioBertNER |
|
|
35 |
from utils.metric_tracking import MetricsTracking |
|
|
36 |
from utils.training import train_loop, testing |
|
|
37 |
|
|
|
38 |
import torch |
|
|
39 |
from torch.optim import SGD |
|
|
40 |
from torch.optim import Adam |
|
|
41 |
from torch.utils.data import DataLoader |
|
|
42 |
|
|
|
43 |
import numpy as np |
|
|
44 |
import pandas as pd |
|
|
45 |
|
|
|
46 |
from tqdm import tqdm |
|
|
47 |
|
|
|
48 |
#-------MAIN-------# |
|
|
49 |
|
|
|
50 |
if not args.transfer_learning: |
|
|
51 |
print("Training base BERT model...") |
|
|
52 |
model = BertNER(3) #O, B-, I- -> 3 entities |
|
|
53 |
|
|
|
54 |
if args.type == 'Medical Condition': |
|
|
55 |
type = 'MEDCOND' |
|
|
56 |
elif args.type == 'Symptom': |
|
|
57 |
type = 'SYMPTOM' |
|
|
58 |
elif args.type == 'Medication': |
|
|
59 |
type = 'MEDICATION' |
|
|
60 |
elif args.type == 'Vital Statistic': |
|
|
61 |
type = 'VITALSTAT' |
|
|
62 |
elif args.type == 'Measurement Value': |
|
|
63 |
type = 'MEASVAL' |
|
|
64 |
elif args.type == 'Negation Cue': |
|
|
65 |
type = 'NEGATION' |
|
|
66 |
elif args.type == 'Medical Procedure': |
|
|
67 |
type = 'PROCEDURE' |
|
|
68 |
else: |
|
|
69 |
raise ValueError('Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure') |
|
|
70 |
|
|
|
71 |
else: |
|
|
72 |
print("Training BERT model based on BioBERT diseases...") |
|
|
73 |
|
|
|
74 |
if not args.type == 'Medical Condition': |
|
|
75 |
raise ValueError('Type of annotation needs to be Medical Condition when using BioBERT as baseline.') |
|
|
76 |
|
|
|
77 |
model = BioBertNER(3) #O, B-, I- -> 3 entities |
|
|
78 |
type = 'DISEASE' |
|
|
79 |
|
|
|
80 |
label_to_ids = { |
|
|
81 |
'B-' + type: 0, |
|
|
82 |
'I-' + type: 1, |
|
|
83 |
'O': 2 |
|
|
84 |
} |
|
|
85 |
|
|
|
86 |
ids_to_label = { |
|
|
87 |
0:'B-' + type, |
|
|
88 |
1:'I-' + type, |
|
|
89 |
2:'O' |
|
|
90 |
} |
|
|
91 |
|
|
|
92 |
dataloader = Dataloader(label_to_ids, ids_to_label, args.transfer_learning, args.input_length, type) |
|
|
93 |
|
|
|
94 |
train, val, test = dataloader.load_dataset(augment = args.data_augmentation) |
|
|
95 |
|
|
|
96 |
if args.optimizer == 'SGD': |
|
|
97 |
print("Using SGD optimizer...") |
|
|
98 |
optimizer = SGD(model.parameters(), lr=args.learning_rate, momentum = 0.9) |
|
|
99 |
else: |
|
|
100 |
print("Using Adam optimizer...") |
|
|
101 |
optimizer = Adam(model.parameters(), lr=args.learning_rate) |
|
|
102 |
|
|
|
103 |
parameters = { |
|
|
104 |
"model": model, |
|
|
105 |
"train_dataset": train, |
|
|
106 |
"eval_dataset" : val, |
|
|
107 |
"optimizer" : optimizer, |
|
|
108 |
"batch_size" : args.batch_size, |
|
|
109 |
"epochs" : args.epochs, |
|
|
110 |
"type" : type |
|
|
111 |
} |
|
|
112 |
|
|
|
113 |
train_loop(**parameters, verbose=args.verbose) |
|
|
114 |
|
|
|
115 |
testing(model, test, args.batch_size, type) |
|
|
116 |
|
|
|
117 |
#save model if wanted |
|
|
118 |
if args.output: |
|
|
119 |
torch.save(model.state_dict(), args.output) |
|
|
120 |
print(f"Model has successfully been saved at {args.output}!") |