--- a
+++ b/shepherd/dataset.py
@@ -0,0 +1,204 @@
+
+
+import random
+import numpy as np
+import pandas as pd
+import pickle
+import sys
+import torch
+import time
+import re
+from torch.utils.data import Dataset
+from collections import defaultdict
+
+from project_utils import read_patients
+import project_config
+
+
+class PatientDataset(Dataset):
+
+    def __init__(self, filepath, gp_spl=None, raw_data=False, mondo_map_file=str(project_config.PROJECT_DIR / 'mondo_references.csv'), needs_disease_mapping=False, time=False): 
+        self.filepath = filepath
+        self.patients = read_patients(filepath)
+        print('Dataset filepath: ', filepath)
+        print('Number of patients: ', len(self.patients))
+        
+        # add placeholder for true genes/diseases if they don't exist
+        for patient in self.patients:
+            if 'true_genes' not in patient: patient['true_genes'] = []
+            if 'true_diseases' not in patient: patient['true_diseases'] = []
+
+        self.raw_data = raw_data
+        self.needs_disease_mapping = needs_disease_mapping
+        self.time = time
+
+        # create HPO to node_idx map
+        with open(project_config.KG_DIR / f'hpo_to_idx_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
+            self.hpo_to_idx_dict = pickle.load(handle)
+        with open(project_config.KG_DIR / f'hpo_to_name_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
+            self.hpo_to_name_dict = pickle.load(handle)
+        self.idx_to_hpo_dict = {v:self.hpo_to_name_dict[k] if k in self.hpo_to_name_dict else k for k, v in self.hpo_to_idx_dict.items()}
+
+
+        # create ensembl to node_idx map
+        # NOTE: assumes conversion from gene symbols to ensembl IDs has already occurred
+        with open(str(project_config.KG_DIR  / f'ensembl_to_idx_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle:
+            self.ensembl_to_idx_dict = pickle.load(handle)
+        self.idx_to_ensembl_dict = {v:k for k, v in self.ensembl_to_idx_dict.items()}
+
+        # orphanet to mondo disease map
+        with open(str(project_config.PROJECT_DIR / 'preprocess' / 'orphanet' / 'orphanet_to_mondo_dict.pkl'), 'rb') as handle:
+            self.orpha_mondo_map = pickle.load(handle)
+
+        with open(project_config.KG_DIR / f'mondo_to_idx_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
+            self.disease_to_idx_dict = pickle.load(handle)
+        with open(project_config.KG_DIR / f'mondo_to_name_dict_{project_config.CURR_KG}.pkl', 'rb') as handle:
+            self.disease_to_name_dict = pickle.load(handle)
+        self.idx_to_disease_dict = {v:self.disease_to_name_dict[k] if k in self.disease_to_name_dict else k for k, v in self.disease_to_idx_dict.items()}
+
+
+        # degree dict from idx to degree - used for debugging
+        #NOTE: may need to subtract 1 to index into this dict
+        with open(str(project_config.KG_DIR / f'degree_dict_{project_config.CURR_KG}.pkl'), 'rb') as handle:
+            self.degree_dict = pickle.load(handle)
+
+        # get patients with similar genes
+        if all(['true_genes' in patient for patient in self.patients]): # first check to make sure all patients have true genes
+            genes_to_patients = defaultdict(list)
+            for patient in self.patients:
+                for g in patient['true_genes']:
+                    genes_to_patients[g].append(patient['id'])
+            self.patients_with_same_gene = defaultdict(list)
+            for patients in genes_to_patients.values():
+                for p in patients:
+                    self.patients_with_same_gene[p].extend([pat for pat in patients if pat != p])
+
+        # get patients with similar diseases
+        if all(['true_diseases' in patient for patient in self.patients]): # first check to make sure all patients have true diseases
+            dis_to_patients = defaultdict(list)
+            for patient in self.patients:
+                patient_diseases = patient['true_diseases']
+                for d in patient_diseases: 
+                    dis_to_patients[d].append(patient['id'])
+            self.patients_with_same_disease = defaultdict(list)
+            for patients in dis_to_patients.values():
+                for p in patients:
+                    self.patients_with_same_disease[p].extend([pat for pat in patients if pat != p])
+
+        # map from patient id to index in dataset
+        self.patient_id_to_index = {p['id']:i for i, p in enumerate(self.patients)}
+
+        print('Finished initalizing dataset')
+
+
+    def __len__(self):
+        '''
+        Returns the length of the dataset
+        '''
+        return len(self.patients)
+
+    def __getitem__(self, idx):
+        '''
+        Returns a single example from the dataset
+        '''
+        t0 = time.time()
+        patient = self.patients[idx]
+
+        additional_labels_dict = {}
+        if 'additional_labels' in patient:
+            for label, values in patient['additional_labels'].items():
+                if label == "n_hops_cand_g_p": continue
+                if values == None: values = [[-1]]
+                if type(values) != list: values = [[values]] #  wrap in list if needed
+                if type(values) == list and (len(values)==0 or type(values[0]) != list): values = [values]
+                additional_labels_dict[label] = values 
+            if 'max_percent_phen_overlap_train' not in patient['additional_labels']: additional_labels_dict['max_percent_phen_overlap_train'] = [[-1]] 
+            if 'max_phen_overlap_train' not in patient['additional_labels']: additional_labels_dict['max_phen_overlap_train'] = [[-1]]
+       
+        phenotype_node_idx = [self.hpo_to_idx_dict[p] for p in patient['positive_phenotypes'] if p in self.hpo_to_idx_dict ]
+        correct_genes_node_idx = [self.ensembl_to_idx_dict[g] for g in patient['true_genes'] if g in self.ensembl_to_idx_dict ]
+        if 'all_candidate_genes' in patient:
+            candidate_gene_node_idx = [self.ensembl_to_idx_dict[g] for g in patient['all_candidate_genes'] if g in self.ensembl_to_idx_dict ]
+        else: candidate_gene_node_idx = []
+
+        if 'true_diseases' in patient:
+            if self.needs_disease_mapping:
+                orpha_diseases = [ int(d) if len(re.match("^[0-9]*", d)[0]) > 0 else d for d in patient['true_diseases']]
+                mondo_diseases = [mondo_d for orpha_d in set(orpha_diseases).intersection(set(self.orpha_mondo_map.keys())) for mondo_d in self.orpha_mondo_map[orpha_d]]
+            else:
+                mondo_diseases = [str(d) for d in patient['true_diseases']]
+            disease_node_idx = [self.disease_to_idx_dict[d] for d in mondo_diseases if d in self.disease_to_idx_dict]
+        else: disease_node_idx = None
+        
+        if not self.raw_data:
+            phenotype_node_idx = torch.LongTensor(phenotype_node_idx)
+            correct_genes_node_idx = torch.LongTensor(correct_genes_node_idx)
+            candidate_gene_node_idx = torch.LongTensor(candidate_gene_node_idx)
+            if 'true_diseases' in patient: disease_node_idx = torch.LongTensor(disease_node_idx)
+            
+        
+        assert len(phenotype_node_idx) >= 1, f'There are no phenotypes for patient: {patient}'
+        
+        #NOTE: assumes that patient has a single causal/correct gene (the model still outputs a score for each candidate gene)
+        if not self.raw_data:
+            if len(correct_genes_node_idx) > 1:
+                #print('NOTE: The patient has multiple correct genes, but we\'re only selecting the first.')
+                correct_genes_node_idx = correct_genes_node_idx[0].unsqueeze(-1)
+
+        # get index of correct gene
+        if len(correct_genes_node_idx) == 0: # no correct genes available for the patient
+            label_idx = None
+        else:
+            if self.raw_data:
+                label_idx = [candidate_gene_node_idx.index(g) for g in correct_genes_node_idx]
+            else:
+                label_idx = (candidate_gene_node_idx == correct_genes_node_idx[0]).nonzero(as_tuple=True)[0]
+
+        if self.time:
+            t1 = time.time()
+            print(f'It takes {t1-t0:0.4f}s to get an item from the dataset')
+            
+        return (phenotype_node_idx, candidate_gene_node_idx, correct_genes_node_idx, disease_node_idx, label_idx, additional_labels_dict, patient['id'])
+
+    def node_idx_to_degree(self, idx):
+        if idx in self.degree_dict:
+            return self.degree_dict[idx]
+        else:
+            return -1
+    
+    def node_idx_to_name(self, idx):
+        if idx in self.idx_to_hpo_dict:
+            return self.idx_to_hpo_dict[idx]
+        elif idx in self.idx_to_ensembl_dict:
+            return self.idx_to_ensembl_dict[idx]
+        elif idx in self.idx_to_disease_dict:
+            return self.idx_to_disease_dict[idx]
+        elif idx == 0:
+            return 'padding'
+        else:
+            print("Exception on:", idx)
+            raise Exception
+
+    def get_similar_patients(self, patient_id, similarity_type='gene'):
+        if similarity_type == 'gene':
+            sim_pats = np.array(self.patients_with_same_gene[patient_id])
+            np.random.shuffle(sim_pats)
+            return sim_pats
+        elif similarity_type == 'disease':
+            sim_pats = np.array(self.patients_with_same_disease[patient_id])
+            np.random.shuffle(sim_pats)
+            return sim_pats
+        else:
+            raise NotImplementedError
+
+    def get_candidate_diseases(self, cand_type='all_kg_nodes'):
+        if cand_type == 'all_kg_nodes':
+            all_kg_diseases_idx = np.unique(list(self.disease_to_idx_dict.values())) # get idx of all diseases in KG
+            return torch.LongTensor(all_kg_diseases_idx)
+        elif cand_type == 'orphanet':
+            orpha_mondo_diseases = [d[0] for d in list(self.orpha_mondo_map.values())]
+            orpha_mondo_idx = np.unique([self.disease_to_idx_dict[d] for d in orpha_mondo_diseases if d in self.disease_to_idx_dict])
+            return torch.LongTensor(orpha_mondo_idx)
+        else:
+            raise NotImplementedError
+