[16dd74]: / dsb2018_topcoders / albu / src / bowl_train.py

Download this file

119 lines (100 with data), 4.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import os
import cv2
from scipy.misc import imread
import numpy as np
from utils import get_csv_folds, update_config, get_folds
from config import Config
from dataset.reading_image_provider import ReadingImageProvider, CachingImageProvider, InFolderImageProvider
from dataset.raw_image import RawImageType
from pytorch_utils.concrete_eval import FullImageEvaluator
from augmentations.transforms import aug_victor
from pytorch_utils.train import train
from merge_preds import merge_files
import json
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('config_path')
parser.add_argument('--fold', type=int)
parser.add_argument('--training', action='store_true')
args = parser.parse_args()
with open(args.config_path, 'r') as f:
cfg = json.load(f)
cfg['dataset_path'] = cfg['dataset_path'] + ('' if args.training else '_test')
config = Config(**cfg)
paths = {
'masks': 'masks_all',
'images': 'images_all',
'labels': 'labels_all',
}
fn_mapping = {
'masks': lambda name: name,
'labels': lambda name: os.path.splitext(name)[0] + '.tif'
}
if args.training:
paths = {k: os.path.join(config.dataset_path, p) for k, p in paths.items()}
else:
paths = {"images": config.dataset_path}
num_workers = 0 if os.name == 'nt' else 4
class MinSizeImageType(RawImageType):
def finalyze(self, data):
rows, cols = data.shape[:2]
nrows = (256 - rows) if rows < 256 else 0
ncols = (256 - cols) if cols < 256 else 0
if nrows > 0 or ncols > 0:
return cv2.copyMakeBorder(data, 0, nrows, 0, ncols, cv2.BORDER_CONSTANT)
return data
class SigmoidBorderImageType(MinSizeImageType):
def read_mask(self):
path = os.path.join(self.paths['masks'], self.fn_mapping['masks'](self.fn))
mask = imread(path, mode='RGB')
label = self.read_label()
fin = self.finalyze(mask)
data = np.dstack((fin[...,2], fin[...,1], (label > 0).astype(np.uint8) * 255))
return data
class BorderImageType(MinSizeImageType):
def read_mask(self):
path = os.path.join(self.paths['masks'], self.fn_mapping['masks'](self.fn))
msk = imread(path, mode='RGB')
msk[..., 2] = (msk[..., 2] > 127)
msk[..., 1] = (msk[..., 1] > 127) * (msk[..., 2] == 0)
msk[..., 0] = (msk[..., 1] == 0) * (msk[..., 2] == 0)
return self.finalyze(msk.astype(np.uint8) * 255)
class PaddedImageType(BorderImageType):
def finalyze(self, data):
rows, cols = data.shape[:2]
return cv2.copyMakeBorder(data, 0, (32-rows%32), 0, (32-cols%32), cv2.BORDER_REFLECT)
class PaddedSigmoidImageType(SigmoidBorderImageType):
def finalyze(self, data):
rows, cols = data.shape[:2]
return cv2.copyMakeBorder(data, 0, (32-rows%32), 0, (32-cols%32), cv2.BORDER_REFLECT)
def train_bowl():
torch.backends.cudnn.benchmark = True
im_type = BorderImageType if not config.sigmoid else SigmoidBorderImageType
im_val_type = PaddedImageType if not config.sigmoid else PaddedSigmoidImageType
ds = CachingImageProvider(im_type, paths, fn_mapping)
val_ds = CachingImageProvider(im_val_type, paths, fn_mapping)
folds = get_csv_folds(ds, os.path.join(config.dataset_path, 'folds.csv'))
for fold, (train_idx, val_idx) in enumerate(folds):
if args.fold is not None and int(args.fold) != fold:
continue
train(ds, val_ds, fold, train_idx, val_idx, config, num_workers=num_workers, transforms=aug_victor(.97))
def eval_bowl():
global config
test = not args.training
im_val_type = PaddedImageType if not config.sigmoid else PaddedSigmoidImageType
im_prov_type = InFolderImageProvider if test else ReadingImageProvider
ds = im_prov_type(im_val_type, paths, fn_mapping)
if not test:
folds = get_csv_folds(ds, os.path.join(config.dataset_path, 'folds.csv'))
else:
folds = [([], list(range(len(ds)))) for i in range(4)]
keval = FullImageEvaluator(config, ds, test=test, flips=3, num_workers=num_workers, border=0)
for fold, (t, e) in enumerate(folds):
if args.fold is not None and int(args.fold) != fold:
continue
keval.predict(fold, e)
if test and args.fold is None:
merge_files(keval.save_dir)
if __name__ == "__main__":
train_bowl()