Diff of /rsf/rsf_baseline.py [000000] .. [785f18]

Switch to unified view

a b/rsf/rsf_baseline.py
1
"""
2
Baseline random forest script.
3
Leon Zheng.
4
5
ToDo: choose the right feature and fine tune.
6
"""
7
8
9
import preprocessing
10
from sksurv.ensemble import RandomSurvivalForest
11
from sklearn.model_selection import cross_validate
12
13
"""
14
Reading data
15
"""
16
17
def load_data(features=None):
18
    X_df, y_df, _ = preprocessing.load_owkin_data()
19
    if features != None:
20
        X_df = X_df[features]
21
    X = X_df.to_numpy()
22
    y = preprocessing.y_dataframe_to_rsf_input(y_df)
23
    return X_df, y_df, X, y
24
25
X_df, y_df, X, y = load_data()
26
feature_name = list(X_df.columns.values)
27
28
"""
29
Train model
30
"""
31
params = {'min_samples_leaf': 1, 'min_samples_split': 10, 'n_estimators': 10}
32
rsf = RandomSurvivalForest(n_estimators = params['n_estimators'],
33
                           min_samples_split = params['min_samples_split'],
34
                           min_samples_leaf = params['min_samples_leaf'],
35
                           max_features="sqrt",
36
                           n_jobs=-1
37
                           )
38
print(cross_validate(rsf, X, y, cv=5))