Switch to unified view

a b/datacode/natural_image_data.py
1
import os, sys, time
2
import random
3
4
import pandas as pd
5
import PIL
6
7
import torch
8
import torchvision
9
import torchvision.transforms as torch_transforms
10
from datacode import augmentations as augs
11
12
13
class Cifar100Dataset(torch.utils.data.Dataset):
14
    def __init__(self, images_folder, csv_path, transform = None,
15
                    return_label = False):
16
        """
17
        Simple CIFAR Image data loader. image_size:32x32
18
        return_label: fine_class, coarse_label, none
19
        """
20
21
        self.images_folder = images_folder
22
        self.df = pd.read_csv(csv_path)
23
24
        self._getitem_method = self._get_image_only
25
        if return_label:
26
            self._getitem_method = self._get_image_label
27
            self.label_type = return_label
28
29
        if transform: self.transform = transform
30
        else: self.transform = torch_transforms.ToTensor()
31
32
    def _get_image_only(self, index):
33
        imgpath = os.path.join(self.images_folder, self.df["filename"][index])
34
        image = PIL.Image.open(imgpath)
35
        image = self.transform(image)
36
        return image
37
38
    def _get_image_label(self, index):
39
        imgpath = os.path.join(self.images_folder, self.df["filename"][index])
40
        image = PIL.Image.open(imgpath)
41
        image = self.transform(image)
42
        label = self.df[self.label_type][index]
43
        return image, label
44
45
    def __len__(self):
46
        return len(self.df)
47
48
    def __getitem__(self, index):
49
        return self._getitem_method(index)
50
51
    def get_info(self):
52
        return {
53
            "DataSize": self.__len__(),
54
            "Transforms": str(self.transform),
55
        }
56
57
58
59
if __name__ == "__main__":
60
61
    traindataset = Cifar100Dataset( images_folder= os.path.join(CFG.datapath,"train_images"),
62
                                    csv_path= os.path.join(CFG.datapath,"train_list.csv"),
63
                                    transform = transform_obj)
64
65
    validdataset = Cifar100Dataset( images_folder= os.path.join(CFG.datapath,"test_images"),
66
                                    csv_path= os.path.join(CFG.datapath,"test_list.csv"),
67
                                    transform = transform_obj)