Diff of /lr/param_search.py [000000] .. [3f1788]

Switch to side-by-side view

--- a
+++ b/lr/param_search.py
@@ -0,0 +1,74 @@
+#!/usr/bin/env python
+
+import logging
+import datetime
+import sys
+import json
+import warnings
+
+sys.path.append('../')
+warnings.filterwarnings("ignore", category=UserWarning)
+
+import pandas as pd
+import numpy as np
+
+from scipy import stats
+from sklearn.feature_extraction.text import TfidfVectorizer
+from sklearn.model_selection import RandomizedSearchCV
+from sklearn.linear_model import LogisticRegression
+
+from utils.splits import set_group_splits
+from args import args
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+sh = logging.StreamHandler()
+sh.setFormatter(logging.Formatter('%(levelname)s:%(name)s: %(message)s'))
+logger.addHandler(sh)
+
+if __name__ == '__main__':
+  if len(sys.argv) != 2:
+    logger.error(f"Usage: {sys.argv[0]} task_name (ia|ps)")
+    sys.exit(1)
+
+  task = sys.argv[1]
+  ori_df = pd.read_csv(args.dataset_csv, usecols=args.cols, parse_dates=args.dates)
+  if task == 'ia':
+    logger.info(f"Running hyperparameter search for Imminent Admission Prediction task")
+    task_df = ori_df.loc[(ori_df['imminent_adm_label'] != -1)][args.imminent_adm_cols].reset_index(drop=True)
+    label = 'imminent_adm_label'
+  if task == 'ps':
+    logger.info(f"Running hyperparameter search for Prolonged Stay Prediction task ")
+    task_df = ori_df[args.prolonged_stay_cols].copy()
+    label = 'prolonged_stay_label'
+
+  df = set_group_splits(task_df.copy(), group_col='hadm_id', seed=42)
+  vectorizer = TfidfVectorizer(min_df=args.min_freq, analyzer=str.split, sublinear_tf=True, ngram_range=(2,2))
+
+  x_train = vectorizer.fit_transform(df.loc[(df['split'] == 'train')]['processed_note'])
+  x_test = vectorizer.transform(df.loc[(df['split'] == 'test')]['processed_note'])
+  y_train = df.loc[(df['split'] == 'train')][label].to_numpy()
+  y_test = df.loc[(df['split'] == 'test')][label].to_numpy()
+
+  clf_params = {
+    'solver': 'liblinear',
+    'multi_class': 'ovr',
+  }
+
+  clf = LogisticRegression(**clf_params)
+
+  param_space = {
+    'C': stats.uniform(0.1, 2),
+    'dual': [True, False],
+    'class_weight': ['balanced', None],
+    'max_iter': stats.randint(100, 1000),
+  }
+  random_search = RandomizedSearchCV(clf, param_space, n_iter=200, cv=10, iid=False, verbose=1, n_jobs=32)
+
+  logger.info("Starting random search...")
+  t1 = datetime.datetime.now()
+  random_search.fit(x_train, y_train)
+  dt = datetime.datetime.now() - t1
+  params_file = args.workdir/f'{task}_best_params.json'
+  logger.info(f"Random search completed. Took {dt.days}, {dt.seconds//3600} hours and, {(dt.seconds//60)%60} minutes. Writing best params to {params_file}")
+  json.dump(random_search.best_params_, params_file.open('w'))