a b/src/rnn/rnn_dataset.py
1
import torch
2
from torch.utils.data import Dataset
3
import numpy as np 
4
5
from src.utils import labeltarget
6
from src.rnn.rnn_utils import encode_sentence
7
8
frequent_icd9category = ['401','427','276','414','272','250','428','518','285','584']
9
frequent_icd9code = ['4019', '4280', '42731', '41401', '5849', '25000', '2724', '51881', '5990', '53081']
10
frequent_icd10category = ['I10', 'I25', 'E78', 'I50', 'I48', 'N17', 'E87', 'E11', 'J96', 'N39']
11
frequent_icd10code = ['I10', 'I2510', 'I509', 'I4891', 'N179', 'E119', 'E784', 'E785', 'J9690', 'J9600']
12
13
14
class rnndataset(Dataset):
15
  def __init__(self, df, vocab2index, max_len = 50):
16
    self.df = df
17
    self.nsamples = len(df)
18
    self.vocab2index = vocab2index
19
    self.max_len = max_len
20
21
  def __getitem__(self,index):
22
       
23
    x = torch.from_numpy(np.array(encode_sentence(self.df['discharge_diagnosis'].iloc[index], self.vocab2index, self.max_len)))
24
    y = {}
25
    y['icd9code'] = torch.from_numpy(labeltarget(self.df["ICD9_CODE"].iloc[index], frequent_icd9code))
26
    y['icd9cat'] = torch.from_numpy(labeltarget(self.df["ICD9_CATEGORY"].iloc[index], frequent_icd9category))
27
    y['icd10code'] = torch.from_numpy(labeltarget(self.df["ICD10"].iloc[index], frequent_icd10code))
28
    y['icd10cat'] = torch.from_numpy(labeltarget(self.df["ICD10_CATEGORY"].iloc[index], frequent_icd10category))
29
    return x, y
30
31
  def __len__(self):
32
    return self.nsamples