Diff of /inpainting/data/data.py [000000] .. [92cc18]

Switch to unified view

a b/inpainting/data/data.py
1
import torch
2
import numpy as np
3
import cv2
4
import os
5
from torch.utils.data import Dataset
6
7
8
class ToTensor(object):
9
    def __call__(self, sample):
10
        entry = {}
11
        for k in sample:
12
            if k == 'rect':
13
                entry[k] = torch.IntTensor(sample[k])
14
            else:
15
                entry[k] = torch.FloatTensor(sample[k])
16
        return entry
17
18
19
class InpaintingDataset(Dataset):
20
    def __init__(self, info_list, root_dir='', im_size=(256, 256), transform=None):
21
        self.filenames = open(info_list, 'rt').read().splitlines()
22
        self.root_dir = root_dir
23
        self.transform = transform
24
        self.im_size = im_size
25
        np.random.seed(2018)
26
27
    def __len__(self):
28
        return len(self.filenames)
29
30
    def read_image(self, filepath):
31
        image = cv2.imread(filepath)
32
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
33
        h, w, c = image.shape
34
        if h != self.im_size[0] or w != self.im_size[1]:
35
            ratio = max(1.0 * self.im_size[0] / h, 1.0 * self.im_size[1] / w)
36
            im_scaled = cv2.resize(image, None, fx=ratio, fy=ratio)
37
            h, w, _ = im_scaled.shape
38
            h_idx = (h - self.im_size[0]) // 2
39
            w_idx = (w - self.im_size[1]) // 2
40
            im_scaled = im_scaled[h_idx:h_idx + self.im_size[0], w_idx:w_idx + self.im_size[1], :]
41
            im_scaled = np.transpose(im_scaled, [2, 0, 1])
42
        else:
43
            im_scaled = np.transpose(image, [2, 0, 1])
44
        return im_scaled
45
46
    def __getitem__(self, idx):
47
        image = self.read_image(os.path.join(self.root_dir, self.filenames[idx]))
48
        sample = {'gt': image}
49
        if self.transform:
50
            sample = self.transform(sample)
51
        return sample