Diff of /src/testsuite.py [000000] .. [0eda78]

Switch to unified view

a b/src/testsuite.py
1
import unittest
2
import numpy as np
3
import torch
4
from utils.dataloader import Dataloader, tokenize_and_preserve_labels
5
from utils.metric_tracking import MetricsTracking 
6
from transformers import BertTokenizer
7
import re
8
9
class DataloaderTest(unittest.TestCase):
10
11
    def test_tokenize_sentence(self):
12
        label_to_ids = {
13
           'B-MEDCOND': 0,
14
           'I-MEDCOND': 1,
15
           'O': 2
16
        }
17
        ids_to_label = {
18
               0:'B-MEDCOND',
19
               1:'I-MEDCOND',
20
               2:'O'
21
        }
22
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
23
        max_tokens = 128
24
25
        sentence = "Patient presents with glaucoma, characterized by a definitive diagnosis of pigmentary glaucoma. Intraocular pressure measures at 15 mmHg, while the visual field remains normal. Visual acuity is recorded as 20/50. The patient has not undergone prior cataract surgery, but has had LASIK surgery. Additionally, comorbid ocular diseases include macular degeneration."
26
27
        tokens = "O O O B-MEDCOND O O O O O O O B-MEDCOND I-MEDCOND O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O O B-MEDCOND I-MEDCOND I-MEDCOND O B-MEDCOND I-MEDCOND O"
28
29
        sentence = re.findall(r"\w+|\w+(?='s)|'s|['\".,!?;]", sentence.strip(), re.UNICODE)
30
        tokens = tokens.split(" ")
31
32
        t_sen, t_labl = tokenize_and_preserve_labels(sentence, tokens, tokenizer, label_to_ids, ids_to_label, max_tokens)
33
34
        self.assertEqual(len(t_sen), len(t_labl))
35
        self.assertEqual(t_labl.count("B-MEDCOND"), tokens.count("B-MEDCOND"))
36
37
    def test_load_dataset(self):
38
        label_to_ids = {
39
           'B-MEDCOND': 0,
40
           'I-MEDCOND': 1,
41
           'O': 2
42
        }
43
        ids_to_label = {
44
               0:'B-MEDCOND',
45
               1:'I-MEDCOND',
46
               2:'O'
47
        }
48
49
        dataloader = Dataloader(label_to_ids, ids_to_label)
50
51
        dataset = dataloader.load_dataset(full = True)
52
53
        self.assertEqual(len(dataset), 255)
54
55
        sample = dataset.__getitem__(0)
56
57
        self.assertEqual(len(sample), 4) #input_ids, attention_mask, token_type_ids, entity
58
59
60
61
class MetricsTrackingTest(unittest.TestCase):
62
63
    def test_avg_metrics(self):
64
        predictions =  np.array([-100, 0, 0, 0, 1, 1, 1, 2, 2, 2])
65
        ground_truth = np.array([-100, 0, 1, 0, 1, 2, 1, 2, 0, 2]) #arbitrary, should return 67% for each metric
66
        
67
        predictions = torch.from_numpy(predictions)
68
        ground_truth = torch.from_numpy(ground_truth)
69
70
        tracker = MetricsTracking()
71
        tracker.update(predictions, ground_truth, 0.1)
72
73
        metrics = tracker.return_avg_metrics(1) #tracker only updated once
74
75
        self.assertEqual(metrics['acc'], 0.667)
76
        self.assertEqual(metrics['f1'], 0.667)
77
        self.assertEqual(metrics['precision'], 0.667)
78
        self.assertEqual(metrics['recall'], 0.667)
79
        self.assertEqual(metrics['loss'], 0.1)
80
81
if __name__ == '__main__':
82
    unittest.main()