--- a
+++ b/load_dataset/custom_datasets.py
@@ -0,0 +1,206 @@
+#custom_datasets.py
+#Copyright (c) 2020 Rachel Lea Ballantyne Draelos
+
+#MIT License
+
+#Permission is hereby granted, free of charge, to any person obtaining a copy
+#of this software and associated documentation files (the "Software"), to deal
+#in the Software without restriction, including without limitation the rights
+#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+#copies of the Software, and to permit persons to whom the Software is
+#furnished to do so, subject to the following conditions:
+
+#The above copyright notice and this permission notice shall be included in all
+#copies or substantial portions of the Software.
+
+#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+#SOFTWARE
+
+import os
+import pickle
+import numpy as np
+import pandas as pd
+
+import torch
+from torch.utils.data import Dataset, DataLoader
+
+from . import utils
+
+#Set seeds
+np.random.seed(0)
+torch.manual_seed(0)
+torch.cuda.manual_seed(0)
+torch.cuda.manual_seed_all(0)
+
+###################################################
+# PACE Dataset for Data Stored in 2019-10-BigData #-----------------------------
+###################################################
+class CTDataset_2019_10(Dataset):    
+    def __init__(self, setname, label_type_ld,
+                 label_meanings, num_channels, pixel_bounds,
+                 data_augment, crop_type,
+                 selected_note_acc_files):
+        """CT Dataset class that works for preprocessed data in 2019-10-BigData.
+        A single example (for crop_type == 'single') is a 4D CT volume:
+            if num_channels == 3, shape [134,3,420,420]
+            if num_channels == 1, shape [402,420,420]
+        
+        Variables:
+        <setname> is either 'train' or 'valid' or 'test'
+        <label_type_ld> is 'disease_new'
+        <label_meanings>: list of strings indicating which labels should
+            be kept. Alternatively, can be the string 'all' in which case
+            all labels are kept.
+        <num_channels>: number of channels to reshape the image to.
+            == 3 if the model uses a pretrained feature extractor.
+            == 1 if the model uses only 3D convolutions.
+        <pixel_bounds>: list of ints e.g. [-1000,200]
+            Determines the lower bound, upper bound of pixel value clipping
+            and normalization.
+        <data_augment>: if True, perform data augmentation.
+        <crop_type>: is 'single' for an example consisting of one 4D numpy array
+        <selected_note_acc_files>: This should be a dictionary
+            with key equal to setname and value that is a string. If the value
+            is a path to a file, the file must be a CSV. Only note accessions
+            in this file will be used. If the value is not a valid file path,
+            all available note accs will be used, i.e. the model will be
+            trained on the whole dataset."""
+        self.setname = setname
+        self.define_subsets_list()
+        self.label_type_ld = label_type_ld
+        self.label_meanings = label_meanings
+        self.num_channels = num_channels
+        self.pixel_bounds = pixel_bounds
+        if self.setname == 'train':
+            self.data_augment = data_augment
+        else:
+            self.data_augment = False
+        print('For dataset',self.setname,'data_augment is',self.data_augment)
+        self.crop_type = crop_type
+        assert self.crop_type == 'single'
+        self.selected_note_acc_files = selected_note_acc_files
+        
+        #Define location of the CT volumes
+        self.main_clean_path = './load_dataset/fakedata'
+        self.volume_log_df = pd.read_csv('./load_dataset/fakedata/CT_Scan_Preprocessing_Log_File_FINAL_SMALL.csv',header=0,index_col=0)
+        
+        #Get the example ids
+        self.volume_accessions = self.get_volume_accessions()
+                        
+        #Get the ground truth labels
+        self.labels_df = self.get_labels_df()
+    
+    # Pytorch Required Methods #------------------------------------------------
+    def __len__(self):
+        return len(self.volume_accessions)
+        
+    def __getitem__(self, idx):
+        """Return a single sample at index <idx>. The sample is a Python
+        dictionary with keys 'data' and 'gr_truth' for the image and label,
+        respectively"""
+        return self._get_pace(self.volume_accessions[idx])
+    
+    # Volume Accession Methods #------------------------------------------------
+    def get_note_accessions(self):
+        setname_file = self.selected_note_acc_files[self.setname]
+        if os.path.isfile(setname_file):
+            print('\tObtaining note accessions from',setname_file)
+            sel_accs = pd.read_csv(setname_file,header=0)            
+            assert sorted(list(set(sel_accs['Subset_Assigned'].values.tolist())))==sorted(self.subsets_list)
+            note_accs = sel_accs.loc[:,'Accession'].values.tolist()
+            print('\tTotal theoretical note accessions in subsets:',len(note_accs))
+            return note_accs
+        else: 
+            print('\tObtaining note accessions from complete identifiers file')
+            #Read in identifiers file, which contains note_accessions
+            #Columns are MRN, Accession, Set_Assigned, Set_Should_Be, Subset_Assigned
+            all_ids = pd.read_csv('./load_dataset/fakedata/all_identifiers.csv',header=0)
+           
+            #Extract the note_accessions
+            note_accs = []
+            for subset in self.subsets_list: #e.g. ['imgvalid_a','imgvalid_b']
+                subset_note_accs = all_ids[all_ids['Subset_Assigned']==subset].loc[:,'Accession'].values.tolist()
+                note_accs += subset_note_accs
+            print('\tTotal theoretical note accessions in subsets:',len(note_accs))
+            return note_accs
+    
+    def get_volume_accessions(self):
+        note_accs = self.get_note_accessions()
+        #Translate note_accessions to volume_accessions based on what data has been
+        #preprocessed successfully. volume_log_df has note accessions as the
+        #index, and the column 'full_filename_npz' for the volume accession.
+        #The column 'status' should equal 'success' if the volume has been
+        #preprocessed correctly.
+        print('\tTotal theoretical volumes in whole dataset:',self.volume_log_df.shape[0])
+        self.volume_log_df = self.volume_log_df[self.volume_log_df['status']=='success']
+        print('\tTotal successfully preprocessed volumes in whole dataset:',self.volume_log_df.shape[0])
+        volume_accs = []
+        for note_acc in note_accs:
+            if note_acc in self.volume_log_df.index.values.tolist():
+                volume_accs.append(self.volume_log_df.at[note_acc,'full_filename_npz'])
+        print('\tFinal total successfully preprocessed volumes in requested subsets:',len(volume_accs))
+        #According to this thread: https://github.com/pytorch/pytorch/issues/13246
+        #it is better to use a numpy array than a list to reduce memory leaks.
+        return np.array(volume_accs)
+    
+    # Ground Truth Label Methods #----------------------------------------------
+    def get_labels_df(self):
+        #Get the ground truth labels based on requested label type.
+        labels_df = read_in_labels(self.label_type_ld, self.setname)
+        
+        #Now filter the ground truth labels based on the desired label meanings:
+        if self.label_meanings != 'all': #i.e. if you want to filter
+            labels_df = labels_df[self.label_meanings]
+        return labels_df
+    
+    # Fetch a CT Volume (__getitem__ implementation) #--------------------------
+    def _get_pace(self, volume_acc):
+        """<volume_acc> is for example RHAA12345_6.npz"""
+        #Load compressed npz file: [slices, square, square]
+        ctvol = np.load(os.path.join(self.main_clean_path, volume_acc))['ct']
+        
+        #Prepare the CT volume data (already torch Tensors)
+        data = utils.prepare_ctvol_2019_10_dataset(ctvol, self.pixel_bounds, self.data_augment, self.num_channels, self.crop_type)
+        
+        #Get the ground truth:
+        note_acc = self.volume_log_df[self.volume_log_df['full_filename_npz']==volume_acc].index.values.tolist()[0]
+        gr_truth = self.labels_df.loc[note_acc, :].values
+        gr_truth = torch.from_numpy(gr_truth).squeeze().type(torch.float)
+        
+        #When training on only one abnormality you must unsqueeze to prevent
+        #a dimensions error when training the model:
+        if len(self.label_meanings)==1:
+            gr_truth = gr_truth.unsqueeze(0)
+        
+        #Create the sample
+        sample = {'data': data, 'gr_truth': gr_truth, 'volume_acc': volume_acc}
+        return sample
+    
+    # Sanity Check #------------------------------------------------------------
+    def define_subsets_list(self):
+        assert self.setname in ['train','valid','test']
+        if self.setname == 'train':
+            self.subsets_list = ['imgtrain']
+        elif self.setname == 'valid':
+            self.subsets_list = ['imgvalid_a']
+        elif self.setname == 'test':
+            self.subsets_list = ['imgtest_a','imgtest_b','imgtest_c','imgtest_d']
+        print('Creating',self.setname,'dataset with subsets',self.subsets_list)
+
+#######################
+# Ground Truth Labels #---------------------------------------------------------
+#######################
+
+def read_in_labels(label_type_ld, setname):
+    """Return a pandas dataframe with the dataset labels.
+    Accession numbers are the index and labels (e.g. "pneumonia") are the columns.
+    <setname> can be 'train', 'valid', or 'test'."""
+    assert label_type_ld == 'disease_new'
+    labels_file = './load_dataset/fakedata/2019-12-18_duke_disease/img'+setname+'_BinaryLabels.csv'
+    return pd.read_csv(labels_file, header=0, index_col = 0)
+    
\ No newline at end of file