Diff of /src/main.py [000000] .. [0eda78]

Switch to unified view

a b/src/main.py
1
import argparse
2
3
parser = argparse.ArgumentParser(
4
        description='This class is used to train a transformer-based model on admission notes, labelled with <3 by Patrick.')
5
6
parser.add_argument('-o', '--output', type=str, default="",
7
                    help='Choose where to save the model after training. Saving is optional.')
8
parser.add_argument('-lr', '--learning_rate', type=float, default=1e-2,
9
                    help='Choose the learning rate of the model.')
10
parser.add_argument('-b', '--batch_size', type=int, default=16,
11
                    help='Choose the batch size of the model.')
12
parser.add_argument('-e', '--epochs', type=int, default=5,
13
                    help='Choose the epochs of the model.')
14
parser.add_argument('-opt', '--optimizer', type=str, default='SGD',
15
                    help='Choose the optimizer to be used for the model: SDG | Adam')
16
parser.add_argument('-tr', '--transfer_learning', type=bool, default=False,
17
                    help='Choose whether the BioBERT model should be used as baseline or not.')
18
parser.add_argument('-v', '--verbose', type=bool, default=False,
19
                    help='Choose whether the model should be evaluated after each epoch or only after the training.')
20
parser.add_argument('-l', '--input_length', type=int, default=128,
21
                    help='Choose the maximum length of the model\'s input layer.')
22
parser.add_argument('-ag', '--data_augmentation', type=bool, default=False,
23
                    help='Choose whether data-augmentation should be used.')
24
parser.add_argument('-t', '--type', type=str, required=True,
25
                    help='Specify the type of annotation to process. Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
26
27
args = parser.parse_args()
28
29
if args.type not in ['Medical Condition', 'Symptom', 'Medication', 'Vital Statistic', 'Measurement Value', 'Negation Cue', 'Medical Procedure']:
30
    raise ValueError('Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
31
32
from utils.dataloader import Dataloader
33
from utils.BertArchitecture import BertNER
34
from utils.BertArchitecture import BioBertNER
35
from utils.metric_tracking import MetricsTracking
36
from utils.training import train_loop, testing
37
38
import torch
39
from torch.optim import SGD
40
from torch.optim import Adam
41
from torch.utils.data import DataLoader
42
43
import numpy as np
44
import pandas as pd
45
46
from tqdm import tqdm
47
48
#-------MAIN-------#
49
50
if not args.transfer_learning:
51
    print("Training base BERT model...")
52
    model = BertNER(3) #O, B-, I- -> 3 entities
53
54
    if args.type == 'Medical Condition':
55
        type = 'MEDCOND'
56
    elif args.type == 'Symptom':
57
        type = 'SYMPTOM'
58
    elif args.type == 'Medication':
59
        type = 'MEDICATION'
60
    elif args.type == 'Vital Statistic':
61
        type = 'VITALSTAT'
62
    elif args.type == 'Measurement Value':
63
        type = 'MEASVAL'
64
    elif args.type == 'Negation Cue':
65
        type = 'NEGATION'
66
    elif args.type == 'Medical Procedure':
67
        type = 'PROCEDURE'
68
    else:    
69
        raise ValueError('Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
70
71
else:
72
    print("Training BERT model based on BioBERT diseases...")
73
74
    if not args.type == 'Medical Condition':
75
        raise ValueError('Type of annotation needs to be Medical Condition when using BioBERT as baseline.')
76
77
    model = BioBertNER(3) #O, B-, I- -> 3 entities
78
    type = 'DISEASE'
79
80
label_to_ids = {
81
    'B-' + type: 0,
82
    'I-' + type: 1,
83
    'O': 2
84
    }
85
86
ids_to_label = {
87
    0:'B-' + type,
88
    1:'I-' + type,
89
    2:'O'
90
    }
91
92
dataloader = Dataloader(label_to_ids, ids_to_label, args.transfer_learning, args.input_length, type)
93
94
train, val, test = dataloader.load_dataset(augment = args.data_augmentation)
95
96
if args.optimizer == 'SGD':
97
    print("Using SGD optimizer...")
98
    optimizer = SGD(model.parameters(), lr=args.learning_rate, momentum = 0.9)
99
else:
100
    print("Using Adam optimizer...")
101
    optimizer = Adam(model.parameters(), lr=args.learning_rate)
102
103
parameters = {
104
    "model": model,
105
    "train_dataset": train,
106
    "eval_dataset" : val,
107
    "optimizer" : optimizer,
108
    "batch_size" : args.batch_size,
109
    "epochs" : args.epochs,
110
    "type" : type
111
}
112
113
train_loop(**parameters, verbose=args.verbose)
114
115
testing(model, test, args.batch_size, type)
116
117
#save model if wanted
118
if args.output:
119
    torch.save(model.state_dict(), args.output)
120
    print(f"Model has successfully been saved at {args.output}!")