|
a |
|
b/medicalbert/datareader/StandardDataReader.py |
|
|
1 |
import logging |
|
|
2 |
import os |
|
|
3 |
|
|
|
4 |
import pandas as pd |
|
|
5 |
import torch |
|
|
6 |
from torch.utils.data import TensorDataset |
|
|
7 |
from tqdm import tqdm |
|
|
8 |
|
|
|
9 |
from datareader.abstract_data_reader import AbstractDataReader, InputExample |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class InputFeatures(object): |
|
|
13 |
"""A single set of features of data.""" |
|
|
14 |
|
|
|
15 |
def __init__(self, input_ids, input_mask, segment_ids, label_id): |
|
|
16 |
self.input_ids = input_ids |
|
|
17 |
self.input_mask = input_mask |
|
|
18 |
self.segment_ids = segment_ids |
|
|
19 |
self.label_id = label_id |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
class StandardDataReader(AbstractDataReader): |
|
|
23 |
|
|
|
24 |
def __init__(self, config, tokenizer): |
|
|
25 |
self.tokenizer = tokenizer |
|
|
26 |
self.max_sequence_length = config['max_sequence_length'] |
|
|
27 |
self.config = config |
|
|
28 |
self.train = None |
|
|
29 |
self.valid = None |
|
|
30 |
self.test = None |
|
|
31 |
|
|
|
32 |
def build_fresh_dataset(self, dataset): |
|
|
33 |
logging.info("Building fresh dataset...") |
|
|
34 |
|
|
|
35 |
df = pd.read_csv(os.path.join(self.config['data_dir'], dataset)) |
|
|
36 |
|
|
|
37 |
input_features = [] |
|
|
38 |
df['text'] = df['text'].str.replace(r'\t', ' ', regex=True) |
|
|
39 |
df['text'] = df['text'].str.replace(r'\n', ' ', regex=True) |
|
|
40 |
df['text'] = df['text'].str.lower() |
|
|
41 |
|
|
|
42 |
for _, row in tqdm(df.iterrows(), total=df.shape[0]): |
|
|
43 |
text = row['text'] |
|
|
44 |
lbl = row[self.config['target']] |
|
|
45 |
|
|
|
46 |
input_example = InputExample(None, text, None, self.config['target']) |
|
|
47 |
feature = self.convert_example_to_feature(input_example, lbl) |
|
|
48 |
input_features.append(feature) |
|
|
49 |
|
|
|
50 |
all_input_ids = torch.tensor([f.input_ids for f in input_features], dtype=torch.long) |
|
|
51 |
all_input_mask = torch.tensor([f.input_mask for f in input_features], dtype=torch.long) |
|
|
52 |
all_segment_ids = torch.tensor([f.segment_ids for f in input_features], dtype=torch.long) |
|
|
53 |
all_label_ids = torch.tensor([f.label_id for f in input_features], dtype=torch.long) |
|
|
54 |
|
|
|
55 |
td = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) |
|
|
56 |
return td |
|
|
57 |
|
|
|
58 |
def convert_example_to_feature(self, example, lbl): |
|
|
59 |
"""Loads a data file into a list of `InputBatch`s.""" |
|
|
60 |
|
|
|
61 |
# tokenize the first text. |
|
|
62 |
tokens_a = self.tokenizer.tokenize(example.text_a) |
|
|
63 |
|
|
|
64 |
# if its a sentence-pair task, tokenize the second |
|
|
65 |
tokens_b = None |
|
|
66 |
if example.text_b: |
|
|
67 |
tokens_b = self.tokenizer.tokenize(example.text_b) |
|
|
68 |
|
|
|
69 |
if tokens_b: |
|
|
70 |
# Modifies `tokens_a` and `tokens_b` in place so that the total |
|
|
71 |
# length is less than the specified length. |
|
|
72 |
# Account for [CLS], [SEP], [SEP] with "- 3" |
|
|
73 |
AbstractDataReader.truncate_seq_pair(tokens_a, tokens_b, self.max_sequence_length - 3) |
|
|
74 |
else: |
|
|
75 |
# Account for [CLS] and [SEP] with "- 2" |
|
|
76 |
if len(tokens_a) > (self.max_sequence_length - 2): |
|
|
77 |
tokens_a = tokens_a[-(self.max_sequence_length - 2):] |
|
|
78 |
|
|
|
79 |
tokens = [] |
|
|
80 |
segment_ids = [] |
|
|
81 |
tokens.append("[CLS]") |
|
|
82 |
segment_ids.append(0) |
|
|
83 |
for token in tokens_a: |
|
|
84 |
tokens.append(token) |
|
|
85 |
segment_ids.append(0) |
|
|
86 |
tokens.append("[SEP]") |
|
|
87 |
segment_ids.append(0) |
|
|
88 |
|
|
|
89 |
if tokens_b: |
|
|
90 |
for token in tokens_b: |
|
|
91 |
tokens.append(token) |
|
|
92 |
segment_ids.append(1) |
|
|
93 |
tokens.append("[SEP]") |
|
|
94 |
segment_ids.append(1) |
|
|
95 |
|
|
|
96 |
input_ids = self.tokenizer.convert_tokens_to_ids(tokens) |
|
|
97 |
|
|
|
98 |
# The mask has 1 for real tokens and 0 for padding tokens. Only real |
|
|
99 |
# tokens are attended to. |
|
|
100 |
input_mask = [1] * len(input_ids) |
|
|
101 |
|
|
|
102 |
# Zero-pad up to the sequence length. |
|
|
103 |
while len(input_ids) < self.max_sequence_length: |
|
|
104 |
input_ids.append(0) |
|
|
105 |
input_mask.append(0) |
|
|
106 |
segment_ids.append(0) |
|
|
107 |
|
|
|
108 |
assert len(input_ids) == self.max_sequence_length |
|
|
109 |
assert len(input_mask) == self.max_sequence_length |
|
|
110 |
assert len(segment_ids) == self.max_sequence_length |
|
|
111 |
|
|
|
112 |
return InputFeatures(input_ids, input_mask, segment_ids, lbl) |
|
|
113 |
|