|
a |
|
b/chexbert/src/datasets/impressions_dataset.py |
|
|
1 |
import torch |
|
|
2 |
import pandas as pd |
|
|
3 |
import numpy as np |
|
|
4 |
from bert_tokenizer import load_list |
|
|
5 |
from torch.utils.data import Dataset, DataLoader |
|
|
6 |
|
|
|
7 |
class ImpressionsDataset(Dataset): |
|
|
8 |
"""The dataset to contain report impressions and their labels.""" |
|
|
9 |
|
|
|
10 |
def __init__(self, csv_path, list_path): |
|
|
11 |
""" Initialize the dataset object |
|
|
12 |
@param csv_path (string): path to the csv file containing labels |
|
|
13 |
@param list_path (string): path to the list of encoded impressions |
|
|
14 |
""" |
|
|
15 |
self.df = pd.read_csv(csv_path) |
|
|
16 |
self.df = self.df[['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', |
|
|
17 |
'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', |
|
|
18 |
'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', |
|
|
19 |
'Support Devices', 'No Finding']] |
|
|
20 |
self.df.replace(0, 2, inplace=True) #negative label is 2 |
|
|
21 |
self.df.replace(-1, 3, inplace=True) #uncertain label is 3 |
|
|
22 |
self.df.fillna(0, inplace=True) #blank label is 0 |
|
|
23 |
self.encoded_imp = load_list(path=list_path) |
|
|
24 |
|
|
|
25 |
def __len__(self): |
|
|
26 |
"""Compute the length of the dataset |
|
|
27 |
|
|
|
28 |
@return (int): size of the dataframe |
|
|
29 |
""" |
|
|
30 |
return self.df.shape[0] |
|
|
31 |
|
|
|
32 |
def __getitem__(self, idx): |
|
|
33 |
""" Functionality to index into the dataset |
|
|
34 |
@param idx (int): Integer index into the dataset |
|
|
35 |
|
|
|
36 |
@return (dictionary): Has keys 'imp', 'label' and 'len'. The value of 'imp' is |
|
|
37 |
a LongTensor of an encoded impression. The value of 'label' |
|
|
38 |
is a LongTensor containing the labels and 'the value of |
|
|
39 |
'len' is an integer representing the length of imp's value |
|
|
40 |
""" |
|
|
41 |
if torch.is_tensor(idx): |
|
|
42 |
idx = idx.tolist() |
|
|
43 |
label = self.df.iloc[idx].to_numpy() |
|
|
44 |
label = torch.LongTensor(label) |
|
|
45 |
imp = self.encoded_imp[idx] |
|
|
46 |
imp = torch.LongTensor(imp) |
|
|
47 |
return {"imp": imp, "label": label, "len": imp.shape[0]} |