Switch to side-by-side view

--- a
+++ b/medicalbert/datareader/abstract_data_reader.py
@@ -0,0 +1,138 @@
+# This method is the public interface. We use this to get a dataset.
+# If a tensor dataset does not exist, we create it.
+import logging, os, torch, gcsfs
+from pathlib import Path
+import pandas as pd
+from torch.utils.data import TensorDataset, DataLoader
+from tqdm import tqdm
+
+#We suppress logging below error for this library, otherwise seq. longer than 512 will spam the console.
+logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
+
+class InputExample(object):
+    """A single training/test example for simple sequence classification."""
+
+    def __init__(self, guid, text_a, text_b=None, label=None):
+        """Constructs a InputExample.
+
+        Args:
+            guid: Unique id for the example.
+            text_a: string. The untokenized text of the first sequence. For single
+            sequence tasks, only this sequence must be specified.
+            text_b: (Optional) string. The untokenized text of the second sequence.
+            Only must be specified for sequence pair tasks.
+            label: (Optional) string. The label of the example. This should be
+            specified for train and dev examples, but not for test examples.
+        """
+        self.guid = guid
+        self.text_a = text_a
+        self.text_b = text_b
+        self.label = label
+
+class AbstractDataReader:
+
+    def __init__(self, config, tokenizer):
+        self.tokenizer = tokenizer
+        self.max_sequence_length = config['max_sequence_length']
+        self.config = config
+        self.train = None
+        self.valid = None
+        self.test = None
+
+    @staticmethod
+    def truncate_seq_pair(tokens_a, tokens_b, max_length):
+        """Truncates a sequence pair in place to the maximum length."""
+
+        # This is a simple heuristic which will always truncate the longer sequence
+        # one token at a time. This makes more sense than truncating an equal percent
+        # of tokens from each, since if one sequence is very short then each token
+        # that's truncated likely contains more information than a longer sequence.
+        while True:
+            total_length = len(tokens_a) + len(tokens_b)
+            if total_length <= max_length:
+                break
+            if len(tokens_a) > len(tokens_b):
+                tokens_a.pop()
+            else:
+                tokens_b.pop()
+
+    def load_from_cache(self, dataset):
+        path = os.path.join(self.config['output_dir'], self.config['experiment_name'])
+        saved_file = os.path.join(path, Path(dataset).stem + ".pt")
+
+        # If we're using localfilesystem.
+        if saved_file[:2] != "gs":
+            if os.path.isfile(saved_file):
+                logging.info("Using Cached dataset from local disk {} - saves time!".format(saved_file))
+                return torch.load(saved_file)
+
+        #If we're here were using gcsfs
+        try:
+            fs = gcsfs.GCSFileSystem()
+            with fs.open(saved_file, mode='rb') as f:
+                return torch.load(f)
+        except:
+            return None
+
+    # Abstract function - how we convert examples to features should be left to the subclasses
+    def econvert_example_to_feature(self, input_example, lbl):
+        pass
+
+
+    def save_dataset(self, dataset, tensorDataset):
+        path = os.path.join(self.config['output_dir'], self.config['experiment_name'])
+
+        saved_file = os.path.join(path, Path(dataset).stem + ".pt")
+
+        # If we are using local disk then make the path.
+        if path[:2] != "gs":
+            if not os.path.exists(path):
+                os.makedirs(path)
+
+            logging.info("saving dataset at {}".format(saved_file))
+            torch.save(tensorDataset, saved_file)
+        else:
+            fs = gcsfs.GCSFileSystem()
+            with fs.open(saved_file, 'wb') as f:
+                torch.save(tensorDataset, f)
+
+    def get_dataset(self, dataset):
+
+        # 1. load cached version if we can
+        td = self.load_from_cache(dataset)
+
+        # build a fresh copy
+        if td is None:
+            td = self.build_fresh_dataset(dataset)
+
+            self.save_dataset(dataset, td)
+        return td
+
+    def get_train(self):
+        if self.train:
+            return self.train
+
+        data = self.get_dataset(self.config['training_data'])
+        actual_batch_size = self.config['train_batch_size'] // self.config['gradient_accumulation_steps']
+
+        logging.info("Using gradient accumulation - physical batch size is {}".format(actual_batch_size))
+        self.train = DataLoader(data, shuffle=True, batch_size=actual_batch_size)
+        return self.train
+
+    def get_validation(self):
+        if self.valid:
+            return self.valid
+
+        data = self.get_dataset(self.config['validation_data'])
+
+        self.valid = DataLoader(data, shuffle=False, batch_size=self.config['eval_batch_size'])
+        return self.valid
+
+    def get_test(self):
+        if self.test:
+            return self.test
+
+        data = self.get_dataset(self.config['test_data'])
+
+        self.test = DataLoader(data, shuffle=False, batch_size=self.config['eval_batch_size'])
+        return self.test
\ No newline at end of file