Diff of /pipeline.py [000000] .. [df6751]

Switch to unified view

a b/pipeline.py
1
import os
2
import pandas as pd
3
from tqdm import tqdm
4
import torch
5
from torchvision import transforms
6
from torch.utils.data import DataLoader, Dataset
7
from torchvision.datasets.folder import pil_loader
8
9
data_cat = ['train', 'valid'] # data categories
10
11
def get_study_level_data(study_type):
12
    """
13
    Returns a dict, with keys 'train' and 'valid' and respective values as study level dataframes, 
14
    these dataframes contain three columns 'Path', 'Count', 'Label'
15
    Args:
16
        study_type (string): one of the seven study type folder names in 'train/valid/test' dataset 
17
    """
18
    study_data = {}
19
    study_label = {'positive': 1, 'negative': 0}
20
    for phase in data_cat:
21
        BASE_DIR = 'MURA-v1.0/%s/%s/' % (phase, study_type)
22
        patients = list(os.walk(BASE_DIR))[0][1] # list of patient folder names
23
        study_data[phase] = pd.DataFrame(columns=['Path', 'Count', 'Label'])
24
        i = 0
25
        for patient in tqdm(patients): # for each patient folder
26
            for study in os.listdir(BASE_DIR + patient): # for each study in that patient folder
27
                label = study_label[study.split('_')[1]] # get label 0 or 1
28
                path = BASE_DIR + patient + '/' + study + '/' # path to this study
29
                study_data[phase].loc[i] = [path, len(os.listdir(path)), label] # add new row
30
                i+=1
31
    return study_data
32
33
class ImageDataset(Dataset):
34
    """training dataset."""
35
36
    def __init__(self, df, transform=None):
37
        """
38
        Args:
39
            df (pd.DataFrame): a pandas DataFrame with image path and labels.
40
            transform (callable, optional): Optional transform to be applied
41
                on a sample.
42
        """
43
        self.df = df
44
        self.transform = transform
45
46
    def __len__(self):
47
        return len(self.df)
48
49
    def __getitem__(self, idx):
50
        study_path = self.df.iloc[idx, 0]
51
        count = self.df.iloc[idx, 1]
52
        images = []
53
        for i in range(count):
54
            image = pil_loader(study_path + 'image%s.png' % (i+1))
55
            images.append(self.transform(image))
56
        images = torch.stack(images)
57
        label = self.df.iloc[idx, 2]
58
        sample = {'images': images, 'label': label}
59
        return sample
60
61
def get_dataloaders(data, batch_size=8, study_level=False):
62
    '''
63
    Returns dataloader pipeline with data augmentation
64
    '''
65
    data_transforms = {
66
        'train': transforms.Compose([
67
                transforms.Resize((224, 224)),
68
                transforms.RandomHorizontalFlip(),
69
                transforms.RandomRotation(10),
70
                transforms.ToTensor(),
71
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
72
        ]),
73
        'valid': transforms.Compose([
74
            transforms.Resize((224, 224)),
75
            transforms.ToTensor(),
76
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
77
        ]),
78
    }
79
    image_datasets = {x: ImageDataset(data[x], transform=data_transforms[x]) for x in data_cat}
80
    dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in data_cat}
81
    return dataloaders
82
83
if __name__=='main':
84
    pass