[76022b]: / semseg_train / dataset.py

Download this file

57 lines (42 with data), 1.7 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import numpy as np
import cv2
from torch.utils.data import Dataset
from pathlib import Path
data_path = Path('data')
class AngyodysplasiaDataset(Dataset):
def __init__(self, img_paths: list, to_augment=False, transform=None, mode='train', limit=None):
self.img_paths = img_paths
self.to_augment = to_augment
self.transform = transform
self.mode = mode
self.limit = limit
#サンプル数を返す。例)len(AngyodysplasiaDataset)でデータの長さを返す
def __len__(self):
if self.limit is None:
return len(self.img_paths)
else:
return self.limit
#キーに対応するサンプルを返す。例)AngyodysplasiaDataset[0]で0番目のデータを返す
def __getitem__(self, idx):
if self.limit is None:
img_file_name = self.img_paths[idx]
else:
img_file_name = np.random.choice(self.img_paths)
img = load_image(img_file_name)
if self.mode == 'train':
mask = load_mask(img_file_name)
img, mask = self.transform(img, mask)
return to_float_tensor(img), torch.from_numpy(np.expand_dims(mask, 0)).float()
else:
mask = np.zeros(img.shape[:2])
img, mask = self.transform(img, mask)
return to_float_tensor(img), str(img_file_name)
def to_float_tensor(img):
return torch.from_numpy(np.moveaxis(img, -1, 0)).float()
def load_image(path):
img = cv2.imread(str(path))
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def load_mask(path):
mask = cv2.imread(str(path).replace('images', 'masks').replace(r'.jpg', r'_a.jpg'), 0)
return (mask > 0).astype(np.uint8)