--- a +++ b/pipeline.py @@ -0,0 +1,84 @@ +import os +import pandas as pd +from tqdm import tqdm +import torch +from torchvision import transforms +from torch.utils.data import DataLoader, Dataset +from torchvision.datasets.folder import pil_loader + +data_cat = ['train', 'valid'] # data categories + +def get_study_level_data(study_type): + """ + Returns a dict, with keys 'train' and 'valid' and respective values as study level dataframes, + these dataframes contain three columns 'Path', 'Count', 'Label' + Args: + study_type (string): one of the seven study type folder names in 'train/valid/test' dataset + """ + study_data = {} + study_label = {'positive': 1, 'negative': 0} + for phase in data_cat: + BASE_DIR = 'MURA-v1.0/%s/%s/' % (phase, study_type) + patients = list(os.walk(BASE_DIR))[0][1] # list of patient folder names + study_data[phase] = pd.DataFrame(columns=['Path', 'Count', 'Label']) + i = 0 + for patient in tqdm(patients): # for each patient folder + for study in os.listdir(BASE_DIR + patient): # for each study in that patient folder + label = study_label[study.split('_')[1]] # get label 0 or 1 + path = BASE_DIR + patient + '/' + study + '/' # path to this study + study_data[phase].loc[i] = [path, len(os.listdir(path)), label] # add new row + i+=1 + return study_data + +class ImageDataset(Dataset): + """training dataset.""" + + def __init__(self, df, transform=None): + """ + Args: + df (pd.DataFrame): a pandas DataFrame with image path and labels. + transform (callable, optional): Optional transform to be applied + on a sample. + """ + self.df = df + self.transform = transform + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + study_path = self.df.iloc[idx, 0] + count = self.df.iloc[idx, 1] + images = [] + for i in range(count): + image = pil_loader(study_path + 'image%s.png' % (i+1)) + images.append(self.transform(image)) + images = torch.stack(images) + label = self.df.iloc[idx, 2] + sample = {'images': images, 'label': label} + return sample + +def get_dataloaders(data, batch_size=8, study_level=False): + ''' + Returns dataloader pipeline with data augmentation + ''' + data_transforms = { + 'train': transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomHorizontalFlip(), + transforms.RandomRotation(10), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + 'valid': transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + } + image_datasets = {x: ImageDataset(data[x], transform=data_transforms[x]) for x in data_cat} + dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in data_cat} + return dataloaders + +if __name__=='main': + pass