Switch to unified view

a b/medicalbert/datareader/chunked_data_reader.py
1
import logging
2
import os
3
4
import pandas as pd
5
import torch
6
from torch.utils.data import TensorDataset
7
from tqdm import tqdm
8
9
from datareader.FeatureSetBuilder import FeatureSetBuilder
10
from datareader.abstract_data_reader import AbstractDataReader, InputExample
11
12
13
class InputFeatures(object):
14
    """A single set of features of data."""
15
16
    def __init__(self, input_ids, input_mask, segment_ids, label_id):
17
        self.input_ids = input_ids
18
        self.input_mask = input_mask
19
        self.segment_ids = segment_ids
20
        self.label_id = label_id
21
22
23
class ChunkedDataReader(AbstractDataReader):
24
25
    def __init__(self, config, tokenizer):
26
        self.tokenizer = tokenizer
27
        self.max_sequence_length = config['max_sequence_length']
28
        self.config = config
29
        self.train = None
30
        self.valid = None
31
        self.test = None
32
        self.num_sections = config['num_sections']
33
34
    @staticmethod
35
    def chunks(lst, n):
36
        """Yield successive n-sized chunks from lst."""
37
        for i in range(0, len(lst), n):
38
            yield lst[i:i + n]
39
40
    def build_fresh_dataset(self, dataset):
41
        logging.info("Building fresh dataset...")
42
43
        df = pd.read_csv(os.path.join(self.config['data_dir'], dataset))
44
45
        return self.build_fresh_dataset(df)
46
47
    def _convert_rows_to_list_of_feature(self, df):
48
        input_features = []
49
        for _, row in tqdm(df.iterrows(), total=df.shape[0]):
50
            text = row['text']
51
            lbl = row[self.config['target']]
52
53
            input_example = InputExample(None, text, None, self.config['target'])
54
            feature = self.convert_example_to_feature(input_example, lbl)
55
56
            input_features.append(feature)
57
58
        return input_features
59
60
    def build_fresh_dataset(self, dataset):
61
        df = pd.read_csv(os.path.join(self.config['data_dir'], dataset))
62
63
        return self.build_fresh_dataset_from_df(df)
64
    def build_fresh_dataset_from_df(self, df):
65
        logging.info("Building fresh dataset...")
66
67
        features = self._convert_rows_to_list_of_feature(df)
68
69
        # Now parse them out into the proper parts.
70
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
71
        all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
72
        all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
73
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
74
75
        return TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
76
77
    def convert_example_to_feature(self, example, label):
78
79
        # create a new feature set builder for this example
80
        inputFeatureBuilder = FeatureSetBuilder(label)
81
82
        # tokenize the text into a list
83
        tokens_a = self.tokenizer.tokenize(example.text_a)
84
85
        # chunk the list of tokens
86
        generator = self.chunks(tokens_a, self.max_sequence_length - 2)
87
88
        for section in generator:
89
            # convert the section to a feature
90
            section_feature = self.convert_section_to_feature(section, label)
91
92
            inputFeatureBuilder.add(section_feature)
93
94
        inputFeatureBuilder.resize(self.num_sections, self.convert_section_to_feature([0], label))
95
        assert len(inputFeatureBuilder.get()) == self.num_sections
96
97
        # We return the builder
98
        input_ids = [feature.input_ids for feature in inputFeatureBuilder.features ]
99
        input_masks = [feature.input_mask for feature in inputFeatureBuilder.features]
100
        segment_ids = [feature.segment_ids for feature in inputFeatureBuilder.features]
101
102
        # Now create a new 'type' of inputfeature
103
        return InputFeatures(input_ids, input_masks, segment_ids, label)
104
105
    def convert_section_to_feature(self, tokens_a, label):
106
107
        # Truncate the section if needed
108
        if len(tokens_a) > (self.max_sequence_length - 2):
109
            tokens_a = tokens_a[-(self.max_sequence_length - 2):]
110
111
        tokens = []
112
        segment_ids = []
113
        tokens.append("[CLS]")
114
        segment_ids.append(0)
115
        for token in tokens_a:
116
            tokens.append(token)
117
            segment_ids.append(0)
118
        tokens.append("[SEP]")
119
        segment_ids.append(0)
120
121
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
122
123
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
124
        # tokens are attended to.
125
        input_mask = [1] * len(input_ids)
126
127
        # Zero-pad up to the sequence length.
128
        while len(input_ids) < self.max_sequence_length:
129
            input_ids.append(0)
130
            input_mask.append(0)
131
            segment_ids.append(0)
132
133
        assert len(input_ids) == self.max_sequence_length
134
        assert len(input_mask) == self.max_sequence_length
135
        assert len(segment_ids) == self.max_sequence_length
136
137
        return InputFeatures(input_ids, input_mask, segment_ids, label)