a b/rvseg/unet-data.py
1
#!/usr/bin/env python
2
3
from __future__ import division, print_function
4
5
import os, glob
6
import argparse
7
from PIL import Image
8
import patient
9
10
11
def generate_unet_data(patient_dirs, x_dir, y_dir, file_format="{:04d}"):
12
    BW_8BIT = 'L'
13
    i = 0
14
    for patient_dir in patient_dirs:
15
        p = patient.PatientData(patient_dir)
16
        for image, endo in zip(p.images, p.endocardium_masks):
17
            x_outfile = os.path.join(x_dir, file_format.format(i) + ".jpg")
18
            Image.fromarray(image, BW_8BIT).save(x_outfile)
19
            y_outfile = os.path.join(y_dir, file_format.format(i) + ".png")
20
            Image.fromarray(endo, BW_8BIT).save(y_outfile)
21
            i += 1
22
23
def main(args):
24
    glob_search = os.path.join(args.indir, "patient*")
25
    patient_dirs = glob.glob(glob_search)
26
27
    print("Found patient directories:")
28
    for patient_dir in patient_dirs:
29
        print(patient_dir)
30
31
    split_index = int((1-args.split/100) * len(patient_dirs))
32
    train_dirs = patient_dirs[:split_index]
33
    test_dirs = patient_dirs[split_index:]
34
    print("First {} patients used as training set, remaining as test.".format(split_index))
35
36
    x_train = os.path.join(args.outdir, "train/img/0")
37
    y_train = os.path.join(args.outdir, "train/gt/0")
38
    os.makedirs(x_train)
39
    os.makedirs(y_train)
40
41
    x_test = os.path.join(args.outdir, "test/img/0")
42
    y_test = os.path.join(args.outdir, "test/gt/0")
43
    os.makedirs(x_test)
44
    os.makedirs(y_test)
45
46
    generate_unet_data(train_dirs, x_train, y_train)
47
    generate_unet_data(test_dirs, x_test, y_test)
48
49
50
if __name__ == '__main__':
51
    parser = argparse.ArgumentParser(description="Generate data for U-Net from RV MRI dicom images.")
52
    parser.add_argument('indir', default='.', help="TrainingSet/ directory")
53
    parser.add_argument('-o', '--outdir', default='.', help="Directory to write output data")
54
    parser.add_argument('-s', '--split', default=20, type=int, help='Percentage of patients used for test set')
55
    args = parser.parse_args()
56
    main(args)