|
a |
|
b/analysis/ml/RandomForestClassifier.py |
|
|
1 |
import multiprocessing |
|
|
2 |
|
|
|
3 |
if __name__ == '__main__': |
|
|
4 |
multiprocessing.set_start_method('forkserver') |
|
|
5 |
import sys |
|
|
6 |
from sklearn.ensemble import RandomForestClassifier |
|
|
7 |
import pdb |
|
|
8 |
from evaluate_model import evaluate_model |
|
|
9 |
import numpy as np |
|
|
10 |
|
|
|
11 |
dataset = sys.argv[1] |
|
|
12 |
save_file = sys.argv[2] |
|
|
13 |
random_seed = int(sys.argv[3]) |
|
|
14 |
rare = eval(sys.argv[4]) |
|
|
15 |
|
|
|
16 |
# Read the data set into meory |
|
|
17 |
# parameter variation |
|
|
18 |
hyper_params = [{ |
|
|
19 |
'n_estimators': [500], |
|
|
20 |
'min_samples_leaf': np.logspace(-4,-1,4), |
|
|
21 |
'max_features': ('sqrt','log2',None), |
|
|
22 |
}] |
|
|
23 |
|
|
|
24 |
# hyper_params = { |
|
|
25 |
# 'n_estimators': [100,500], |
|
|
26 |
# 'criterion': ('gini',) |
|
|
27 |
# } |
|
|
28 |
|
|
|
29 |
# create the classifier |
|
|
30 |
clf = RandomForestClassifier(class_weight='balanced',n_jobs=1) |
|
|
31 |
|
|
|
32 |
# evaluate the model |
|
|
33 |
evaluate_model(dataset, save_file, random_seed, clf, 'RF', hyper_params,False,rare=rare) |