Switch to side-by-side view

--- a
+++ b/src/predict_single_sentence.py
@@ -0,0 +1,117 @@
+import argparse
+parser = argparse.ArgumentParser(description='Enter a sentence for the model to work on.')
+
+parser.add_argument('-l', '--length', type=int, default=128,
+                    help='Choose the maximum length of the model\'s input layer.')
+parser.add_argument('-m', '--model', type=str, default='../models/medcondbert.pth',
+                    help='Choose the directory of the model to be used for prediction.')
+parser.add_argument('-tr', '--transfer_learning', type=bool, default=False,
+                    help='Choose whether the given model has been trained on BioBERT or not. \
+                    Careful: It will not work if wrongly specified!')
+parser.add_argument('sentence', type=str,
+                    help='Write your input sentence, preferrably an admission note!')
+parser.add_argument('-t', '--type', type=str, required=True,
+                    help='Specify the type of annotation to process. Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
+
+args = parser.parse_args()
+max_length = args.length
+model_path = args.model
+sentence = args.sentence
+
+from utils.dataloader import Dataloader
+from utils.BertArchitecture import BertNER, BioBertNER
+from utils.metric_tracking import MetricsTracking
+
+import torch
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+import numpy as np
+import pandas as pd
+
+from tqdm import tqdm
+from transformers import BertTokenizer,BertForTokenClassification
+
+if not args.transfer_learning:
+    print("Training base BERT model...")
+    model = BertNER(3) #O, B-, I- -> 3 entities
+
+    if args.type == 'Medical Condition':
+        type = 'MEDCOND'
+    elif args.type == 'Symptom':
+        type = 'SYMPTOM'
+    elif args.type == 'Medication':
+        type = 'MEDICATION'
+    elif args.type == 'Vital Statistic':
+        type = 'VITALSTAT'
+    elif args.type == 'Measurement Value':
+        type = 'MEASVAL'
+    elif args.type == 'Negation Cue':
+        type = 'NEGATION'
+    elif args.type == 'Medical Procedure':
+        type = 'PROCEDURE'
+    else:    
+        raise ValueError('Type of annotation needs to be one of the following: Medical Condition, Symptom, Medication, Vital Statistic, Measurement Value, Negation Cue, Medical Procedure')
+    
+    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
+    tokenizer.add_tokens(['B-' + args.type, 'I-' + args.type])
+else:
+    print("Training BERT model based on BioBERT diseases...")
+
+    if not args.type == 'Medical Condition':
+        raise ValueError('Type of annotation needs to be Medical Condition when using BioBERT as baseline.')
+
+    model = BioBertNER(3) #O, B-, I- -> 3 entities
+    tokenizer = BertTokenizer.from_pretrained('alvaroalon2/biobert_diseases_ner')
+    type = 'DISEASE'
+
+label_to_ids = {
+    'B-' + type: 0,
+    'I-' + type: 1,
+    'O': 2
+    }
+
+ids_to_label = {
+    0:'B-' + type,
+    1:'I-' + type,
+    2:'O'
+    }
+
+model.load_state_dict(torch.load(model_path))
+model.eval()
+
+
+t_sen = tokenizer.tokenize(sentence)
+
+sen_code = tokenizer.encode_plus(sentence,
+    return_tensors='pt',
+    add_special_tokens=True,
+    max_length = max_length,
+    padding='max_length',
+    return_attention_mask=True,
+    truncation = True
+    )
+inputs = {key: torch.as_tensor(val) for key, val in sen_code.items()}
+
+attention_mask = inputs['attention_mask'].squeeze(1)
+input_ids = inputs['input_ids'].squeeze(1)
+
+outputs = model(input_ids, attention_mask)
+
+predictions = outputs.logits.argmax(dim=-1)
+predictions = [ids_to_label.get(x) for x in predictions.numpy()[0]]
+
+#beware special tokens
+cutoff = min(len(predictions)-1, len(t_sen))
+predictions = predictions[1:cutoff+1]
+t_sen = t_sen[:cutoff]
+
+df = pd.DataFrame({
+    'Sentence': t_sen,
+    'Prediction': predictions
+})
+pd.set_option('display.max_rows', None)
+pd.set_option('display.max_columns', None)
+pd.set_option('display.width', 1000)
+pd.set_option('display.max_colwidth', None)
+print(df)