Switch to unified view

a b/BraTs18Challege/dataprocess/data/subset.py
1
#!/usr/bin/env python
2
3
import os, sys, math, random
4
from collections import defaultdict
5
6
if sys.version_info[0] >= 3:
7
    xrange = range
8
9
def exit_with_help(argv):
10
    print("""\
11
Usage: {0} [options] dataset subset_size [output1] [output2]
12
13
This script randomly selects a subset of the dataset.
14
15
options:
16
-s method : method of selection (default 0)
17
     0 -- stratified selection (classification only)
18
     1 -- random selection
19
20
output1 : the subset (optional)
21
output2 : rest of the data (optional)
22
If output1 is omitted, the subset will be printed on the screen.""".format(argv[0]))
23
    exit(1)
24
25
def process_options(argv):
26
    argc = len(argv)
27
    if argc < 3:
28
        exit_with_help(argv)
29
30
    # default method is stratified selection
31
    method = 0
32
    subset_file = sys.stdout
33
    rest_file = None
34
35
    i = 1
36
    while i < argc:
37
        if argv[i][0] != "-":
38
            break
39
        if argv[i] == "-s":
40
            i = i + 1
41
            method = int(argv[i])
42
            if method not in [0,1]:
43
                print("Unknown selection method {0}".format(method))
44
                exit_with_help(argv)
45
        i = i + 1
46
47
    dataset = argv[i]
48
    subset_size = int(argv[i+1])
49
    if i+2 < argc:
50
        subset_file = open(argv[i+2],'w')
51
    if i+3 < argc:
52
        rest_file = open(argv[i+3],'w')
53
54
    return dataset, subset_size, method, subset_file, rest_file
55
56
def random_selection(dataset, subset_size):
57
    l = sum(1 for line in open(dataset,'r'))
58
    return sorted(random.sample(xrange(l), subset_size))
59
60
def stratified_selection(dataset, subset_size):
61
    labels = [line.split(None,1)[0] for line in open(dataset)]
62
    label_linenums = defaultdict(list)
63
    for i, label in enumerate(labels):
64
        label_linenums[label] += [i]
65
66
    l = len(labels)
67
    remaining = subset_size
68
    ret = []
69
70
    # classes with fewer data are sampled first; otherwise
71
    # some rare classes may not be selected
72
    for label in sorted(label_linenums, key=lambda x: len(label_linenums[x])):
73
        linenums = label_linenums[label]
74
        label_size = len(linenums)
75
        # at least one instance per class
76
        s = int(min(remaining, max(1, math.ceil(label_size*(float(subset_size)/l)))))
77
        if s == 0:
78
            sys.stderr.write('''\
79
Error: failed to have at least one instance per class
80
    1. You may have regression data.
81
    2. Your classification data is unbalanced or too small.
82
Please use -s 1.
83
''')
84
            sys.exit(-1)
85
        remaining -= s
86
        ret += [linenums[i] for i in random.sample(xrange(label_size), s)]
87
    return sorted(ret)
88
89
def main(argv=sys.argv):
90
    dataset, subset_size, method, subset_file, rest_file = process_options(argv)
91
    #uncomment the following line to fix the random seed
92
    #random.seed(0)
93
    selected_lines = []
94
95
    if method == 0:
96
        selected_lines = stratified_selection(dataset, subset_size)
97
    elif method == 1:
98
        selected_lines = random_selection(dataset, subset_size)
99
100
    #select instances based on selected_lines
101
    dataset = open(dataset,'r')
102
    prev_selected_linenum = -1
103
    for i in xrange(len(selected_lines)):
104
        for cnt in xrange(selected_lines[i]-prev_selected_linenum-1):
105
            line = dataset.readline()
106
            if rest_file:
107
                rest_file.write(line)
108
        subset_file.write(dataset.readline())
109
        prev_selected_linenum = selected_lines[i]
110
    subset_file.close()
111
112
    if rest_file:
113
        for line in dataset:
114
            rest_file.write(line)
115
        rest_file.close()
116
    dataset.close()
117
118
if __name__ == '__main__':
119
    main(sys.argv)
120