a b/chexbert/src/datasets/impressions_dataset.py
1
import torch
2
import pandas as pd
3
import numpy as np
4
from bert_tokenizer import load_list
5
from torch.utils.data import Dataset, DataLoader
6
7
class ImpressionsDataset(Dataset):
8
        """The dataset to contain report impressions and their labels."""
9
        
10
        def __init__(self, csv_path, list_path):
11
                """ Initialize the dataset object
12
                @param csv_path (string): path to the csv file containing labels
13
                @param list_path (string): path to the list of encoded impressions
14
                """
15
                self.df = pd.read_csv(csv_path)
16
                self.df =  self.df[['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
17
                                    'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis',
18
                                    'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture',
19
                                    'Support Devices', 'No Finding']]
20
                self.df.replace(0, 2, inplace=True) #negative label is 2
21
                self.df.replace(-1, 3, inplace=True) #uncertain label is 3
22
                self.df.fillna(0, inplace=True) #blank label is 0
23
                self.encoded_imp = load_list(path=list_path)
24
25
        def __len__(self):
26
                """Compute the length of the dataset
27
28
                @return (int): size of the dataframe
29
                """
30
                return self.df.shape[0]
31
32
        def __getitem__(self, idx):
33
                """ Functionality to index into the dataset
34
                @param idx (int): Integer index into the dataset
35
36
                @return (dictionary): Has keys 'imp', 'label' and 'len'. The value of 'imp' is
37
                                      a LongTensor of an encoded impression. The value of 'label'
38
                                      is a LongTensor containing the labels and 'the value of
39
                                      'len' is an integer representing the length of imp's value
40
                """
41
                if torch.is_tensor(idx):
42
                        idx = idx.tolist()
43
                label = self.df.iloc[idx].to_numpy()
44
                label = torch.LongTensor(label)
45
                imp = self.encoded_imp[idx]
46
                imp = torch.LongTensor(imp)
47
                return {"imp": imp, "label": label, "len": imp.shape[0]}