--- a +++ b/src/testsuite.py @@ -0,0 +1,82 @@ +import unittest +import numpy as np +import torch +from utils.dataloader import Dataloader, tokenize_and_preserve_labels +from utils.metric_tracking import MetricsTracking +from transformers import BertTokenizer +import re + +class DataloaderTest(unittest.TestCase): + + def test_tokenize_sentence(self): + label_to_ids = { + 'B-MEDCOND': 0, + 'I-MEDCOND': 1, + 'O': 2 + } + ids_to_label = { + 0:'B-MEDCOND', + 1:'I-MEDCOND', + 2:'O' + } + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + max_tokens = 128 + + sentence = "Patient presents with glaucoma, characterized by a definitive diagnosis of pigmentary glaucoma. Intraocular pressure measures at 15 mmHg, while the visual field remains normal. Visual acuity is recorded as 20/50. The patient has not undergone prior cataract surgery, but has had LASIK surgery. Additionally, comorbid ocular diseases include macular degeneration." + + tokens = "O O O B-MEDCOND O O O O O O O B-MEDCOND I-MEDCOND O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MEDCOND I-MEDCOND I-MEDCOND O B-MEDCOND I-MEDCOND O" + + sentence = re.findall(r"\w+|\w+(?='s)|'s|['\".,!?;]", sentence.strip(), re.UNICODE) + tokens = tokens.split(" ") + + t_sen, t_labl = tokenize_and_preserve_labels(sentence, tokens, tokenizer, label_to_ids, ids_to_label, max_tokens) + + self.assertEqual(len(t_sen), len(t_labl)) + self.assertEqual(t_labl.count("B-MEDCOND"), tokens.count("B-MEDCOND")) + + def test_load_dataset(self): + label_to_ids = { + 'B-MEDCOND': 0, + 'I-MEDCOND': 1, + 'O': 2 + } + ids_to_label = { + 0:'B-MEDCOND', + 1:'I-MEDCOND', + 2:'O' + } + + dataloader = Dataloader(label_to_ids, ids_to_label) + + dataset = dataloader.load_dataset(full = True) + + self.assertEqual(len(dataset), 255) + + sample = dataset.__getitem__(0) + + self.assertEqual(len(sample), 4) #input_ids, attention_mask, token_type_ids, entity + + + +class MetricsTrackingTest(unittest.TestCase): + + def test_avg_metrics(self): + predictions = np.array([-100, 0, 0, 0, 1, 1, 1, 2, 2, 2]) + ground_truth = np.array([-100, 0, 1, 0, 1, 2, 1, 2, 0, 2]) #arbitrary, should return 67% for each metric + + predictions = torch.from_numpy(predictions) + ground_truth = torch.from_numpy(ground_truth) + + tracker = MetricsTracking() + tracker.update(predictions, ground_truth, 0.1) + + metrics = tracker.return_avg_metrics(1) #tracker only updated once + + self.assertEqual(metrics['acc'], 0.667) + self.assertEqual(metrics['f1'], 0.667) + self.assertEqual(metrics['precision'], 0.667) + self.assertEqual(metrics['recall'], 0.667) + self.assertEqual(metrics['loss'], 0.1) + +if __name__ == '__main__': + unittest.main()