[0eda78]: / src / main.py

Download this file

121 lines (97 with data), 4.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
parser = argparse.ArgumentParser(
description='This class is used to train a transformer-based model on admission notes, labelled with <3 by Patrick.')
parser.add_argument('-o', '--output', type=str, default="",
help='Choose where to save the model after training. Saving is optional.')
parser.add_argument('-lr', '--learning_rate', type=float, default=1e-2,
help='Choose the learning rate of the model.')
parser.add_argument('-b', '--batch_size', type=int, default=16,
help='Choose the batch size of the model.')
parser.add_argument('-e', '--epochs', type=int, default=5,
help='Choose the epochs of the model.')
parser.add_argument('-opt', '--optimizer', type=str, default='SGD',
help='Choose the optimizer to be used for the model: SDG | Adam')
parser.add_argument('-tr', '--transfer_learning', type=bool, default=False,
help='Choose whether the BioBERT model should be used as baseline or not.')
parser.add_argument('-v', '--verbose', type=bool, default=False,
help='Choose whether the model should be evaluated after each epoch or only after the training.')
parser.add_argument('-l', '--input_length', type=int, default=128,
help='Choose the maximum length of the model\'s input layer.')
parser.add_argument('-ag', '--data_augmentation', type=bool, default=False,
help='Choose whether data-augmentation should be used.')
parser.add_argument('-t', '--type', type=str, required=True,
help='Specify the type of annotation to process. Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
args = parser.parse_args()
if args.type not in ['Medical Condition', 'Symptom', 'Medication', 'Vital Statistic', 'Measurement Value', 'Negation Cue', 'Medical Procedure']:
raise ValueError('Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
from utils.dataloader import Dataloader
from utils.BertArchitecture import BertNER
from utils.BertArchitecture import BioBertNER
from utils.metric_tracking import MetricsTracking
from utils.training import train_loop, testing
import torch
from torch.optim import SGD
from torch.optim import Adam
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from tqdm import tqdm
#-------MAIN-------#
if not args.transfer_learning:
print("Training base BERT model...")
model = BertNER(3) #O, B-, I- -> 3 entities
if args.type == 'Medical Condition':
type = 'MEDCOND'
elif args.type == 'Symptom':
type = 'SYMPTOM'
elif args.type == 'Medication':
type = 'MEDICATION'
elif args.type == 'Vital Statistic':
type = 'VITALSTAT'
elif args.type == 'Measurement Value':
type = 'MEASVAL'
elif args.type == 'Negation Cue':
type = 'NEGATION'
elif args.type == 'Medical Procedure':
type = 'PROCEDURE'
else:
raise ValueError('Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
else:
print("Training BERT model based on BioBERT diseases...")
if not args.type == 'Medical Condition':
raise ValueError('Type of annotation needs to be Medical Condition when using BioBERT as baseline.')
model = BioBertNER(3) #O, B-, I- -> 3 entities
type = 'DISEASE'
label_to_ids = {
'B-' + type: 0,
'I-' + type: 1,
'O': 2
}
ids_to_label = {
0:'B-' + type,
1:'I-' + type,
2:'O'
}
dataloader = Dataloader(label_to_ids, ids_to_label, args.transfer_learning, args.input_length, type)
train, val, test = dataloader.load_dataset(augment = args.data_augmentation)
if args.optimizer == 'SGD':
print("Using SGD optimizer...")
optimizer = SGD(model.parameters(), lr=args.learning_rate, momentum = 0.9)
else:
print("Using Adam optimizer...")
optimizer = Adam(model.parameters(), lr=args.learning_rate)
parameters = {
"model": model,
"train_dataset": train,
"eval_dataset" : val,
"optimizer" : optimizer,
"batch_size" : args.batch_size,
"epochs" : args.epochs,
"type" : type
}
train_loop(**parameters, verbose=args.verbose)
testing(model, test, args.batch_size, type)
#save model if wanted
if args.output:
torch.save(model.state_dict(), args.output)
print(f"Model has successfully been saved at {args.output}!")