[6673ef]: / scripts / eval.py

Download this file

139 lines (120 with data), 5.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python
from __future__ import division, print_function
import numpy as np
import matplotlib.pyplot as plt
from rvseg import opts, patient, dataset, models
def save_image(figname, image, mask_true, mask_pred, alpha=0.3):
cmap = plt.cm.gray
plt.figure(figsize=(12, 3.75))
plt.subplot(1, 3, 1)
plt.axis("off")
plt.imshow(image, cmap=cmap)
plt.subplot(1, 3, 2)
plt.axis("off")
plt.imshow(image, cmap=cmap)
plt.imshow(mask_pred, cmap=cmap, alpha=alpha)
plt.subplot(1, 3, 3)
plt.axis("off")
plt.imshow(image, cmap=cmap)
plt.imshow(mask_true, cmap=cmap, alpha=alpha)
plt.savefig(figname, bbox_inches='tight')
plt.close()
def sorensen_dice(y_true, y_pred):
intersection = np.sum(y_true * y_pred)
return 2*intersection / (np.sum(y_true) + np.sum(y_pred))
def jaccard(y_true, y_pred):
intersection = np.sum(y_true & y_pred)
union = np.sum(y_true | y_pred)
return intersection / union
def compute_statistics(model, generator, steps_per_epoch, return_images=False):
dices = []
jaccards = []
predictions = []
for i in range(steps_per_epoch):
images, masks_true = next(generator)
# Normally: masks_pred = model.predict(images)
# But dilated densenet cannot handle large batch size
masks_pred = np.concatenate([model.predict(image[None,:,:,:]) for image in images])
for mask_true, mask_pred in zip(masks_true, masks_pred):
y_true = mask_true[:,:,1].astype('uint8')
y_pred = np.round(mask_pred[:,:,1]).astype('uint8')
dices.append(sorensen_dice(y_true, y_pred))
jaccards.append(jaccard(y_true, y_pred))
if return_images:
for image, mask_true, mask_pred in zip(images, masks_true, masks_pred):
predictions.append((image[:,:,0], mask_true[:,:,1], mask_pred[:,:,1]))
print("Dice: {:.3f} ({:.3f})".format(np.mean(dices), np.std(dices)))
print("Jaccard: {:.3f} ({:.3f})".format(np.mean(jaccards), np.std(jaccards)))
return dices, jaccards, predictions
def main():
# Sort of a hack:
# args.outfile = file basename to store train / val dice scores
# args.checkpoint = turns on saving of images
args = opts.parse_arguments()
print("Loading dataset...")
augmentation_args = {
'rotation_range': args.rotation_range,
'width_shift_range': args.width_shift_range,
'height_shift_range': args.height_shift_range,
'shear_range': args.shear_range,
'zoom_range': args.zoom_range,
'fill_mode' : args.fill_mode,
'alpha': args.alpha,
'sigma': args.sigma,
}
train_generator, train_steps_per_epoch, \
val_generator, val_steps_per_epoch = dataset.create_generators(
args.datadir, args.batch_size,
validation_split=args.validation_split,
mask=args.classes,
shuffle_train_val=args.shuffle_train_val,
shuffle=args.shuffle,
seed=args.seed,
normalize_images=args.normalize,
augment_training=args.augment_training,
augment_validation=args.augment_validation,
augmentation_args=augmentation_args)
# get image dimensions from first batch
images, masks = next(train_generator)
_, height, width, channels = images.shape
_, _, _, classes = masks.shape
print("Building model...")
string_to_model = {
"unet": models.unet,
"dilated-unet": models.dilated_unet,
"dilated-densenet": models.dilated_densenet,
"dilated-densenet2": models.dilated_densenet2,
"dilated-densenet3": models.dilated_densenet3,
}
model = string_to_model[args.model]
m = model(height=height, width=width, channels=channels, classes=classes,
features=args.features, depth=args.depth, padding=args.padding,
temperature=args.temperature, batchnorm=args.batchnorm,
dropout=args.dropout)
m.load_weights(args.load_weights)
print("Training Set:")
train_dice, train_jaccard, train_images = compute_statistics(
m, train_generator, train_steps_per_epoch,
return_images=args.checkpoint)
print()
print("Validation Set:")
val_dice, val_jaccard, val_images = compute_statistics(
m, val_generator, val_steps_per_epoch,
return_images=args.checkpoint)
if args.outfile:
train_data = np.asarray([train_dice, train_jaccard]).T
val_data = np.asarray([val_dice, val_jaccard]).T
np.savetxt(args.outfile + ".train", train_data)
np.savetxt(args.outfile + ".val", val_data)
if args.checkpoint:
print("Saving images...")
for i,dice in enumerate(train_dice):
image, mask_true, mask_pred = train_images[i]
figname = "train-{:03d}-{:.3f}.png".format(i, dice)
save_image(figname, image, mask_true, np.round(mask_pred))
for i,dice in enumerate(val_dice):
image, mask_true, mask_pred = val_images[i]
figname = "val-{:03d}-{:.3f}.png".format(i, dice)
save_image(figname, image, mask_true, np.round(mask_pred))
if __name__ == '__main__':
main()