Diff of /rf/param_search.py [000000] .. [3f1788]

Switch to unified view

a b/rf/param_search.py
1
#!/usr/bin/env python
2
3
import logging
4
import datetime
5
import sys
6
import json
7
import warnings
8
9
sys.path.append('../')
10
warnings.filterwarnings("ignore", category=UserWarning)
11
12
import pandas as pd
13
import numpy as np
14
15
from scipy import stats
16
from sklearn.feature_extraction.text import TfidfVectorizer
17
from sklearn.model_selection import RandomizedSearchCV
18
from sklearn.ensemble import RandomForestClassifier
19
20
from utils.splits import set_group_splits
21
from args import args
22
23
logger = logging.getLogger(__name__)
24
logger.setLevel(logging.INFO)
25
sh = logging.StreamHandler()
26
sh.setFormatter(logging.Formatter('%(levelname)s:%(name)s: %(message)s'))
27
logger.addHandler(sh)
28
29
if __name__ == '__main__':
30
  if len(sys.argv) != 2:
31
    logger.error(f"Usage: {sys.argv[0]} task_name (ia|ps)")
32
    sys.exit(1)
33
34
  task = sys.argv[1]
35
  ori_df = pd.read_csv(args.dataset_csv, usecols=args.cols, parse_dates=args.dates)
36
  if task == 'ia':
37
    logger.info(f"Running hyperparameter search for Imminent Admission Prediction task")
38
    task_df = ori_df.loc[(ori_df['imminent_adm_label'] != -1)][args.imminent_adm_cols].reset_index(drop=True)
39
    label = 'imminent_adm_label'
40
  if task == 'ps':
41
    logger.info(f"Running hyperparameter search for Prolonged Stay Prediction task ")
42
    task_df = ori_df[args.prolonged_stay_cols].copy()
43
    label = 'prolonged_stay_label'
44
45
  df = set_group_splits(task_df.copy(), group_col='hadm_id', seed=42)
46
  vectorizer = TfidfVectorizer(min_df=args.min_freq, analyzer=str.split, sublinear_tf=True, ngram_range=(2,2))
47
48
  x_train = vectorizer.fit_transform(df.loc[(df['split'] == 'train')]['processed_note'])
49
  x_test = vectorizer.transform(df.loc[(df['split'] == 'test')]['processed_note'])
50
  y_train = df.loc[(df['split'] == 'train')][label].to_numpy()
51
  y_test = df.loc[(df['split'] == 'test')][label].to_numpy()
52
53
  clf = RandomForestClassifier()
54
55
  param_space = {
56
    'n_estimators': stats.randint(100, 200),
57
    'class_weight': ['balanced', 'balanced_subsample', None],
58
    'criterion': ['gini', 'entropy'],
59
    'max_depth': [2, 3, 4, None],
60
    'min_samples_leaf': stats.randint(2, 8),
61
    'min_samples_split': stats.randint(2, 8),
62
    'max_features': stats.uniform(0.1, 0.5),
63
    'oob_score': [True, False],
64
  }
65
  random_search = RandomizedSearchCV(clf, param_space, n_iter=200, cv=5, iid=False, verbose=1, n_jobs=32)
66
67
  logger.info("Starting random search...")
68
  t1 = datetime.datetime.now()
69
  random_search.fit(x_train, y_train)
70
  dt = datetime.datetime.now() - t1
71
  params_file = args.workdir/f'{task}_best_params.json'
72
  logger.info(f"Random search completed. Took {dt.days} days, {dt.seconds//3600} hours, and {(dt.seconds//60)%60} minutes. Writing best params to {params_file}")
73
  json.dump(random_search.best_params_, params_file.open('w'))