a b/findings_classifier/chexpert_dataset.py
1
import collections
2
import os
3
from pathlib import Path
4
5
import numpy as np
6
import pandas as pd
7
import torch
8
from PIL import Image
9
from skimage import io
10
from torch.utils.data import Dataset
11
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, transforms
12
13
from model.lavis.data.ReportDataset import ExpandChannels
14
15
from local_config import VIS_ROOT, PATH_TO_MIMIC_CXR
16
17
class Chexpert_Dataset(Dataset):
18
    def __init__(self, split='train', truncate=None, loss_weighting="none", use_augs=False):
19
20
        super().__init__()
21
22
        # load csv file
23
        self.split = pd.read_csv(f'{PATH_TO_MIMIC_CXR}/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-split.csv')
24
        self.reports = pd.read_csv('mimic-cxr/reports_processed/mimic_cxr_sectioned.csv')
25
        self.reports = self.reports.dropna(subset=['findings'])
26
27
        self.vis_root = VIS_ROOT
28
        self.img_ids = {img_id: i for i, img_id in enumerate(self.reports['dicom_id'])}
29
        self.split_ids = set(self.split.loc[self.split['split'] == split]['dicom_id'])
30
        self.chexpert = pd.read_csv(f'data/data_files/finding_chexbert_labels.csv')
31
        self.chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum",
32
                              "Cardiomegaly", "Lung Opacity",
33
                              "Lung Lesion", "Edema",
34
                              "Consolidation", "Pneumonia",
35
                              "Atelectasis", "Pneumothorax",
36
                              "Pleural Effusion", "Pleural Other",
37
                              "Fracture", "Support Devices"]
38
39
        # get all dicom_ids where "split" is split
40
        self.annotation = self.reports.loc[self.reports['dicom_id'].isin(self.split_ids)]
41
        self.annotation['study_id'] = self.annotation['Note_file'].apply(lambda x: int(x.lstrip('s').rstrip('.txt')))
42
        # merge chexpert labels
43
        self.annotation = pd.merge(self.annotation, self.chexpert, how='left', left_on=['dicom_id'], right_on=['dicom_id'])
44
        if truncate is not None:
45
            self.annotation = self.annotation[:truncate]
46
47
        self.vis_transforms = Compose([Resize(512), CenterCrop(488), ToTensor(), ExpandChannels()])
48
        if use_augs:
49
            aug_tfm = transforms.Compose([transforms.RandomAffine(degrees=30, shear=15),
50
                                          transforms.ColorJitter(brightness=0.2, contrast=0.2)])
51
52
            self.vis_transforms = transforms.Compose([self.vis_transforms, aug_tfm])
53
        self.loss_weighting = loss_weighting
54
55
    def get_class_weights(self):
56
        """Compute class weights based on the inverse of class frequencies.
57
58
        Returns:
59
            Dict[str, float]: Class weights.
60
        """
61
        if self.loss_weighting == "none":
62
            return torch.ones(len(self.chexpert_cols), dtype=torch.float32)
63
64
        label_counts = torch.zeros(len(self.chexpert_cols), dtype=torch.float32)
65
        # iterate over dataframe getting rows
66
        for _, ann in self.annotation.iterrows():
67
            chexpert_labels = self._extract_chexpert_labels_from_row(ann)
68
            label_counts += chexpert_labels
69
70
        # Compute class weights
71
        if self.loss_weighting == "lin":
72
            class_weights = len(self.annotation) / label_counts
73
        elif self.loss_weighting == "log":
74
            class_weights = torch.log(len(self.annotation) / label_counts)
75
76
        return class_weights
77
78
    def remap_to_uint8(self, array: np.ndarray, percentiles=None) -> np.ndarray:
79
        """Remap values in input so the output range is :math:`[0, 255]`.
80
81
        Percentiles can be used to specify the range of values to remap.
82
        This is useful to discard outliers in the input data.
83
84
        :param array: Input array.
85
        :param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``.
86
            Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster).
87
        :returns: Array with ``0`` and ``255`` as minimum and maximum values.
88
        """
89
        array = array.astype(float)
90
        if percentiles is not None:
91
            len_percentiles = len(percentiles)
92
            if len_percentiles != 2:
93
                message = (
94
                    'The value for percentiles should be a sequence of length 2,'
95
                    f' but has length {len_percentiles}'
96
                )
97
                raise ValueError(message)
98
            a, b = percentiles
99
            if a >= b:
100
                raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed')
101
            if a < 0 or b > 100:
102
                raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed')
103
            cutoff: np.ndarray = np.percentile(array, percentiles)
104
            array = np.clip(array, *cutoff)
105
        array -= array.min()
106
        array /= array.max()
107
        array *= 255
108
        return array.astype(np.uint8)
109
110
    def load_image(self, path) -> Image.Image:
111
        """Load an image from disk.
112
113
        The image values are remapped to :math:`[0, 255]` and cast to 8-bit unsigned integers.
114
115
        :param path: Path to image.
116
        :returns: Image as ``Pillow`` ``Image``.
117
        """
118
        # Although ITK supports JPEG and PNG, we use Pillow for consistency with older trained models
119
        if path.suffix in [".jpg", ".jpeg", ".png"]:
120
            image = io.imread(path)
121
        else:
122
            raise ValueError(f"Image type not supported, filename was: {path}")
123
124
        image = self.remap_to_uint8(image)
125
        return Image.fromarray(image).convert("L")
126
127
    def _extract_chexpert_labels_from_row(self, row: pd.Series) -> torch.Tensor:
128
        labels = torch.zeros(len(self.chexpert_cols), dtype=torch.float32)
129
        for i, col in enumerate(self.chexpert_cols):
130
            if row[col] == 1:
131
                labels[i] = 1
132
        return labels
133
134
    def __getitem__(self, index):
135
        ann = self.annotation.iloc[index]
136
        image_path = os.path.join(self.vis_root, ann["Img_Folder"], ann["Img_Filename"])
137
        image = self.load_image(Path(image_path))
138
        image = self.vis_transforms(image)
139
        chexpert_labels = self._extract_chexpert_labels_from_row(ann)
140
141
        return {
142
            "image": image,
143
            "labels": chexpert_labels,
144
            "image_id": self.img_ids[ann["dicom_id"]],
145
            "report": ann["findings"],
146
            "study_id": ann["study_id"],
147
            "dicom_id": ann["dicom_id"],
148
        }
149
150
    def __len__(self):
151
        return len(self.annotation)
152
153
154
if __name__ == '__main__':
155
    dataset = Chexpert_Dataset()
156
    print(dataset[0])