a b/lr/param_search.py
1
#!/usr/bin/env python
2
3
import logging
4
import datetime
5
import sys
6
import json
7
import warnings
8
9
sys.path.append('../')
10
warnings.filterwarnings("ignore", category=UserWarning)
11
12
import pandas as pd
13
import numpy as np
14
15
from scipy import stats
16
from sklearn.feature_extraction.text import TfidfVectorizer
17
from sklearn.model_selection import RandomizedSearchCV
18
from sklearn.linear_model import LogisticRegression
19
20
from utils.splits import set_group_splits
21
from args import args
22
23
logger = logging.getLogger(__name__)
24
logger.setLevel(logging.INFO)
25
sh = logging.StreamHandler()
26
sh.setFormatter(logging.Formatter('%(levelname)s:%(name)s: %(message)s'))
27
logger.addHandler(sh)
28
29
if __name__ == '__main__':
30
  if len(sys.argv) != 2:
31
    logger.error(f"Usage: {sys.argv[0]} task_name (ia|ps)")
32
    sys.exit(1)
33
34
  task = sys.argv[1]
35
  ori_df = pd.read_csv(args.dataset_csv, usecols=args.cols, parse_dates=args.dates)
36
  if task == 'ia':
37
    logger.info(f"Running hyperparameter search for Imminent Admission Prediction task")
38
    task_df = ori_df.loc[(ori_df['imminent_adm_label'] != -1)][args.imminent_adm_cols].reset_index(drop=True)
39
    label = 'imminent_adm_label'
40
  if task == 'ps':
41
    logger.info(f"Running hyperparameter search for Prolonged Stay Prediction task ")
42
    task_df = ori_df[args.prolonged_stay_cols].copy()
43
    label = 'prolonged_stay_label'
44
45
  df = set_group_splits(task_df.copy(), group_col='hadm_id', seed=42)
46
  vectorizer = TfidfVectorizer(min_df=args.min_freq, analyzer=str.split, sublinear_tf=True, ngram_range=(2,2))
47
48
  x_train = vectorizer.fit_transform(df.loc[(df['split'] == 'train')]['processed_note'])
49
  x_test = vectorizer.transform(df.loc[(df['split'] == 'test')]['processed_note'])
50
  y_train = df.loc[(df['split'] == 'train')][label].to_numpy()
51
  y_test = df.loc[(df['split'] == 'test')][label].to_numpy()
52
53
  clf_params = {
54
    'solver': 'liblinear',
55
    'multi_class': 'ovr',
56
  }
57
58
  clf = LogisticRegression(**clf_params)
59
60
  param_space = {
61
    'C': stats.uniform(0.1, 2),
62
    'dual': [True, False],
63
    'class_weight': ['balanced', None],
64
    'max_iter': stats.randint(100, 1000),
65
  }
66
  random_search = RandomizedSearchCV(clf, param_space, n_iter=200, cv=10, iid=False, verbose=1, n_jobs=32)
67
68
  logger.info("Starting random search...")
69
  t1 = datetime.datetime.now()
70
  random_search.fit(x_train, y_train)
71
  dt = datetime.datetime.now() - t1
72
  params_file = args.workdir/f'{task}_best_params.json'
73
  logger.info(f"Random search completed. Took {dt.days}, {dt.seconds//3600} hours and, {(dt.seconds//60)%60} minutes. Writing best params to {params_file}")
74
  json.dump(random_search.best_params_, params_file.open('w'))