Switch to unified view

a b/Classifier/Classes/Data.py
1
#############################################################################################
2
#
3
# Project:       Mascot Defense Research Project
4
# Repository:    ALL Detection System 2020
5
# Project:       AllDS2020 Classifier
6
#
7
# Author:        Allen Akhaumere
8
# Contributors:
9
# Title:         Data Class
10
# Description:   Data class for the Acute Lymphoblastic Leukemia Pytorch CNN Classifier
11
#                ALL Classifier.
12
# License:       MIT License
13
# Last Modified: 2020-07-23
14
#
15
############################################################################################
16
17
import os
18
from PIL import Image
19
from torch.utils.data import Dataset
20
from torchvision import transforms, datasets
21
22
23
24
25
class LeukemiaDataset(Dataset):
26
27
    """
28
    Acute Lymphoblastic Leukemia Dataset Reader.
29
30
     Args:
31
           df_data: Dataframe for CSV file
32
           data_dir: path to Lymphoblastic Leukemia Data
33
           transform: transforms for performing data augmentation
34
    """
35
36
    def __init__(self, df_data, data_dir='./', transform=None):
37
        super().__init__()
38
        self.df = df_data.values
39
        self.data_dir = data_dir
40
        self.transform = transform
41
42
    def __len__(self):
43
        return len(self.df)
44
45
    def __getitem__(self, index):
46
        img_name, label = self.df[index]
47
        img_path = os.path.join(self.data_dir, img_name + '.jpg')
48
        image = Image.open(img_path)
49
        if self.transform is not None:
50
            image = self.transform(image)
51
        return image, label
52
53
54
def augmentation():
55
    """Acute Lymphoblastic Leukemia data augmentation"""
56
    mean, std_dev = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
57
    training_transforms = transforms.Compose([transforms.Resize((100, 100)),
58
                                              transforms.RandomRotation(30),
59
                                              transforms.RandomResizedCrop(100),
60
                                              transforms.RandomHorizontalFlip(),
61
                                              transforms.RandomGrayscale(p=0.1),
62
                                              transforms.ToTensor(),
63
                                              transforms.Normalize(mean=mean, std=std_dev)])
64
65
    validation_transforms = transforms.Compose([
66
        transforms.Resize((100, 100)),
67
        transforms.ToTensor(),
68
        transforms.Normalize(mean=mean, std=std_dev)])
69
70
    return training_transforms, validation_transforms
71
72
73
# Load the datasets with ImageFolder
74
def load_datasets(train_dir, training_transforms, valid_dir, validation_transforms):
75
    """ """
76
    training_dataset = datasets.ImageFolder(train_dir, transform=training_transforms)
77
    validation_dataset = datasets.ImageFolder(valid_dir, transform=validation_transforms)
78
    return training_dataset, validation_dataset