Switch to unified view

a b/datasets/dataset_classifier.py
1
import os,sys
2
import numpy as np
3
from PIL import Image as PILImage
4
import torch
5
import torch.nn.functional as F
6
from torch.utils import data as data
7
from torchvision import transforms as transforms
8
9
# dataset for sign detection and char detection
10
class COVID_CT_DATA(data.Dataset):
11
12
     def __init__(self, **kwargs):           
13
         super(COVID_CT_DATA).__init__()
14
         self.stage = kwargs['stage']
15
         # this returns the path to data dir
16
         self.data = kwargs['data']         
17
         self.fs = sorted(os.listdir(self.data))
18
         self.size = kwargs['img_size']
19
         # this returns the path to 
20
         self.img_fname = None
21
22
     def transform_img(self, img):
23
         # Faster R-CNN does the normalization
24
         t_ = transforms.Compose([
25
                             #transforms.ToPILImage(),
26
                             transforms.Resize(self.size),
27
                             transforms.ToTensor(),
28
                             ])
29
         img = t_(img)
30
         return img
31
32
     def load_img_label(self, idx):
33
         lab=torch.zeros(3, dtype=torch.float)
34
         lab[int(self.fs[idx].split('_')[0])] = 1
35
         im = PILImage.open(os.path.join(self.data, self.fs[idx]))
36
         if im.mode !='RGB':
37
            im = im.convert(mode='RGB')
38
         im = self.transform_img(im)
39
         return im, lab
40
41
     #'magic' method: size of the dataset
42
     def __len__(self):
43
         return len(os.listdir(self.data))
44
45
     # return one datapoint
46
     def __getitem__(self, idx):
47
         X,y = self.load_img_label(idx)
48
         return X,y
49
50