--- a
+++ b/src/main.py
@@ -0,0 +1,120 @@
+import argparse
+
+parser = argparse.ArgumentParser(
+        description='This class is used to train a transformer-based model on admission notes, labelled with <3 by Patrick.')
+
+parser.add_argument('-o', '--output', type=str, default="",
+                    help='Choose where to save the model after training. Saving is optional.')
+parser.add_argument('-lr', '--learning_rate', type=float, default=1e-2,
+                    help='Choose the learning rate of the model.')
+parser.add_argument('-b', '--batch_size', type=int, default=16,
+                    help='Choose the batch size of the model.')
+parser.add_argument('-e', '--epochs', type=int, default=5,
+                    help='Choose the epochs of the model.')
+parser.add_argument('-opt', '--optimizer', type=str, default='SGD',
+                    help='Choose the optimizer to be used for the model: SDG | Adam')
+parser.add_argument('-tr', '--transfer_learning', type=bool, default=False,
+                    help='Choose whether the BioBERT model should be used as baseline or not.')
+parser.add_argument('-v', '--verbose', type=bool, default=False,
+                    help='Choose whether the model should be evaluated after each epoch or only after the training.')
+parser.add_argument('-l', '--input_length', type=int, default=128,
+                    help='Choose the maximum length of the model\'s input layer.')
+parser.add_argument('-ag', '--data_augmentation', type=bool, default=False,
+                    help='Choose whether data-augmentation should be used.')
+parser.add_argument('-t', '--type', type=str, required=True,
+                    help='Specify the type of annotation to process. Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
+
+args = parser.parse_args()
+
+if args.type not in ['Medical Condition', 'Symptom', 'Medication', 'Vital Statistic', 'Measurement Value', 'Negation Cue', 'Medical Procedure']:
+    raise ValueError('Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
+
+from utils.dataloader import Dataloader
+from utils.BertArchitecture import BertNER
+from utils.BertArchitecture import BioBertNER
+from utils.metric_tracking import MetricsTracking
+from utils.training import train_loop, testing
+
+import torch
+from torch.optim import SGD
+from torch.optim import Adam
+from torch.utils.data import DataLoader
+
+import numpy as np
+import pandas as pd
+
+from tqdm import tqdm
+
+#-------MAIN-------#
+
+if not args.transfer_learning:
+    print("Training base BERT model...")
+    model = BertNER(3) #O, B-, I- -> 3 entities
+
+    if args.type == 'Medical Condition':
+        type = 'MEDCOND'
+    elif args.type == 'Symptom':
+        type = 'SYMPTOM'
+    elif args.type == 'Medication':
+        type = 'MEDICATION'
+    elif args.type == 'Vital Statistic':
+        type = 'VITALSTAT'
+    elif args.type == 'Measurement Value':
+        type = 'MEASVAL'
+    elif args.type == 'Negation Cue':
+        type = 'NEGATION'
+    elif args.type == 'Medical Procedure':
+        type = 'PROCEDURE'
+    else:    
+        raise ValueError('Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
+
+else:
+    print("Training BERT model based on BioBERT diseases...")
+
+    if not args.type == 'Medical Condition':
+        raise ValueError('Type of annotation needs to be Medical Condition when using BioBERT as baseline.')
+
+    model = BioBertNER(3) #O, B-, I- -> 3 entities
+    type = 'DISEASE'
+
+label_to_ids = {
+    'B-' + type: 0,
+    'I-' + type: 1,
+    'O': 2
+    }
+
+ids_to_label = {
+    0:'B-' + type,
+    1:'I-' + type,
+    2:'O'
+    }
+
+dataloader = Dataloader(label_to_ids, ids_to_label, args.transfer_learning, args.input_length, type)
+
+train, val, test = dataloader.load_dataset(augment = args.data_augmentation)
+
+if args.optimizer == 'SGD':
+    print("Using SGD optimizer...")
+    optimizer = SGD(model.parameters(), lr=args.learning_rate, momentum = 0.9)
+else:
+    print("Using Adam optimizer...")
+    optimizer = Adam(model.parameters(), lr=args.learning_rate)
+
+parameters = {
+    "model": model,
+    "train_dataset": train,
+    "eval_dataset" : val,
+    "optimizer" : optimizer,
+    "batch_size" : args.batch_size,
+    "epochs" : args.epochs,
+    "type" : type
+}
+
+train_loop(**parameters, verbose=args.verbose)
+
+testing(model, test, args.batch_size, type)
+
+#save model if wanted
+if args.output:
+    torch.save(model.state_dict(), args.output)
+    print(f"Model has successfully been saved at {args.output}!")