a b/mlp/train_100.py
1
#!/usr/bin/env python
2
3
import datetime
4
import logging
5
import sys
6
sys.path.append('../')
7
8
import warnings
9
warnings.filterwarnings("ignore")
10
11
import pickle
12
import lightgbm
13
import pandas as pd
14
import numpy as np
15
from tqdm import tqdm
16
from pathlib import Path
17
from sklearn.feature_extraction.text import TfidfVectorizer
18
19
from torch import optim
20
21
from skorch import NeuralNetBinaryClassifier
22
from skorch.toy import MLPModule
23
from skorch.dataset import CVSplit
24
from skorch.callbacks import *
25
26
from utils.splits import set_group_splits
27
from args import args
28
29
logger = logging.getLogger(__name__)
30
logger.setLevel(logging.INFO)
31
sh = logging.StreamHandler()
32
sh.setFormatter(logging.Formatter('%(levelname)s:%(name)s: %(message)s'))
33
logger.addHandler(sh)
34
35
def run_100(task, task_df, args, threshold):
36
  reduce_lr = LRScheduler(
37
    policy='ReduceLROnPlateau',
38
    mode='min',
39
    factor=0.5,
40
    patience=1,
41
  )
42
43
  seeds = list(range(args.start_seed, args.start_seed + 100))
44
  for seed in tqdm(seeds, desc=f'{task} Runs'):
45
    logger.info(f"Spliting with seed {seed}")
46
    checkpoint = Checkpoint(
47
      dirname=args.modeldir/f'{task}_seed_{seed}',
48
    )
49
    df = set_group_splits(task_df.copy(), group_col='hadm_id', seed=seed)
50
    vectorizer = TfidfVectorizer(sublinear_tf=True, ngram_range=(1,2), binary=True, max_features=60_000)
51
52
    x_train = vectorizer.fit_transform(df.loc[(df['split'] == 'train')]['processed_note']).astype(np.float32)
53
    x_test = vectorizer.transform(df.loc[(df['split'] == 'test')]['processed_note']).astype(np.float32)
54
55
    x_train = np.asarray(x_train.todense())
56
    x_test = np.asarray(x_test.todense())
57
    vocab_sz = len(vectorizer.vocabulary_)
58
59
    y_train = df.loc[(df['split'] == 'train')][f'{task}_label'].to_numpy()
60
    y_test = df.loc[(df['split'] == 'test')][f'{task}_label'].to_numpy()
61
62
    clf = MLPModule(input_units=vocab_sz, output_units=1, hidden_units=args.hidden_dim, num_hidden=1, dropout=args.dropout_p, squeeze_output=True)
63
64
    net = NeuralNetBinaryClassifier(
65
      clf,
66
      max_epochs=args.max_epochs,
67
      lr=args.lr,
68
      device=args.device,
69
      optimizer=optim.Adam,
70
      optimizer__weight_decay=args.wd,
71
      batch_size=args.batch_size,
72
      verbose=1,
73
      callbacks=[EarlyStopping, ProgressBar, checkpoint, reduce_lr],
74
      train_split=CVSplit(cv=0.15, stratified=True),
75
      iterator_train__shuffle=True,
76
      threshold=threshold,
77
    )
78
    net.set_params(callbacks__valid_acc=None)
79
    net.fit(x_train, y_train.astype(np.float32))
80
81
if __name__=='__main__':
82
  if len(sys.argv) != 2:
83
    logger.error(f"Usage: {sys.argv[0]} task_name (ia|ps)")
84
    sys.exit(1)
85
86
  task = sys.argv[1]
87
  if task not in ['ia', 'ps']:
88
    logger.error("Task values are either ia (imminent admission) or ps (prolonged stay)")
89
    sys.exit(1)
90
91
  args.modeldir = args.workdir/'models'
92
  ori_df = pd.read_csv(args.dataset_csv, usecols=args.cols, parse_dates=args.dates)
93
  if task == 'ia':
94
    task_df = ori_df.loc[(ori_df['imminent_adm_label'] != -1)][args.imminent_adm_cols].reset_index(drop=True)
95
    prefix = 'imminent_adm'
96
    threshold = args.ia_thresh
97
  if task == 'ps':
98
    task_df = ps_df = ori_df.loc[(ori_df['chartinterval'] != 0)][args.prolonged_stay_cols].reset_index(drop=True)
99
    prefix = 'prolonged_stay'
100
    threshold = args.ps_thresh
101
102
  logger.info(f"Running 100 seed test run for task {task}")
103
  t1 = datetime.datetime.now()
104
  run_100(prefix, task_df, args, threshold)
105
  dt = datetime.datetime.now() - t1
106
  logger.info(f"100 seed test run completed. Took {dt.days} days, {dt.seconds//3600} hours, and {(dt.seconds//60)%60} minutes")