Diff of /scripts/submission.py [000000] .. [6673ef]

Switch to unified view

a b/scripts/submission.py
1
#!/usr/bin/env python
2
3
from __future__ import division, print_function
4
5
import os
6
import glob
7
8
import numpy as np
9
import matplotlib.pyplot as plt
10
import cv2
11
12
from rvseg import opts, patient, dataset, models
13
14
15
def load_patient_images(path, normalize=True):
16
    p = patient.PatientData(path)
17
18
    # reshape to account for channel dimension
19
    images = np.asarray(p.images, dtype='float64')[:,:,:,None]
20
21
    # maybe normalize images
22
    if normalize:
23
        dataset.normalize(images, axis=(1,2))
24
25
    return images, p.index, p.labeled, p.rotated
26
27
def get_contours(mask):
28
    mask_image = np.where(mask > 0.5, 255, 0).astype('uint8')
29
    im2, coords, hierarchy = cv2.findContours(mask_image, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
30
31
    if not coords:
32
        print("No contour detected.")
33
        coords = np.ones((1, 1, 1, 2), dtype='int')
34
    if len(coords) > 1:
35
        print("Multiple contours detected.")
36
        lengths = [len(coord) for coord in coords]
37
        coords = [coords[np.argmax(lengths)]]
38
39
    coords = np.squeeze(coords[0], axis=(1,))
40
    coords = np.append(coords, coords[:1], axis=0)
41
42
    return coords
43
44
def save_image(figname, image, mask_pred, alpha=0.3):
45
    cmap = plt.cm.gray
46
    plt.figure(figsize=(8, 3.75))
47
    plt.subplot(1, 2, 1)
48
    plt.axis("off")
49
    plt.imshow(image, cmap=cmap)
50
    plt.subplot(1, 2, 2)
51
    plt.axis("off")
52
    plt.imshow(image, cmap=cmap)
53
    plt.imshow(mask_pred, cmap=cmap, alpha=alpha)
54
    plt.savefig(figname, bbox_inches='tight')
55
    plt.close()
56
57
def main():
58
    # Sort of a hack:
59
    # args.checkpoint = turns on saving of images
60
    args = opts.parse_arguments()
61
    args.checkpoint = False     # override for now
62
63
    glob_search = os.path.join(args.datadir, "patient*")
64
    patient_dirs = sorted(glob.glob(glob_search))
65
    if len(patient_dirs) == 0:
66
        raise Exception("No patient directors found in {}".format(data_dir))
67
68
    # get image dimensions from first patient
69
    images, _, _, _ = load_patient_images(patient_dirs[0], args.normalize)
70
    _, height, width, channels = images.shape
71
    classes = 2                 # hard coded for now
72
    contour_type = {'inner': 'i', 'outer': 'o'}[args.classes]
73
74
    print("Building model...")
75
    string_to_model = {
76
        "unet": models.unet,
77
        "dilated-unet": models.dilated_unet,
78
        "dilated-densenet": models.dilated_densenet,
79
        "dilated-densenet2": models.dilated_densenet2,
80
        "dilated-densenet3": models.dilated_densenet3,
81
    }
82
    model = string_to_model[args.model]
83
84
    m = model(height=height, width=width, channels=channels, classes=classes,
85
              features=args.features, depth=args.depth, padding=args.padding,
86
              temperature=args.temperature, batchnorm=args.batchnorm,
87
              dropout=args.dropout)
88
89
    m.load_weights(args.load_weights)
90
91
    for path in patient_dirs:
92
        ret = load_patient_images(path, args.normalize)
93
        images, patient_number, frame_indices, rotated = ret
94
95
        predictions = []
96
        for image in images:
97
            mask_pred = m.predict(image[None,:,:,:]) # feed one at a time
98
            predictions.append((image[:,:,0], mask_pred[0,:,:,1]))
99
100
        for (image, mask), frame_index in zip(predictions, frame_indices):
101
            filename = "P{:02d}-{:04d}-{}contour-auto.txt".format(
102
                patient_number, frame_index, contour_type)
103
            outpath = os.path.join(args.outdir, filename)
104
            print(filename)
105
106
            contour = get_contours(mask)
107
            if rotated:
108
                height, width = image.shape
109
                x, y = contour.T
110
                x, y = height - y, x
111
                contour = np.vstack((x, y)).T
112
113
            np.savetxt(outpath, contour, fmt='%i', delimiter=' ')
114
115
            if args.checkpoint:
116
                filename = "P{:02d}-{:04d}-{}contour-auto.png".format(
117
                    patient_number, frame_index, contour_type)
118
                outpath = os.path.join(args.outdir, filename)
119
                save_image(outpath, image, np.round(mask))
120
121
if __name__ == '__main__':
122
    main()