Switch to unified view

a b/medicalbert/tests/tests.py
1
import unittest
2
3
import pandas as pd
4
from torch.utils.data import DataLoader
5
from transformers import BertTokenizer
6
7
from datareader.abstract_data_reader import InputExample
8
from medicalbert.datareader.chunked_data_reader import ChunkedDataReader
9
10
class TestChunkedDataReader(unittest.TestCase):
11
    def __init__(self, methodName='runTest'):
12
        super().__init__(methodName)
13
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
14
        self.config = {"max_sequence_length": 10, "target":"target", "num_sections": 10}
15
        self.cdr = ChunkedDataReader(self.config, self.tokenizer)
16
17
    def test_chunker_gen(self):
18
19
        # create a test string
20
        test_input = "Hi My name is Andrew Patterson and I made this".split()
21
22
        from medicalbert.datareader.chunked_data_reader import ChunkedDataReader
23
        sections = [section for section in ChunkedDataReader.chunks(test_input, 3)]
24
25
        self.assertTrue(len(sections) == 4, "Correct number of sections returned")
26
27
        self.assertEqual(sections[0], ['Hi', 'My', 'name'], "First section is correct")
28
29
        self.assertEqual(sections[3], ['this'], "Last section is correct")
30
31
    def assertInputFeatureIsValid(self, inputFeature, sep_index):
32
        # check the length
33
        self.assertTrue(len(inputFeature.input_ids) == 10)
34
        self.assertTrue(len(inputFeature.segment_ids) == 10)
35
        self.assertTrue(len(inputFeature.input_mask) == 10)
36
37
        # now check that the cls and sep tokens are in the correct place.
38
        self.assertTrue(inputFeature.input_ids[0] == 101)
39
        self.assertTrue(inputFeature.input_ids[sep_index] == 102)
40
41
        # now check that the padded space is filled with zeroes
42
        expected = [0] * 501
43
        actual = inputFeature.input_ids[11:]
44
        self.assertTrue(expected, actual)
45
46
    @staticmethod
47
    def make_test_data():
48
        # make a dummy dataset
49
50
        examples_text = ["Hi My name is Andrew Patterson and I made this and so I must test it.",
51
                         "Hi My name is Andrew",
52
                         "Hi My name is Andrew Patterson and I made this and so I must test it.",
53
                         "Hi My name is Andrew", ]
54
        examples_label = [1, 1, 1, 1]
55
56
        data = {'text': examples_text, 'target': examples_label}
57
58
59
        return pd.DataFrame.from_dict(data)
60
61
    def test_build_fresh_dataset(self):
62
        test_data = TestChunkedDataReader.make_test_data()
63
64
        tensor_dataset = self.cdr.build_fresh_dataset(test_data)
65
66
        self.assertTrue(4, len(tensor_dataset[0])) # This checks the number of features
67
        print(tensor_dataset[0][0].shape)
68
69
    def test_convert_section_to_feature_short(self):
70
        # create a test string that is shorter than the max sequence length
71
        test_input = "Hi My name is Andrew"
72
73
        tokens = self.tokenizer.tokenize(test_input)
74
75
        # convert to a feature
76
        inputFeature = self.cdr.convert_section_to_feature(tokens, "1")
77
78
        self.assertInputFeatureIsValid(inputFeature, 6)
79
80
    def test_convert_section_to_feature_long(self):
81
        # create a test string that is longer than the max sequence length
82
        test_input = "Hi My name is Andrew Patterson and I made this and so I must test it."
83
84
        tokens = self.tokenizer.tokenize(test_input)
85
86
        # convert to a feature
87
        inputFeature = self.cdr.convert_section_to_feature(tokens, "1")
88
89
        self.assertInputFeatureIsValid(inputFeature, 9)
90
91
    def test_convert_example_to_feature(self):
92
        # create a test string that is longer than the max sequence length
93
        test_input = "Hi My name is Andrew Patterson and I made this and so I must test it."
94
        e = InputExample(None, test_input, None, 1)
95
96
        result = self.cdr.convert_example_to_feature(e, 1)
97
98
if __name__ == '__main__':
99
    unittest.main()