[16dd74]: / dsb2018_topcoders / albu / src / dataset / neural_dataset.py

Download this file

152 lines (128 with data), 6.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import random
import numpy as np
from augmentations.composition import Compose
from augmentations.transforms import ToTensor
from dataset.abstract_image_provider import AbstractImageProvider
from .image_cropper import ImageCropper, DVCropper
class Dataset:
def __init__(self, image_provider: AbstractImageProvider, image_indexes, config, stage='train', transforms=None):
self.pad = 0 if stage=='train' else config.test_pad
self.image_provider = image_provider
self.image_indexes = image_indexes if isinstance(image_indexes, list) else image_indexes.tolist()
if stage != 'train' and len(self.image_indexes) % 2: #todo bugreport it
self.image_indexes += [self.image_indexes[-1]]
self.stage = stage
self.keys = {'image', 'image_name'}
self.config = config
normalize = {'mean': [124 / 255, 117 / 255, 104 / 255],
'std': [1 / (.0167 * 255)] * 3}
self.transforms = Compose([transforms, ToTensor(config.num_classes, config.sigmoid, normalize)])
self.croppers = {}
def __getitem__(self, item):
raise NotImplementedError
def get_cropper(self, image_id, val=False):
#todo maybe cache croppers for different sizes too speedup if it's slow part?
if image_id not in self.croppers:
image = self.image_provider[image_id].image
rows, cols = image.shape[:2]
if self.config.ignore_target_size and val:
assert self.config.predict_batch_size == 1
target_rows, target_cols = rows, cols
else:
target_rows, target_cols = self.config.target_rows, self.config.target_cols
cropper = ImageCropper(rows, cols,
target_rows, target_cols,
self.pad)
self.croppers[image_id] = cropper
return self.croppers[image_id]
class TrainDataset(Dataset):
def __init__(self, image_provider, image_indexes, config, stage='train', transforms=None, partly_sequential=False):
super(TrainDataset, self).__init__(image_provider, image_indexes, config, stage, transforms=transforms)
self.keys.add('mask')
self.partly_sequential = partly_sequential
self.inner_idx = 9
self.idx = 0
masks = []
labels = []
# for im_idx in self.image_indexes:
# item = self.image_provider[im_idx]
# masks.append(item.mask)
# labels.append(item.label)
# self.dv_cropper = DVCropper(masks, labels, config.target_rows, config.target_cols)
def __getitem__(self, idx):
if self.partly_sequential:
#todo rewrite somehow better
if self.inner_idx > 8:
self.idx = idx
self.inner_idx = 0
self.inner_idx += 1
im_idx = self.image_indexes[self.idx % len(self.image_indexes)]
else:
im_idx = self.image_indexes[idx % len(self.image_indexes)]
cropper = self.get_cropper(im_idx)
item = self.image_provider[im_idx]
sx, sy = cropper.random_crop_coords()
if cropper.use_crop and self.image_provider.has_alpha:
for i in range(10):
alpha = cropper.crop_image(item.alpha, sx, sy)
if np.mean(alpha) > 5:
break
sx, sy = cropper.random_crop_coords()
else:
return self.__getitem__(random.randint(0, len(self.image_indexes)))
im = cropper.crop_image(item.image, sx, sy)
mask = cropper.crop_image(item.mask, sx, sy)
# im, mask, lbl = item.image, item.mask, item.label
# im, mask = self.dv_cropper.strange_method(idx % len(self.image_indexes), im, mask, lbl, sx, sy)
data = {'image': im, 'mask': mask, 'image_name': item.fn}
return self.transforms(**data)
def __len__(self):
return len(self.image_indexes) * max(self.config.epoch_size, 1) # epoch size is len images
class SequentialDataset(Dataset):
def __init__(self, image_provider, image_indexes, config, stage='test', transforms=None):
super(SequentialDataset, self).__init__(image_provider, image_indexes, config, stage, transforms=transforms)
self.good_tiles = []
self.init_good_tiles()
self.keys.update({'geometry'})
def init_good_tiles(self):
self.good_tiles = []
for im_idx in self.image_indexes:
cropper = self.get_cropper(im_idx, val=True)
positions = cropper.positions
if self.image_provider.has_alpha:
item = self.image_provider[im_idx]
alpha_generator = cropper.sequential_crops(item.alpha)
for idx, alpha in enumerate(alpha_generator):
if np.mean(alpha) > 5:
self.good_tiles.append((im_idx, *positions[idx]))
else:
for pos in positions:
self.good_tiles.append((im_idx, *pos))
def prepare_image(self, item, cropper, sx, sy):
im = cropper.crop_image(item.image, sx, sy)
rows, cols = item.image.shape[:2]
geometry = {'rows': rows, 'cols': cols, 'sx': sx, 'sy': sy}
data = {'image': im, 'image_name': item.fn, 'geometry': geometry}
return data
def __getitem__(self, idx):
if idx >= self.__len__():
return None
im_idx, sx, sy = self.good_tiles[idx]
cropper = self.get_cropper(im_idx)
item = self.image_provider[im_idx]
data = self.prepare_image(item, cropper, sx, sy)
return self.transforms(**data)
def __len__(self):
return len(self.good_tiles)
class ValDataset(SequentialDataset):
def __init__(self, image_provider, image_indexes, config, stage='train', transforms=None):
super(ValDataset, self).__init__(image_provider, image_indexes, config, stage, transforms=transforms)
self.keys.add('mask')
def __getitem__(self, idx):
im_idx, sx, sy = self.good_tiles[idx]
cropper = self.get_cropper(im_idx)
item = self.image_provider[im_idx]
data = self.prepare_image(item, cropper, sx, sy)
mask = cropper.crop_image(item.mask, sx, sy)
data.update({'mask': mask})
return self.transforms(**data)