|
a |
|
b/BraTs18Challege/dataprocess/data/subset.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
|
|
|
3 |
import os, sys, math, random |
|
|
4 |
from collections import defaultdict |
|
|
5 |
|
|
|
6 |
if sys.version_info[0] >= 3: |
|
|
7 |
xrange = range |
|
|
8 |
|
|
|
9 |
def exit_with_help(argv): |
|
|
10 |
print("""\ |
|
|
11 |
Usage: {0} [options] dataset subset_size [output1] [output2] |
|
|
12 |
|
|
|
13 |
This script randomly selects a subset of the dataset. |
|
|
14 |
|
|
|
15 |
options: |
|
|
16 |
-s method : method of selection (default 0) |
|
|
17 |
0 -- stratified selection (classification only) |
|
|
18 |
1 -- random selection |
|
|
19 |
|
|
|
20 |
output1 : the subset (optional) |
|
|
21 |
output2 : rest of the data (optional) |
|
|
22 |
If output1 is omitted, the subset will be printed on the screen.""".format(argv[0])) |
|
|
23 |
exit(1) |
|
|
24 |
|
|
|
25 |
def process_options(argv): |
|
|
26 |
argc = len(argv) |
|
|
27 |
if argc < 3: |
|
|
28 |
exit_with_help(argv) |
|
|
29 |
|
|
|
30 |
# default method is stratified selection |
|
|
31 |
method = 0 |
|
|
32 |
subset_file = sys.stdout |
|
|
33 |
rest_file = None |
|
|
34 |
|
|
|
35 |
i = 1 |
|
|
36 |
while i < argc: |
|
|
37 |
if argv[i][0] != "-": |
|
|
38 |
break |
|
|
39 |
if argv[i] == "-s": |
|
|
40 |
i = i + 1 |
|
|
41 |
method = int(argv[i]) |
|
|
42 |
if method not in [0,1]: |
|
|
43 |
print("Unknown selection method {0}".format(method)) |
|
|
44 |
exit_with_help(argv) |
|
|
45 |
i = i + 1 |
|
|
46 |
|
|
|
47 |
dataset = argv[i] |
|
|
48 |
subset_size = int(argv[i+1]) |
|
|
49 |
if i+2 < argc: |
|
|
50 |
subset_file = open(argv[i+2],'w') |
|
|
51 |
if i+3 < argc: |
|
|
52 |
rest_file = open(argv[i+3],'w') |
|
|
53 |
|
|
|
54 |
return dataset, subset_size, method, subset_file, rest_file |
|
|
55 |
|
|
|
56 |
def random_selection(dataset, subset_size): |
|
|
57 |
l = sum(1 for line in open(dataset,'r')) |
|
|
58 |
return sorted(random.sample(xrange(l), subset_size)) |
|
|
59 |
|
|
|
60 |
def stratified_selection(dataset, subset_size): |
|
|
61 |
labels = [line.split(None,1)[0] for line in open(dataset)] |
|
|
62 |
label_linenums = defaultdict(list) |
|
|
63 |
for i, label in enumerate(labels): |
|
|
64 |
label_linenums[label] += [i] |
|
|
65 |
|
|
|
66 |
l = len(labels) |
|
|
67 |
remaining = subset_size |
|
|
68 |
ret = [] |
|
|
69 |
|
|
|
70 |
# classes with fewer data are sampled first; otherwise |
|
|
71 |
# some rare classes may not be selected |
|
|
72 |
for label in sorted(label_linenums, key=lambda x: len(label_linenums[x])): |
|
|
73 |
linenums = label_linenums[label] |
|
|
74 |
label_size = len(linenums) |
|
|
75 |
# at least one instance per class |
|
|
76 |
s = int(min(remaining, max(1, math.ceil(label_size*(float(subset_size)/l))))) |
|
|
77 |
if s == 0: |
|
|
78 |
sys.stderr.write('''\ |
|
|
79 |
Error: failed to have at least one instance per class |
|
|
80 |
1. You may have regression data. |
|
|
81 |
2. Your classification data is unbalanced or too small. |
|
|
82 |
Please use -s 1. |
|
|
83 |
''') |
|
|
84 |
sys.exit(-1) |
|
|
85 |
remaining -= s |
|
|
86 |
ret += [linenums[i] for i in random.sample(xrange(label_size), s)] |
|
|
87 |
return sorted(ret) |
|
|
88 |
|
|
|
89 |
def main(argv=sys.argv): |
|
|
90 |
dataset, subset_size, method, subset_file, rest_file = process_options(argv) |
|
|
91 |
#uncomment the following line to fix the random seed |
|
|
92 |
#random.seed(0) |
|
|
93 |
selected_lines = [] |
|
|
94 |
|
|
|
95 |
if method == 0: |
|
|
96 |
selected_lines = stratified_selection(dataset, subset_size) |
|
|
97 |
elif method == 1: |
|
|
98 |
selected_lines = random_selection(dataset, subset_size) |
|
|
99 |
|
|
|
100 |
#select instances based on selected_lines |
|
|
101 |
dataset = open(dataset,'r') |
|
|
102 |
prev_selected_linenum = -1 |
|
|
103 |
for i in xrange(len(selected_lines)): |
|
|
104 |
for cnt in xrange(selected_lines[i]-prev_selected_linenum-1): |
|
|
105 |
line = dataset.readline() |
|
|
106 |
if rest_file: |
|
|
107 |
rest_file.write(line) |
|
|
108 |
subset_file.write(dataset.readline()) |
|
|
109 |
prev_selected_linenum = selected_lines[i] |
|
|
110 |
subset_file.close() |
|
|
111 |
|
|
|
112 |
if rest_file: |
|
|
113 |
for line in dataset: |
|
|
114 |
rest_file.write(line) |
|
|
115 |
rest_file.close() |
|
|
116 |
dataset.close() |
|
|
117 |
|
|
|
118 |
if __name__ == '__main__': |
|
|
119 |
main(sys.argv) |
|
|
120 |
|