--- a +++ b/pytorch/dataloader.py @@ -0,0 +1,140 @@ +from __future__ import absolute_import, division, print_function + +import re +from os import getcwd +from os.path import join + +import numpy as np +import pandas as pd +import torch.utils.data as data +from PIL import Image + +from torchvision.datasets.utils import check_integrity, download_url + + +class MuraDataset(data.Dataset): + """`MURA <https://stanfordmlgroup.github.io/projects/mura/>`_ Dataset : + Towards Radiologist-Level Abnormality Detection in Musculoskeletal Radiographs. + """ + url = "https://cs.stanford.edu/group/mlgroup/mura-v1.0.zip" + filename = "mura-v1.0.zip" + md5_checksum = '4c36feddb7f5698c8bf291b912c438b1' + _patient_re = re.compile(r'patient(\d+)') + _study_re = re.compile(r'study(\d+)') + _image_re = re.compile(r'image(\d+)') + _study_type_re = re.compile(r'XR_(\w+)') + + def __init__(self, csv_f, transform=None, download=False): + self.df = pd.read_csv(csv_f, names=['img', 'label'], header=None) + self.imgs = self.df.img.values.tolist() + self.labels = self.df.label.values.tolist() + # following datasets/folder.py's weird convention here... + self.samples = [tuple(x) for x in self.df.values] + # number of unique classes + self.classes = np.unique(self.labels) + self.balanced_weights = self.balance_class_weights() + + self.transform = transform + + def __len__(self): + return len(self.imgs) + + def _parse_patient(self, img_filename): + return int(self._patient_re.search(img_filename).group(1)) + + def _parse_study(self, img_filename): + return int(self._study_re.search(img_filename).group(1)) + + def _parse_image(self, img_filename): + return int(self._image_re.search(img_filename).group(1)) + + def _parse_study_type(self, img_filename): + return self._study_type_re.search(img_filename).group(1) + + def download_and_uncompress_tarball(tarball_url, dataset_dir): + """Downloads the `tarball_url` and uncompresses it locally. + Args: + tarball_url: The URL of a tarball file. + dataset_dir: The directory where the temporary files are stored. + """ + filename = tarball_url.split('/')[-1] + filepath = os.path.join(dataset_dir, filename) + + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, + float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + + filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) + print() + statinfo = os.stat(filepath) + print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') + if ".zip" in filename: + print("zipfile:{}".format(filepath)) + with zipfile.ZipFile(filepath, "r") as zip_ref: + zip_ref.extractall(dataset_dir) + else: + tarfile.open(filepath, 'r:gz').extractall(dataset_dir) + + def balance_class_weights(self): + count = [0] * len(self.classes) + for item in self.samples: + count[item[1]] += 1 + weight_per_class = [0.] * len(self.classes) + N = float(sum(count)) + for i in range(len(self.classes)): + weight_per_class[i] = N / float(count[i]) + weight = [0] * len(self.samples) + for idx, val in enumerate(self.samples): + weight[idx] = weight_per_class[val[1]] + return weight + + def __getitem__(self, idx): + img_filename = join(self.imgs[idx]) + patient = self._parse_patient(img_filename) + study = self._parse_study(img_filename) + image_num = self._parse_image(img_filename) + study_type = self._parse_study_type(img_filename) + + # todo(bdd) : inconsistent right now, need param for grayscale / RGB + # todo(bdd) : 'L' -> gray, 'RGB' -> Colors + image = Image.open(img_filename).convert('RGB') + label = self.labels[idx] + + if self.transform is not None: + image = self.transform(image) + + meta_data = { + 'y_true': label, + 'img_filename': img_filename, + 'patient': patient, + 'study': study, + 'study_type': study_type, + 'image_num': image_num, + 'encounter': "{}_{}_{}".format(study_type, patient, study) + } + return image, label, meta_data + + +if __name__ == '__main__': + import torchvision.transforms as transforms + import pprint + + data_dir = join(getcwd(), 'MURA-v1.0') + val_csv = join(data_dir, 'valid.csv') + val_loader = data.DataLoader( + MuraDataset(val_csv, + transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), + ])), + batch_size=1, + shuffle=False, + num_workers=1, + pin_memory=False) + + for i, (image, label, meta_data) in enumerate(val_loader): + pprint.pprint(meta_data.cpu()) + if i == 40: + break