a b/gbm/param_search.py
1
#!/usr/bin/env python
2
3
import logging
4
import datetime
5
import sys
6
import json
7
import warnings
8
9
sys.path.append('../')
10
warnings.filterwarnings("ignore")
11
12
import pandas as pd
13
14
from scipy import stats
15
from sklearn.feature_extraction.text import TfidfVectorizer
16
from sklearn.model_selection import RandomizedSearchCV
17
import lightgbm
18
19
from utils.splits import set_group_splits
20
from args import args
21
22
logger = logging.getLogger(__name__)
23
logger.setLevel(logging.INFO)
24
sh = logging.StreamHandler()
25
sh.setFormatter(logging.Formatter('%(levelname)s:%(name)s: %(message)s'))
26
logger.addHandler(sh)
27
28
if __name__ == '__main__':
29
  if len(sys.argv) != 2:
30
    logger.error(f"Usage: {sys.argv[0]} task_name (ia|ps)")
31
    sys.exit(1)
32
33
  task = sys.argv[1]
34
  if task not in ['ia', 'ps']:
35
    logger.error("Task values are either ia (imminent admission) or ps (prolonged stay)")
36
    sys.exit(1)
37
38
  ori_df = pd.read_csv(args.dataset_csv, usecols=args.cols, parse_dates=args.dates)
39
  if task == 'ia':
40
    logger.info(f"Running hyperparameter search for Imminent Admission Prediction task")
41
    task_df = ori_df.loc[(ori_df['imminent_adm_label'] != -1)][args.imminent_adm_cols].reset_index(drop=True)
42
    label = 'imminent_adm_label'
43
  if task == 'ps':
44
    logger.info(f"Running hyperparameter search for Prolonged Stay Prediction task ")
45
    task_df = ori_df.loc[(ori_df['chartinterval'] != 0)][args.prolonged_stay_cols].reset_index(drop=True)
46
    label = 'prolonged_stay_label'
47
48
  df = set_group_splits(task_df.copy(), group_col='hadm_id', seed=643)
49
  vectorizer = TfidfVectorizer(min_df=args.min_freq, analyzer=str.split, sublinear_tf=True, ngram_range=(2,2))
50
51
  x_train = vectorizer.fit_transform(df.loc[(df['split'] == 'train')]['processed_note'])
52
  y_train = df.loc[(df['split'] == 'train')][label].to_numpy()
53
54
  clf_params = {
55
      'objective': 'binary',
56
      'metric': 'binary_logloss',
57
      'is_unbalance': True,
58
  }
59
60
  clf = lightgbm.LGBMClassifier(**clf_params)
61
62
  param_space = {
63
    'num_leaves': stats.randint(30, 60),
64
    'bagging_fraction': stats.uniform(0.2, 0.7),
65
    'learning_rate': stats.uniform(0.1, 0.9),
66
    'min_data_in_leaf': stats.randint(2, 20),
67
    'max_bin': stats.randint(3, 20),
68
    'boosting': ['gbdt', 'dart'],
69
    'bagging_freq': stats.randint(3, 31),
70
    'max_depth': stats.randint(0, 11),
71
    'feature_fraction': stats.uniform(0.2, 0.7),
72
    'lambda_l1': stats.uniform(0, 10),
73
    'num_iterations': stats.randint(100, 200),
74
  }
75
76
  random_search = RandomizedSearchCV(clf, param_space, n_iter=200, cv=5, iid=False, verbose=1, n_jobs=32)
77
78
  logger.info("Starting random search...")
79
  t1 = datetime.datetime.now()
80
  random_search.fit(x_train, y_train)
81
  dt = datetime.datetime.now() - t1
82
  params_file = args.workdir/f'{task}_best_params.json'
83
  logger.info(f"Random search completed. Took {dt.days} days, {dt.seconds//3600} hours, and {(dt.seconds//60)%60} minutes. Writing best params to {params_file}")
84
  json.dump(random_search.best_params_, params_file.open('w'))