Diff of /pytorch/dataloader.py [000000] .. [bca7a0]

Switch to unified view

a b/pytorch/dataloader.py
1
from __future__ import absolute_import, division, print_function
2
3
import re
4
from os import getcwd
5
from os.path import join
6
7
import numpy as np
8
import pandas as pd
9
import torch.utils.data as data
10
from PIL import Image
11
12
from torchvision.datasets.utils import check_integrity, download_url
13
14
15
class MuraDataset(data.Dataset):
16
    """`MURA <https://stanfordmlgroup.github.io/projects/mura/>`_ Dataset :
17
    Towards Radiologist-Level Abnormality Detection in Musculoskeletal Radiographs.
18
    """
19
    url = "https://cs.stanford.edu/group/mlgroup/mura-v1.0.zip"
20
    filename = "mura-v1.0.zip"
21
    md5_checksum = '4c36feddb7f5698c8bf291b912c438b1'
22
    _patient_re = re.compile(r'patient(\d+)')
23
    _study_re = re.compile(r'study(\d+)')
24
    _image_re = re.compile(r'image(\d+)')
25
    _study_type_re = re.compile(r'XR_(\w+)')
26
27
    def __init__(self, csv_f, transform=None, download=False):
28
        self.df = pd.read_csv(csv_f, names=['img', 'label'], header=None)
29
        self.imgs = self.df.img.values.tolist()
30
        self.labels = self.df.label.values.tolist()
31
        # following datasets/folder.py's weird convention here...
32
        self.samples = [tuple(x) for x in self.df.values]
33
        # number of unique classes
34
        self.classes = np.unique(self.labels)
35
        self.balanced_weights = self.balance_class_weights()
36
37
        self.transform = transform
38
39
    def __len__(self):
40
        return len(self.imgs)
41
42
    def _parse_patient(self, img_filename):
43
        return int(self._patient_re.search(img_filename).group(1))
44
45
    def _parse_study(self, img_filename):
46
        return int(self._study_re.search(img_filename).group(1))
47
48
    def _parse_image(self, img_filename):
49
        return int(self._image_re.search(img_filename).group(1))
50
51
    def _parse_study_type(self, img_filename):
52
        return self._study_type_re.search(img_filename).group(1)
53
54
    def download_and_uncompress_tarball(tarball_url, dataset_dir):
55
        """Downloads the `tarball_url` and uncompresses it locally.
56
        Args:
57
            tarball_url: The URL of a tarball file.
58
            dataset_dir: The directory where the temporary files are stored.
59
        """
60
        filename = tarball_url.split('/')[-1]
61
        filepath = os.path.join(dataset_dir, filename)
62
63
        def _progress(count, block_size, total_size):
64
            sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
65
                                                             float(count * block_size) / float(total_size) * 100.0))
66
            sys.stdout.flush()
67
68
        filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
69
        print()
70
        statinfo = os.stat(filepath)
71
        print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
72
        if ".zip" in filename:
73
            print("zipfile:{}".format(filepath))
74
            with zipfile.ZipFile(filepath, "r") as zip_ref:
75
                zip_ref.extractall(dataset_dir)
76
        else:
77
            tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
78
79
    def balance_class_weights(self):
80
        count = [0] * len(self.classes)
81
        for item in self.samples:
82
            count[item[1]] += 1
83
        weight_per_class = [0.] * len(self.classes)
84
        N = float(sum(count))
85
        for i in range(len(self.classes)):
86
            weight_per_class[i] = N / float(count[i])
87
        weight = [0] * len(self.samples)
88
        for idx, val in enumerate(self.samples):
89
            weight[idx] = weight_per_class[val[1]]
90
        return weight
91
92
    def __getitem__(self, idx):
93
        img_filename = join(self.imgs[idx])
94
        patient = self._parse_patient(img_filename)
95
        study = self._parse_study(img_filename)
96
        image_num = self._parse_image(img_filename)
97
        study_type = self._parse_study_type(img_filename)
98
99
        # todo(bdd) : inconsistent right now, need param for grayscale / RGB
100
        # todo(bdd) : 'L' -> gray, 'RGB' -> Colors
101
        image = Image.open(img_filename).convert('RGB')
102
        label = self.labels[idx]
103
104
        if self.transform is not None:
105
            image = self.transform(image)
106
107
        meta_data = {
108
            'y_true': label,
109
            'img_filename': img_filename,
110
            'patient': patient,
111
            'study': study,
112
            'study_type': study_type,
113
            'image_num': image_num,
114
            'encounter': "{}_{}_{}".format(study_type, patient, study)
115
        }
116
        return image, label, meta_data
117
118
119
if __name__ == '__main__':
120
    import torchvision.transforms as transforms
121
    import pprint
122
123
    data_dir = join(getcwd(), 'MURA-v1.0')
124
    val_csv = join(data_dir, 'valid.csv')
125
    val_loader = data.DataLoader(
126
        MuraDataset(val_csv,
127
                    transforms.Compose([
128
                        transforms.Resize(224),
129
                        transforms.CenterCrop(224),
130
                        transforms.ToTensor(),
131
                    ])),
132
        batch_size=1,
133
        shuffle=False,
134
        num_workers=1,
135
        pin_memory=False)
136
137
    for i, (image, label, meta_data) in enumerate(val_loader):
138
        pprint.pprint(meta_data.cpu())
139
        if i == 40:
140
            break