--- a +++ b/medicalbert/datareader/abstract_data_reader.py @@ -0,0 +1,138 @@ +# This method is the public interface. We use this to get a dataset. +# If a tensor dataset does not exist, we create it. +import logging, os, torch, gcsfs +from pathlib import Path +import pandas as pd +from torch.utils.data import TensorDataset, DataLoader +from tqdm import tqdm + +#We suppress logging below error for this library, otherwise seq. longer than 512 will spam the console. +logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) + +class InputExample(object): + """A single training/test example for simple sequence classification.""" + + def __init__(self, guid, text_a, text_b=None, label=None): + """Constructs a InputExample. + + Args: + guid: Unique id for the example. + text_a: string. The untokenized text of the first sequence. For single + sequence tasks, only this sequence must be specified. + text_b: (Optional) string. The untokenized text of the second sequence. + Only must be specified for sequence pair tasks. + label: (Optional) string. The label of the example. This should be + specified for train and dev examples, but not for test examples. + """ + self.guid = guid + self.text_a = text_a + self.text_b = text_b + self.label = label + +class AbstractDataReader: + + def __init__(self, config, tokenizer): + self.tokenizer = tokenizer + self.max_sequence_length = config['max_sequence_length'] + self.config = config + self.train = None + self.valid = None + self.test = None + + @staticmethod + def truncate_seq_pair(tokens_a, tokens_b, max_length): + """Truncates a sequence pair in place to the maximum length.""" + + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_length: + break + if len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + + def load_from_cache(self, dataset): + path = os.path.join(self.config['output_dir'], self.config['experiment_name']) + saved_file = os.path.join(path, Path(dataset).stem + ".pt") + + # If we're using localfilesystem. + if saved_file[:2] != "gs": + if os.path.isfile(saved_file): + logging.info("Using Cached dataset from local disk {} - saves time!".format(saved_file)) + return torch.load(saved_file) + + #If we're here were using gcsfs + try: + fs = gcsfs.GCSFileSystem() + with fs.open(saved_file, mode='rb') as f: + return torch.load(f) + except: + return None + + # Abstract function - how we convert examples to features should be left to the subclasses + def econvert_example_to_feature(self, input_example, lbl): + pass + + + def save_dataset(self, dataset, tensorDataset): + path = os.path.join(self.config['output_dir'], self.config['experiment_name']) + + saved_file = os.path.join(path, Path(dataset).stem + ".pt") + + # If we are using local disk then make the path. + if path[:2] != "gs": + if not os.path.exists(path): + os.makedirs(path) + + logging.info("saving dataset at {}".format(saved_file)) + torch.save(tensorDataset, saved_file) + else: + fs = gcsfs.GCSFileSystem() + with fs.open(saved_file, 'wb') as f: + torch.save(tensorDataset, f) + + def get_dataset(self, dataset): + + # 1. load cached version if we can + td = self.load_from_cache(dataset) + + # build a fresh copy + if td is None: + td = self.build_fresh_dataset(dataset) + + self.save_dataset(dataset, td) + return td + + def get_train(self): + if self.train: + return self.train + + data = self.get_dataset(self.config['training_data']) + actual_batch_size = self.config['train_batch_size'] // self.config['gradient_accumulation_steps'] + + logging.info("Using gradient accumulation - physical batch size is {}".format(actual_batch_size)) + self.train = DataLoader(data, shuffle=True, batch_size=actual_batch_size) + return self.train + + def get_validation(self): + if self.valid: + return self.valid + + data = self.get_dataset(self.config['validation_data']) + + self.valid = DataLoader(data, shuffle=False, batch_size=self.config['eval_batch_size']) + return self.valid + + def get_test(self): + if self.test: + return self.test + + data = self.get_dataset(self.config['test_data']) + + self.test = DataLoader(data, shuffle=False, batch_size=self.config['eval_batch_size']) + return self.test \ No newline at end of file