a b/data/dataset.py
1
import numpy as np
2
import pandas as pd
3
import torch
4
from PIL import Image
5
#import torchvision.transforms as T
6
import torch.nn as nn
7
# from detectron2.data import DatasetMapper
8
9
from util import constants as C
10
from .transforms import get_transforms
11
import albumentations as A
12
from albumentations.pytorch import ToTensorV2
13
import albumentations.augmentations as AA
14
15
import pdb
16
import cv2
17
18
class SegmentationDataset(torch.utils.data.Dataset):
19
    def __init__(self, dataset_path, transforms=None, split='train', 
20
                augmentation=None, image_size=224, pretrained=False):
21
        try:
22
            self._df = pd.read_csv(dataset_path).sort_values(['batch', 'pair_idx']).reset_index(drop = True)
23
        except:
24
            self._df = pd.read_csv(dataset_path)
25
        #self._df = self._df.sample(frac = 0.15).reset_index() # Careful of index_col here
26
        self._image_path = self._df['image_path']
27
        self._mask_path = self._df['mask_path']      
28
        self._pretrained = pretrained
29
        self.augmentation = augmentation
30
        self._transforms = get_transforms(
31
            split=split,
32
            augmentation=augmentation,
33
            image_size=image_size
34
        )
35
36
    def get_batch_list(self):
37
        indices = list(self._df.index)
38
        lol = [indices[i:i+32] for i in range(0, len(indices), 32)]
39
        return lol
40
41
    def __len__(self):
42
        return len(self._df)
43
44
    def __getitem__(self, index):
45
46
        image = cv2.imread(self._image_path[index], cv2.IMREAD_UNCHANGED)
47
        image = (image - image.min())/(image.max() - image.min())*255.0 
48
        image = cv2.resize(image, (C.IMAGE_SIZE, C.IMAGE_SIZE))
49
        image = np.tile(image[...,None], [1, 1, 3])
50
        image = image.astype(np.float32) /255.
51
52
        mask = np.load(self._mask_path[index])
53
54
        mask = torch.tensor(mask.transpose(2, 0, 1), dtype = torch.float32)
55
        image = torch.tensor(image.transpose(2, 0, 1), dtype = torch.float32)
56
57
        #if self.augmentation != 'none':
58
        #    mask = self._transforms(mask)
59
        #    image = self._transforms(image)
60
61
        return image, mask
62
63
class SegmentationDemoDataset(SegmentationDataset):
64
    def __init__(self):
65
        super().__init__(dataset_path=C.TEST_DATASET_PATH)
66
67
class ImageDetectionDataset(torch.utils.data.Dataset):
68
    def __init__(self, image_path=None, annotations=None, augmentations=None):
69
        self._image_path = image_path
70
        self._annotations = annotations
71
        self._mapper = DatasetMapper(is_train=True,
72
                                     image_format="RGB",
73
                                     augmentations=augmentations
74
                                     )
75
76
    def __len__(self):
77
        return len(self._annotations)
78
79
    def __getitem__(self, index):
80
        sample = {}
81
        sample['annotations'] = self._annotations[index]
82
        sample['file_name'] = self._image_path[index]
83
        sample['image_id'] = index
84
        sample = self._mapper(sample)
85
        return sample
86
87
class ImageDetectionDemoDataset(ImageDetectionDataset):
88
    def __init__(self):
89
        super().__init__(image_path=C.TEST_IMG_PATH,
90
                         annotations=[[{'bbox': [438, 254, 455, 271], 'bbox_mode': 0, 'category_id': 0},
91
                                       {'bbox': [388, 259, 408, 279], 'bbox_mode': 0, 'category_id': 1}]] * 2,
92
                         augmentations=[])
93
94
class ImageClassificationDataset(torch.utils.data.Dataset):
95
    def __init__(self, image_path=None, labels=None, transforms=None):
96
        self._image_path = image_path
97
        self._labels = labels
98
        self._transforms = transforms
99
100
    def __len__(self):
101
        return len(self._labels)
102
103
    def __getitem__(self, index):
104
        label = torch.tensor(np.float64(self._labels[index]))
105
        image = Image.open(self._image_path[index]).convert('RGB')
106
        if self._transforms is not None:
107
            image = self._transforms(image)
108
109
        return image, label
110
111
class ImageClassificationDemoDataset(ImageClassificationDataset):
112
    def __init__(self):
113
        super().__init__(image_path=C.TEST_IMG_PATH, labels=[
114
            0, 1], transforms=T.Compose([T.Resize((224, 224)), T.ToTensor()]))