|
a |
|
b/inpainting/data/data.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
import cv2 |
|
|
4 |
import os |
|
|
5 |
from torch.utils.data import Dataset |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class ToTensor(object): |
|
|
9 |
def __call__(self, sample): |
|
|
10 |
entry = {} |
|
|
11 |
for k in sample: |
|
|
12 |
if k == 'rect': |
|
|
13 |
entry[k] = torch.IntTensor(sample[k]) |
|
|
14 |
else: |
|
|
15 |
entry[k] = torch.FloatTensor(sample[k]) |
|
|
16 |
return entry |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
class InpaintingDataset(Dataset): |
|
|
20 |
def __init__(self, info_list, root_dir='', im_size=(256, 256), transform=None): |
|
|
21 |
self.filenames = open(info_list, 'rt').read().splitlines() |
|
|
22 |
self.root_dir = root_dir |
|
|
23 |
self.transform = transform |
|
|
24 |
self.im_size = im_size |
|
|
25 |
np.random.seed(2018) |
|
|
26 |
|
|
|
27 |
def __len__(self): |
|
|
28 |
return len(self.filenames) |
|
|
29 |
|
|
|
30 |
def read_image(self, filepath): |
|
|
31 |
image = cv2.imread(filepath) |
|
|
32 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
33 |
h, w, c = image.shape |
|
|
34 |
if h != self.im_size[0] or w != self.im_size[1]: |
|
|
35 |
ratio = max(1.0 * self.im_size[0] / h, 1.0 * self.im_size[1] / w) |
|
|
36 |
im_scaled = cv2.resize(image, None, fx=ratio, fy=ratio) |
|
|
37 |
h, w, _ = im_scaled.shape |
|
|
38 |
h_idx = (h - self.im_size[0]) // 2 |
|
|
39 |
w_idx = (w - self.im_size[1]) // 2 |
|
|
40 |
im_scaled = im_scaled[h_idx:h_idx + self.im_size[0], w_idx:w_idx + self.im_size[1], :] |
|
|
41 |
im_scaled = np.transpose(im_scaled, [2, 0, 1]) |
|
|
42 |
else: |
|
|
43 |
im_scaled = np.transpose(image, [2, 0, 1]) |
|
|
44 |
return im_scaled |
|
|
45 |
|
|
|
46 |
def __getitem__(self, idx): |
|
|
47 |
image = self.read_image(os.path.join(self.root_dir, self.filenames[idx])) |
|
|
48 |
sample = {'gt': image} |
|
|
49 |
if self.transform: |
|
|
50 |
sample = self.transform(sample) |
|
|
51 |
return sample |