import os
import pandas as pd
from tqdm import tqdm
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets.folder import pil_loader
data_cat = ['train', 'valid'] # data categories
def get_study_level_data(study_type):
"""
Returns a dict, with keys 'train' and 'valid' and respective values as study level dataframes,
these dataframes contain three columns 'Path', 'Count', 'Label'
Args:
study_type (string): one of the seven study type folder names in 'train/valid/test' dataset
"""
study_data = {}
study_label = {'positive': 1, 'negative': 0}
for phase in data_cat:
BASE_DIR = 'MURA-v1.0/%s/%s/' % (phase, study_type)
patients = list(os.walk(BASE_DIR))[0][1] # list of patient folder names
study_data[phase] = pd.DataFrame(columns=['Path', 'Count', 'Label'])
i = 0
for patient in tqdm(patients): # for each patient folder
for study in os.listdir(BASE_DIR + patient): # for each study in that patient folder
label = study_label[study.split('_')[1]] # get label 0 or 1
path = BASE_DIR + patient + '/' + study + '/' # path to this study
study_data[phase].loc[i] = [path, len(os.listdir(path)), label] # add new row
i+=1
return study_data
class ImageDataset(Dataset):
"""training dataset."""
def __init__(self, df, transform=None):
"""
Args:
df (pd.DataFrame): a pandas DataFrame with image path and labels.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.df = df
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
study_path = self.df.iloc[idx, 0]
count = self.df.iloc[idx, 1]
images = []
for i in range(count):
image = pil_loader(study_path + 'image%s.png' % (i+1))
images.append(self.transform(image))
images = torch.stack(images)
label = self.df.iloc[idx, 2]
sample = {'images': images, 'label': label}
return sample
def get_dataloaders(data, batch_size=8, study_level=False):
'''
Returns dataloader pipeline with data augmentation
'''
data_transforms = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'valid': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {x: ImageDataset(data[x], transform=data_transforms[x]) for x in data_cat}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in data_cat}
return dataloaders
if __name__=='main':
pass