Switch to unified view

a b/medicalbert/datareader/StandardDataReader.py
1
import logging
2
import os
3
4
import pandas as pd
5
import torch
6
from torch.utils.data import TensorDataset
7
from tqdm import tqdm
8
9
from datareader.abstract_data_reader import AbstractDataReader, InputExample
10
11
12
class InputFeatures(object):
13
    """A single set of features of data."""
14
15
    def __init__(self, input_ids, input_mask, segment_ids, label_id):
16
        self.input_ids = input_ids
17
        self.input_mask = input_mask
18
        self.segment_ids = segment_ids
19
        self.label_id = label_id
20
21
22
class StandardDataReader(AbstractDataReader):
23
24
    def __init__(self, config, tokenizer):
25
        self.tokenizer = tokenizer
26
        self.max_sequence_length = config['max_sequence_length']
27
        self.config = config
28
        self.train = None
29
        self.valid = None
30
        self.test = None
31
32
    def build_fresh_dataset(self, dataset):
33
        logging.info("Building fresh dataset...")
34
35
        df = pd.read_csv(os.path.join(self.config['data_dir'], dataset))
36
37
        input_features = []
38
        df['text'] = df['text'].str.replace(r'\t', ' ', regex=True)
39
        df['text'] = df['text'].str.replace(r'\n', ' ', regex=True)
40
        df['text'] = df['text'].str.lower()
41
42
        for _, row in tqdm(df.iterrows(), total=df.shape[0]):
43
            text = row['text']
44
            lbl = row[self.config['target']]
45
46
            input_example = InputExample(None, text, None, self.config['target'])
47
            feature = self.convert_example_to_feature(input_example, lbl)
48
            input_features.append(feature)
49
50
        all_input_ids = torch.tensor([f.input_ids for f in input_features], dtype=torch.long)
51
        all_input_mask = torch.tensor([f.input_mask for f in input_features], dtype=torch.long)
52
        all_segment_ids = torch.tensor([f.segment_ids for f in input_features], dtype=torch.long)
53
        all_label_ids = torch.tensor([f.label_id for f in input_features], dtype=torch.long)
54
55
        td = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
56
        return td
57
58
    def convert_example_to_feature(self, example, lbl):
59
        """Loads a data file into a list of `InputBatch`s."""
60
61
        # tokenize the first text.
62
        tokens_a = self.tokenizer.tokenize(example.text_a)
63
64
        # if its a sentence-pair task, tokenize the second
65
        tokens_b = None
66
        if example.text_b:
67
            tokens_b = self.tokenizer.tokenize(example.text_b)
68
69
        if tokens_b:
70
            # Modifies `tokens_a` and `tokens_b` in place so that the total
71
            # length is less than the specified length.
72
            # Account for [CLS], [SEP], [SEP] with "- 3"
73
            AbstractDataReader.truncate_seq_pair(tokens_a, tokens_b, self.max_sequence_length - 3)
74
        else:
75
            # Account for [CLS] and [SEP] with "- 2"
76
            if len(tokens_a) > (self.max_sequence_length - 2):
77
                tokens_a = tokens_a[-(self.max_sequence_length - 2):]
78
79
        tokens = []
80
        segment_ids = []
81
        tokens.append("[CLS]")
82
        segment_ids.append(0)
83
        for token in tokens_a:
84
            tokens.append(token)
85
            segment_ids.append(0)
86
        tokens.append("[SEP]")
87
        segment_ids.append(0)
88
89
        if tokens_b:
90
            for token in tokens_b:
91
                tokens.append(token)
92
                segment_ids.append(1)
93
            tokens.append("[SEP]")
94
            segment_ids.append(1)
95
96
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
97
98
        # The mask has 1 for real tokens and 0 for padding tokens. Only real
99
        # tokens are attended to.
100
        input_mask = [1] * len(input_ids)
101
102
        # Zero-pad up to the sequence length.
103
        while len(input_ids) < self.max_sequence_length:
104
            input_ids.append(0)
105
            input_mask.append(0)
106
            segment_ids.append(0)
107
108
        assert len(input_ids) == self.max_sequence_length
109
        assert len(input_mask) == self.max_sequence_length
110
        assert len(segment_ids) == self.max_sequence_length
111
112
        return InputFeatures(input_ids, input_mask, segment_ids, lbl)
113