[d129b2]: / medicalbert / datareader / abstract_data_reader.py

Download this file

138 lines (108 with data), 5.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# This method is the public interface. We use this to get a dataset.
# If a tensor dataset does not exist, we create it.
import logging, os, torch, gcsfs
from pathlib import Path
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
#We suppress logging below error for this library, otherwise seq. longer than 512 will spam the console.
logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class AbstractDataReader:
def __init__(self, config, tokenizer):
self.tokenizer = tokenizer
self.max_sequence_length = config['max_sequence_length']
self.config = config
self.train = None
self.valid = None
self.test = None
@staticmethod
def truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def load_from_cache(self, dataset):
path = os.path.join(self.config['output_dir'], self.config['experiment_name'])
saved_file = os.path.join(path, Path(dataset).stem + ".pt")
# If we're using localfilesystem.
if saved_file[:2] != "gs":
if os.path.isfile(saved_file):
logging.info("Using Cached dataset from local disk {} - saves time!".format(saved_file))
return torch.load(saved_file)
#If we're here were using gcsfs
try:
fs = gcsfs.GCSFileSystem()
with fs.open(saved_file, mode='rb') as f:
return torch.load(f)
except:
return None
# Abstract function - how we convert examples to features should be left to the subclasses
def econvert_example_to_feature(self, input_example, lbl):
pass
def save_dataset(self, dataset, tensorDataset):
path = os.path.join(self.config['output_dir'], self.config['experiment_name'])
saved_file = os.path.join(path, Path(dataset).stem + ".pt")
# If we are using local disk then make the path.
if path[:2] != "gs":
if not os.path.exists(path):
os.makedirs(path)
logging.info("saving dataset at {}".format(saved_file))
torch.save(tensorDataset, saved_file)
else:
fs = gcsfs.GCSFileSystem()
with fs.open(saved_file, 'wb') as f:
torch.save(tensorDataset, f)
def get_dataset(self, dataset):
# 1. load cached version if we can
td = self.load_from_cache(dataset)
# build a fresh copy
if td is None:
td = self.build_fresh_dataset(dataset)
self.save_dataset(dataset, td)
return td
def get_train(self):
if self.train:
return self.train
data = self.get_dataset(self.config['training_data'])
actual_batch_size = self.config['train_batch_size'] // self.config['gradient_accumulation_steps']
logging.info("Using gradient accumulation - physical batch size is {}".format(actual_batch_size))
self.train = DataLoader(data, shuffle=True, batch_size=actual_batch_size)
return self.train
def get_validation(self):
if self.valid:
return self.valid
data = self.get_dataset(self.config['validation_data'])
self.valid = DataLoader(data, shuffle=False, batch_size=self.config['eval_batch_size'])
return self.valid
def get_test(self):
if self.test:
return self.test
data = self.get_dataset(self.config['test_data'])
self.test = DataLoader(data, shuffle=False, batch_size=self.config['eval_batch_size'])
return self.test