|
a |
|
b/src/rnn/rnn_dataset.py |
|
|
1 |
import torch |
|
|
2 |
from torch.utils.data import Dataset |
|
|
3 |
import numpy as np |
|
|
4 |
|
|
|
5 |
from src.utils import labeltarget |
|
|
6 |
from src.rnn.rnn_utils import encode_sentence |
|
|
7 |
|
|
|
8 |
frequent_icd9category = ['401','427','276','414','272','250','428','518','285','584'] |
|
|
9 |
frequent_icd9code = ['4019', '4280', '42731', '41401', '5849', '25000', '2724', '51881', '5990', '53081'] |
|
|
10 |
frequent_icd10category = ['I10', 'I25', 'E78', 'I50', 'I48', 'N17', 'E87', 'E11', 'J96', 'N39'] |
|
|
11 |
frequent_icd10code = ['I10', 'I2510', 'I509', 'I4891', 'N179', 'E119', 'E784', 'E785', 'J9690', 'J9600'] |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
class rnndataset(Dataset): |
|
|
15 |
def __init__(self, df, vocab2index, max_len = 50): |
|
|
16 |
self.df = df |
|
|
17 |
self.nsamples = len(df) |
|
|
18 |
self.vocab2index = vocab2index |
|
|
19 |
self.max_len = max_len |
|
|
20 |
|
|
|
21 |
def __getitem__(self,index): |
|
|
22 |
|
|
|
23 |
x = torch.from_numpy(np.array(encode_sentence(self.df['discharge_diagnosis'].iloc[index], self.vocab2index, self.max_len))) |
|
|
24 |
y = {} |
|
|
25 |
y['icd9code'] = torch.from_numpy(labeltarget(self.df["ICD9_CODE"].iloc[index], frequent_icd9code)) |
|
|
26 |
y['icd9cat'] = torch.from_numpy(labeltarget(self.df["ICD9_CATEGORY"].iloc[index], frequent_icd9category)) |
|
|
27 |
y['icd10code'] = torch.from_numpy(labeltarget(self.df["ICD10"].iloc[index], frequent_icd10code)) |
|
|
28 |
y['icd10cat'] = torch.from_numpy(labeltarget(self.df["ICD10_CATEGORY"].iloc[index], frequent_icd10category)) |
|
|
29 |
return x, y |
|
|
30 |
|
|
|
31 |
def __len__(self): |
|
|
32 |
return self.nsamples |