|
a |
|
b/pipeline.py |
|
|
1 |
import os |
|
|
2 |
import pandas as pd |
|
|
3 |
from tqdm import tqdm |
|
|
4 |
import torch |
|
|
5 |
from torchvision import transforms |
|
|
6 |
from torch.utils.data import DataLoader, Dataset |
|
|
7 |
from torchvision.datasets.folder import pil_loader |
|
|
8 |
|
|
|
9 |
data_cat = ['train', 'valid'] # data categories |
|
|
10 |
|
|
|
11 |
def get_study_level_data(study_type): |
|
|
12 |
""" |
|
|
13 |
Returns a dict, with keys 'train' and 'valid' and respective values as study level dataframes, |
|
|
14 |
these dataframes contain three columns 'Path', 'Count', 'Label' |
|
|
15 |
Args: |
|
|
16 |
study_type (string): one of the seven study type folder names in 'train/valid/test' dataset |
|
|
17 |
""" |
|
|
18 |
study_data = {} |
|
|
19 |
study_label = {'positive': 1, 'negative': 0} |
|
|
20 |
for phase in data_cat: |
|
|
21 |
BASE_DIR = 'MURA-v1.0/%s/%s/' % (phase, study_type) |
|
|
22 |
patients = list(os.walk(BASE_DIR))[0][1] # list of patient folder names |
|
|
23 |
study_data[phase] = pd.DataFrame(columns=['Path', 'Count', 'Label']) |
|
|
24 |
i = 0 |
|
|
25 |
for patient in tqdm(patients): # for each patient folder |
|
|
26 |
for study in os.listdir(BASE_DIR + patient): # for each study in that patient folder |
|
|
27 |
label = study_label[study.split('_')[1]] # get label 0 or 1 |
|
|
28 |
path = BASE_DIR + patient + '/' + study + '/' # path to this study |
|
|
29 |
study_data[phase].loc[i] = [path, len(os.listdir(path)), label] # add new row |
|
|
30 |
i+=1 |
|
|
31 |
return study_data |
|
|
32 |
|
|
|
33 |
class ImageDataset(Dataset): |
|
|
34 |
"""training dataset.""" |
|
|
35 |
|
|
|
36 |
def __init__(self, df, transform=None): |
|
|
37 |
""" |
|
|
38 |
Args: |
|
|
39 |
df (pd.DataFrame): a pandas DataFrame with image path and labels. |
|
|
40 |
transform (callable, optional): Optional transform to be applied |
|
|
41 |
on a sample. |
|
|
42 |
""" |
|
|
43 |
self.df = df |
|
|
44 |
self.transform = transform |
|
|
45 |
|
|
|
46 |
def __len__(self): |
|
|
47 |
return len(self.df) |
|
|
48 |
|
|
|
49 |
def __getitem__(self, idx): |
|
|
50 |
study_path = self.df.iloc[idx, 0] |
|
|
51 |
count = self.df.iloc[idx, 1] |
|
|
52 |
images = [] |
|
|
53 |
for i in range(count): |
|
|
54 |
image = pil_loader(study_path + 'image%s.png' % (i+1)) |
|
|
55 |
images.append(self.transform(image)) |
|
|
56 |
images = torch.stack(images) |
|
|
57 |
label = self.df.iloc[idx, 2] |
|
|
58 |
sample = {'images': images, 'label': label} |
|
|
59 |
return sample |
|
|
60 |
|
|
|
61 |
def get_dataloaders(data, batch_size=8, study_level=False): |
|
|
62 |
''' |
|
|
63 |
Returns dataloader pipeline with data augmentation |
|
|
64 |
''' |
|
|
65 |
data_transforms = { |
|
|
66 |
'train': transforms.Compose([ |
|
|
67 |
transforms.Resize((224, 224)), |
|
|
68 |
transforms.RandomHorizontalFlip(), |
|
|
69 |
transforms.RandomRotation(10), |
|
|
70 |
transforms.ToTensor(), |
|
|
71 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
72 |
]), |
|
|
73 |
'valid': transforms.Compose([ |
|
|
74 |
transforms.Resize((224, 224)), |
|
|
75 |
transforms.ToTensor(), |
|
|
76 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
77 |
]), |
|
|
78 |
} |
|
|
79 |
image_datasets = {x: ImageDataset(data[x], transform=data_transforms[x]) for x in data_cat} |
|
|
80 |
dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in data_cat} |
|
|
81 |
return dataloaders |
|
|
82 |
|
|
|
83 |
if __name__=='main': |
|
|
84 |
pass |