|
a |
|
b/findings_classifier/chexpert_dataset.py |
|
|
1 |
import collections |
|
|
2 |
import os |
|
|
3 |
from pathlib import Path |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
import pandas as pd |
|
|
7 |
import torch |
|
|
8 |
from PIL import Image |
|
|
9 |
from skimage import io |
|
|
10 |
from torch.utils.data import Dataset |
|
|
11 |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, transforms |
|
|
12 |
|
|
|
13 |
from model.lavis.data.ReportDataset import ExpandChannels |
|
|
14 |
|
|
|
15 |
from local_config import VIS_ROOT, PATH_TO_MIMIC_CXR |
|
|
16 |
|
|
|
17 |
class Chexpert_Dataset(Dataset): |
|
|
18 |
def __init__(self, split='train', truncate=None, loss_weighting="none", use_augs=False): |
|
|
19 |
|
|
|
20 |
super().__init__() |
|
|
21 |
|
|
|
22 |
# load csv file |
|
|
23 |
self.split = pd.read_csv(f'{PATH_TO_MIMIC_CXR}/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-split.csv') |
|
|
24 |
self.reports = pd.read_csv('mimic-cxr/reports_processed/mimic_cxr_sectioned.csv') |
|
|
25 |
self.reports = self.reports.dropna(subset=['findings']) |
|
|
26 |
|
|
|
27 |
self.vis_root = VIS_ROOT |
|
|
28 |
self.img_ids = {img_id: i for i, img_id in enumerate(self.reports['dicom_id'])} |
|
|
29 |
self.split_ids = set(self.split.loc[self.split['split'] == split]['dicom_id']) |
|
|
30 |
self.chexpert = pd.read_csv(f'data/data_files/finding_chexbert_labels.csv') |
|
|
31 |
self.chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum", |
|
|
32 |
"Cardiomegaly", "Lung Opacity", |
|
|
33 |
"Lung Lesion", "Edema", |
|
|
34 |
"Consolidation", "Pneumonia", |
|
|
35 |
"Atelectasis", "Pneumothorax", |
|
|
36 |
"Pleural Effusion", "Pleural Other", |
|
|
37 |
"Fracture", "Support Devices"] |
|
|
38 |
|
|
|
39 |
# get all dicom_ids where "split" is split |
|
|
40 |
self.annotation = self.reports.loc[self.reports['dicom_id'].isin(self.split_ids)] |
|
|
41 |
self.annotation['study_id'] = self.annotation['Note_file'].apply(lambda x: int(x.lstrip('s').rstrip('.txt'))) |
|
|
42 |
# merge chexpert labels |
|
|
43 |
self.annotation = pd.merge(self.annotation, self.chexpert, how='left', left_on=['dicom_id'], right_on=['dicom_id']) |
|
|
44 |
if truncate is not None: |
|
|
45 |
self.annotation = self.annotation[:truncate] |
|
|
46 |
|
|
|
47 |
self.vis_transforms = Compose([Resize(512), CenterCrop(488), ToTensor(), ExpandChannels()]) |
|
|
48 |
if use_augs: |
|
|
49 |
aug_tfm = transforms.Compose([transforms.RandomAffine(degrees=30, shear=15), |
|
|
50 |
transforms.ColorJitter(brightness=0.2, contrast=0.2)]) |
|
|
51 |
|
|
|
52 |
self.vis_transforms = transforms.Compose([self.vis_transforms, aug_tfm]) |
|
|
53 |
self.loss_weighting = loss_weighting |
|
|
54 |
|
|
|
55 |
def get_class_weights(self): |
|
|
56 |
"""Compute class weights based on the inverse of class frequencies. |
|
|
57 |
|
|
|
58 |
Returns: |
|
|
59 |
Dict[str, float]: Class weights. |
|
|
60 |
""" |
|
|
61 |
if self.loss_weighting == "none": |
|
|
62 |
return torch.ones(len(self.chexpert_cols), dtype=torch.float32) |
|
|
63 |
|
|
|
64 |
label_counts = torch.zeros(len(self.chexpert_cols), dtype=torch.float32) |
|
|
65 |
# iterate over dataframe getting rows |
|
|
66 |
for _, ann in self.annotation.iterrows(): |
|
|
67 |
chexpert_labels = self._extract_chexpert_labels_from_row(ann) |
|
|
68 |
label_counts += chexpert_labels |
|
|
69 |
|
|
|
70 |
# Compute class weights |
|
|
71 |
if self.loss_weighting == "lin": |
|
|
72 |
class_weights = len(self.annotation) / label_counts |
|
|
73 |
elif self.loss_weighting == "log": |
|
|
74 |
class_weights = torch.log(len(self.annotation) / label_counts) |
|
|
75 |
|
|
|
76 |
return class_weights |
|
|
77 |
|
|
|
78 |
def remap_to_uint8(self, array: np.ndarray, percentiles=None) -> np.ndarray: |
|
|
79 |
"""Remap values in input so the output range is :math:`[0, 255]`. |
|
|
80 |
|
|
|
81 |
Percentiles can be used to specify the range of values to remap. |
|
|
82 |
This is useful to discard outliers in the input data. |
|
|
83 |
|
|
|
84 |
:param array: Input array. |
|
|
85 |
:param percentiles: Percentiles of the input values that will be mapped to ``0`` and ``255``. |
|
|
86 |
Passing ``None`` is equivalent to using percentiles ``(0, 100)`` (but faster). |
|
|
87 |
:returns: Array with ``0`` and ``255`` as minimum and maximum values. |
|
|
88 |
""" |
|
|
89 |
array = array.astype(float) |
|
|
90 |
if percentiles is not None: |
|
|
91 |
len_percentiles = len(percentiles) |
|
|
92 |
if len_percentiles != 2: |
|
|
93 |
message = ( |
|
|
94 |
'The value for percentiles should be a sequence of length 2,' |
|
|
95 |
f' but has length {len_percentiles}' |
|
|
96 |
) |
|
|
97 |
raise ValueError(message) |
|
|
98 |
a, b = percentiles |
|
|
99 |
if a >= b: |
|
|
100 |
raise ValueError(f'Percentiles must be in ascending order, but a sequence "{percentiles}" was passed') |
|
|
101 |
if a < 0 or b > 100: |
|
|
102 |
raise ValueError(f'Percentiles must be in the range [0, 100], but a sequence "{percentiles}" was passed') |
|
|
103 |
cutoff: np.ndarray = np.percentile(array, percentiles) |
|
|
104 |
array = np.clip(array, *cutoff) |
|
|
105 |
array -= array.min() |
|
|
106 |
array /= array.max() |
|
|
107 |
array *= 255 |
|
|
108 |
return array.astype(np.uint8) |
|
|
109 |
|
|
|
110 |
def load_image(self, path) -> Image.Image: |
|
|
111 |
"""Load an image from disk. |
|
|
112 |
|
|
|
113 |
The image values are remapped to :math:`[0, 255]` and cast to 8-bit unsigned integers. |
|
|
114 |
|
|
|
115 |
:param path: Path to image. |
|
|
116 |
:returns: Image as ``Pillow`` ``Image``. |
|
|
117 |
""" |
|
|
118 |
# Although ITK supports JPEG and PNG, we use Pillow for consistency with older trained models |
|
|
119 |
if path.suffix in [".jpg", ".jpeg", ".png"]: |
|
|
120 |
image = io.imread(path) |
|
|
121 |
else: |
|
|
122 |
raise ValueError(f"Image type not supported, filename was: {path}") |
|
|
123 |
|
|
|
124 |
image = self.remap_to_uint8(image) |
|
|
125 |
return Image.fromarray(image).convert("L") |
|
|
126 |
|
|
|
127 |
def _extract_chexpert_labels_from_row(self, row: pd.Series) -> torch.Tensor: |
|
|
128 |
labels = torch.zeros(len(self.chexpert_cols), dtype=torch.float32) |
|
|
129 |
for i, col in enumerate(self.chexpert_cols): |
|
|
130 |
if row[col] == 1: |
|
|
131 |
labels[i] = 1 |
|
|
132 |
return labels |
|
|
133 |
|
|
|
134 |
def __getitem__(self, index): |
|
|
135 |
ann = self.annotation.iloc[index] |
|
|
136 |
image_path = os.path.join(self.vis_root, ann["Img_Folder"], ann["Img_Filename"]) |
|
|
137 |
image = self.load_image(Path(image_path)) |
|
|
138 |
image = self.vis_transforms(image) |
|
|
139 |
chexpert_labels = self._extract_chexpert_labels_from_row(ann) |
|
|
140 |
|
|
|
141 |
return { |
|
|
142 |
"image": image, |
|
|
143 |
"labels": chexpert_labels, |
|
|
144 |
"image_id": self.img_ids[ann["dicom_id"]], |
|
|
145 |
"report": ann["findings"], |
|
|
146 |
"study_id": ann["study_id"], |
|
|
147 |
"dicom_id": ann["dicom_id"], |
|
|
148 |
} |
|
|
149 |
|
|
|
150 |
def __len__(self): |
|
|
151 |
return len(self.annotation) |
|
|
152 |
|
|
|
153 |
|
|
|
154 |
if __name__ == '__main__': |
|
|
155 |
dataset = Chexpert_Dataset() |
|
|
156 |
print(dataset[0]) |