Switch to unified view

a b/NIH-Chest-X-ray-dataset.py
1
# Copyright 2022 Cristóbal Alcázar
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""NIH Chest X-ray Dataset"""
15
16
17
import os
18
import datasets
19
20
from requests import get
21
from pandas import read_csv
22
23
logger = datasets.logging.get_logger(__name__)
24
25
_CITATION = """\
26
@inproceedings{Wang_2017,
27
    doi = {10.1109/cvpr.2017.369},
28
    url = {https://doi.org/10.1109%2Fcvpr.2017.369},
29
    year = 2017,
30
    month = {jul},
31
    publisher = {{IEEE}
32
},
33
    author = {Xiaosong Wang and Yifan Peng and Le Lu and Zhiyong Lu and Mohammadhadi Bagheri and Ronald M. Summers},
34
    title = {{ChestX}-Ray8: Hospital-Scale Chest X-Ray Database and Benchmarks on Weakly-Supervised Classification and Localization of Common Thorax Diseases},
35
    booktitle = {2017 {IEEE} Conference on Computer Vision and Pattern Recognition ({CVPR})}
36
}
37
"""
38
39
40
_DESCRIPTION = """\
41
The NIH Chest X-ray dataset consists of 100,000 de-identified images of chest x-rays. The images are in PNG format.
42
43
The data is provided by the NIH Clinical Center and is available through the NIH download site: https://nihcc.app.box.com/v/ChestXray-NIHCC
44
"""
45
46
47
_HOMEPAGE = "https://nihcc.app.box.com/v/chestxray-nihcc"
48
49
50
_REPO = "https://huggingface.co/datasets/alkzar90/NIH-Chest-X-ray-dataset/resolve/main/data"
51
52
53
_IMAGE_URLS = [
54
    f"{_REPO}/images/images_001.zip",
55
    f"{_REPO}/images/images_002.zip",
56
    f"{_REPO}/images/images_003.zip",
57
    f"{_REPO}/images/images_004.zip",
58
    f"{_REPO}/images/images_005.zip",
59
    f"{_REPO}/images/images_006.zip",
60
    f"{_REPO}/images/images_007.zip",
61
    f"{_REPO}/images/images_008.zip",
62
    f"{_REPO}/images/images_009.zip",
63
    f"{_REPO}/images/images_010.zip",
64
    f"{_REPO}/images/images_011.zip",
65
    f"{_REPO}/images/images_012.zip"
66
    #'https://huggingface.co/datasets/alkzar90/NIH-Chest-X-ray-dataset/resolve/main/dummy/0.0.0/images_001.tar.gz',
67
    #'https://huggingface.co/datasets/alkzar90/NIH-Chest-X-ray-dataset/resolve/main/dummy/0.0.0/images_002.tar.gz'
68
]
69
70
71
_URLS = {
72
    "train_val_list": f"{_REPO}/train_val_list.txt",
73
    "test_list": f"{_REPO}/test_list.txt",
74
    "labels": f"{_REPO}/Data_Entry_2017_v2020.csv",
75
    "BBox": f"{_REPO}/BBox_List_2017.csv",
76
    "image_urls": _IMAGE_URLS
77
}
78
79
80
_LABEL2IDX = {"No Finding": 0,
81
         "Atelectasis": 1,
82
         "Cardiomegaly": 2,
83
         "Effusion": 3,
84
         "Infiltration": 4,
85
         "Mass": 5,
86
         "Nodule": 6,
87
         "Pneumonia": 7,
88
         "Pneumothorax": 8,
89
         "Consolidation": 9,
90
         "Edema": 10,
91
         "Emphysema": 11,
92
         "Fibrosis": 12,
93
         "Pleural_Thickening": 13,
94
         "Hernia": 14}
95
96
97
_NAMES = list(_LABEL2IDX.keys())
98
99
100
class ChestXray14Config(datasets.BuilderConfig):
101
    """NIH Image Chest X-ray14 configuration."""
102
    
103
    def __init__(self, name, **kwargs):
104
        super(ChestXray14Config, self).__init__(
105
        version=datasets.Version("1.0.0"),
106
        name=name,
107
        description="NIH ChestX-ray14",
108
        **kwargs,
109
        )
110
111
112
113
class ChestXray14(datasets.GeneratorBasedBuilder):
114
    """NIH Image Chest X-ray14 dataset."""
115
116
117
    BUILDER_CONFIGS = [
118
        ChestXray14Config("image-classification"),
119
        ChestXray14Config("object-detection"),
120
    ]
121
122
    def _info(self):
123
        if self.config.name == "image-classification":
124
            features = datasets.Features(
125
                       {
126
                "image": datasets.Image(),
127
                "labels": datasets.features.Sequence(
128
                                     datasets.features.ClassLabel(
129
                                        num_classes=len(_NAMES),
130
                                        names=_NAMES
131
                                     )
132
                                 ),
133
                       }
134
                    )
135
            keys = ("image", "labels")
136
137
138
        if self.config.name == "object-detection":
139
            features = datasets.Features(
140
                       {
141
            "image_id": datasets.Value("string"),
142
            "patient_id": datasets.Value("int32"),
143
                "image": datasets.Image(),
144
            "width": datasets.Value("int32"),
145
            "height": datasets.Value("int32"),
146
                       }
147
                    )
148
            object_dict = {
149
            "image_id": datasets.Value("string"), 
150
            "area": datasets.Value("int64"),
151
            "bbox": datasets.Sequence(datasets.Value("float32"), length=4),
152
            }
153
            features["objects"] = [object_dict]
154
            keys = ("image", "objects")
155
156
157
158
        return datasets.DatasetInfo(
159
            description=_DESCRIPTION,
160
            features=features,
161
            supervised_keys=keys,
162
            homepage=_HOMEPAGE,
163
            citation=_CITATION,
164
        )
165
166
167
    def _split_generators(self, dl_manager):
168
        # Get the image names that belong to the train-val dataset
169
        logger.info("Downloading the train_val_list image names")
170
        train_val_list = get(_URLS['train_val_list']).iter_lines()
171
        train_val_list = set([x.decode('UTF8') for x in train_val_list])
172
        logger.info(f"Check train_val_list: {train_val_list}")
173
174
        # Create list for store the name of the images for each dataset
175
        train_files = []
176
        test_files = []
177
        
178
        # Download batches
179
        data_files = dl_manager.download_and_extract(_URLS["image_urls"])
180
181
        # Iterate trought image folder and check if they belong to
182
        # the trainset or testset
183
184
        for batch in data_files:
185
          logger.info(f"Batch for data_files: {batch}")
186
          path_files = dl_manager.iter_files(batch)
187
          for img in path_files:
188
            if os.path.basename(img) in train_val_list:
189
              train_files.append(img)
190
            else:
191
              test_files.append(img)
192
        
193
        return [
194
            datasets.SplitGenerator(
195
            name=datasets.Split.TRAIN,
196
            gen_kwargs={
197
                "files": train_files
198
            }
199
200
            ),
201
            datasets.SplitGenerator(
202
            name=datasets.Split.TEST,
203
            gen_kwargs={
204
                "files": test_files
205
            }
206
            )
207
        ]
208
209
    def _generate_examples(self, files):
210
211
        if self.config.name == "image-classification":
212
            # Read csv with image labels
213
            label_csv = read_csv(_URLS["labels"])
214
            for i, path in enumerate(files):
215
                file_name = os.path.basename(path)
216
                # Get image id to filter the respective row of the csv  
217
                image_id = file_name
218
                image_labels = label_csv[label_csv["Image Index"] == image_id]["Finding Labels"].values[0].split("|")
219
                if file_name.endswith(".png"):
220
                    yield i, {
221
                    "image": path,
222
                    "labels": image_labels,
223
                }
224