--- a
+++ b/src/data.py
@@ -0,0 +1,121 @@
+from torch.utils.data.dataset import Dataset
+import torch
+from src.utils import *
+# data for var autoencoder deep unsup learning with tbehrt
+
+
+class TBEHRT_data_formation(Dataset):
+    def __init__(self, token2idx, dataframe, code= 'code', age = 'age', year = 'year' , static= 'static' , max_len=1000,expColumn='explabel', outcomeColumn='label',  max_age=110, yvocab=None, list2avoid=None, MEM=True):
+        """
+            The dataset class for the pytorch coded model, Targeted BEHRT
+
+            token2idx - the dict that maps tokens in EHR to numbers /index
+            dataframe - the pandas dataframe that has the code,age,year, and any static columns
+            code - name of code column
+            age - name of age column
+            year - name of year column
+            static - name of static column
+            max_len - length of sequence
+            yvocab - the year vocab for the year based sequence of variables
+            expColumn - the exposure column for dataframe
+            outcomeColumn - the outcome column
+            MEM - the masked EHR modelling flag for unsupervised learning
+            list2avoid - list of tokens /diseases to not include in the MEM masking procedure
+
+              """
+
+        if list2avoid is None:
+            self.acceptableVoc = token2idx
+        else:
+            self.acceptableVoc = {x: y for x, y in token2idx.items() if x not in list2avoid}
+            print("old Vocab size: ", len(token2idx), ", and new Vocab size: ", len(self.acceptableVoc))
+        self.vocab = token2idx
+        self.max_len = max_len
+        self.code = dataframe[code]
+        self.age = dataframe[age]
+        self.year = dataframe[year]
+        if outcomeColumn is None:
+            self.label = dataframe.deathLabel
+        else:
+            self.label = dataframe[outcomeColumn]
+        self.age2idx, _ = age_vocab(110, year, symbol=None)
+
+        if expColumn is None:
+            self.treatmentLabel = dataframe.diseaseLabel
+        else:
+            self.treatmentLabel = dataframe[expColumn]
+        self.year2idx = yvocab
+        self.codeS = dataframe[static]
+        self.MEM = MEM
+    def __getitem__(self, index):
+        """
+        return: age, code, position, segmentation, mask, label
+        """
+
+        # extract data
+
+        age = self.age[index]
+
+        code = self.code[index]
+        year = self.year[index]
+
+        age = age[(-self.max_len + 1):]
+        code = code[(-self.max_len + 1):]
+        year = year[(-self.max_len + 1):]
+
+
+        treatmentOutcome = torch.LongTensor([self.treatmentLabel[index]])
+
+        # avoid data cut with first element to be 'SEP'
+        labelOutcome = self.label[index]
+
+        
+        # moved CLS to end as opposed to beginning.
+        code[-1] = 'CLS'
+
+        mask = np.ones(self.max_len)
+        mask[:-len(code)] = 0
+        mask = np.append(np.array([1]), mask)
+
+
+        tokensReal, code2 = code2index(code, self.vocab)
+        # pad age sequence and code sequence
+        year = seq_padding_reverse(year, self.max_len, token2idx=self.year2idx)
+
+        age = seq_padding_reverse(age, self.max_len, token2idx=self.age2idx)
+
+        if self.MEM == False:
+            tokens, codeMLM, labelMLM = nonMASK(code, self.vocab)
+        else:
+            tokens, codeMLM, labelMLM = randommaskreal(code, self.acceptableVoc)
+
+        # get position code and segment code
+        tokens = seq_padding_reverse(tokens, self.max_len)
+        position = position_idx(tokens)
+        segment = index_seg(tokens)
+
+        code2 = seq_padding_reverse(code2, self.max_len, symbol=self.vocab['PAD'])
+
+        codeMLM = seq_padding_reverse(codeMLM, self.max_len, symbol=self.vocab['PAD'])
+        labelMLM = seq_padding_reverse(labelMLM, self.max_len, symbol=-1)
+
+        outCodeS = [int(xx) for xx in self.codeS[index]]
+        fixedcovar = np.array(outCodeS )
+        labelcovar = np.array(([-1] * len(outCodeS)) + [-1, -1])
+        if self.MEM == True:
+            fixedcovar, labelcovar = covarUnsupMaker(fixedcovar)
+        code2 = np.append(fixedcovar, code2)
+        codeMLM = np.append(fixedcovar, codeMLM)
+
+
+
+        # code2 is the fixed static covariates while the codeMLM are the longutidunal one
+        return torch.LongTensor(age), torch.LongTensor(code2), torch.LongTensor(codeMLM), torch.LongTensor(
+            position), torch.LongTensor(segment), torch.LongTensor(year), \
+               torch.LongTensor(mask), torch.LongTensor(labelMLM), torch.LongTensor(
+            [labelOutcome]), treatmentOutcome,  torch.LongTensor(labelcovar)
+
+    
+    def __len__(self):
+        return len(self.code)
+