[3f1788]: / mlp / predict.py

Download this file

123 lines (100 with data), 3.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python
import datetime
import logging
import sys
sys.path.append('../')
import warnings
warnings.filterwarnings("ignore")
import pickle
import lightgbm
import pandas as pd
import numpy as np
from tqdm import tqdm
from pathlib import Path
from sklearn.feature_extraction.text import TfidfVectorizer
from torch import optim
from skorch import NeuralNetBinaryClassifier
from skorch.toy import MLPModule
from skorch.dataset import CVSplit
from skorch.callbacks import *
from utils.splits import set_group_splits
from args import args
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
sh = logging.StreamHandler()
sh.setFormatter(logging.Formatter('%(levelname)s:%(name)s: %(message)s'))
logger.addHandler(sh)
def run_100(task, task_df, args, threshold):
preds = []
targs = []
probs = []
reduce_lr = LRScheduler(
policy='ReduceLROnPlateau',
mode='min',
factor=0.5,
patience=1,
)
seeds = list(range(args.start_seed, args.start_seed + 100))
for seed in tqdm(seeds, desc=f'{task} Runs'):
# logger.info(f"Spliting with seed {seed}")
checkpoint = Checkpoint(
dirname=args.modeldir/f'{task}_seed_{seed}',
)
df = set_group_splits(task_df.copy(), group_col='hadm_id', seed=seed)
vectorizer = TfidfVectorizer(sublinear_tf=True, ngram_range=(1,2), binary=True, max_features=60_000)
x_train = vectorizer.fit_transform(df.loc[(df['split'] == 'train')]['processed_note']).astype(np.float32)
x_test = vectorizer.transform(df.loc[(df['split'] == 'test')]['processed_note']).astype(np.float32)
x_train = np.asarray(x_train.todense())
x_test = np.asarray(x_test.todense())
vocab_sz = len(vectorizer.vocabulary_)
y_test = df.loc[(df['split'] == 'test')][f'{task}_label'].to_numpy()
targs.append(y_test)
clf = MLPModule(input_units=vocab_sz, output_units=1, hidden_units=args.hidden_dim, num_hidden=1, dropout=args.dropout_p, squeeze_output=True)
net = NeuralNetBinaryClassifier(
clf,
max_epochs=args.max_epochs,
lr=args.lr,
device=args.device,
optimizer=optim.Adam,
optimizer__weight_decay=args.wd,
batch_size=args.batch_size,
verbose=1,
callbacks=[EarlyStopping, ProgressBar, checkpoint, reduce_lr],
train_split=CVSplit(cv=0.15, stratified=True),
iterator_train__shuffle=True,
threshold=threshold,
)
net.set_params(callbacks__valid_acc=None);
net.initialize()
net.load_params(checkpoint=checkpoint)
pos_prob = net.predict_proba(x_test)
probs.append(pos_prob)
y_pred = net.predict(x_test)
preds.append(y_pred)
with open(args.workdir/f'{task}_preds.pkl', 'wb') as f:
pickle.dump(targs, f)
pickle.dump(preds, f)
pickle.dump(probs, f)
if __name__=='__main__':
if len(sys.argv) != 2:
logger.error(f"Usage: {sys.argv[0]} task_name (ia|ps)")
sys.exit(1)
task = sys.argv[1]
if task not in ['ia', 'ps']:
logger.error("Task values are either ia (imminent admission) or ps (prolonged stay)")
sys.exit(1)
args.modeldir = args.workdir/'models'
ori_df = pd.read_csv(args.dataset_csv, usecols=args.cols, parse_dates=args.dates)
if task == 'ia':
task_df = ori_df.loc[(ori_df['imminent_adm_label'] != -1)][args.imminent_adm_cols].reset_index(drop=True)
prefix = 'imminent_adm'
threshold = args.ia_thresh
if task == 'ps':
task_df = ps_df = ori_df.loc[(ori_df['chartinterval'] != 0)][args.prolonged_stay_cols].reset_index(drop=True)
prefix = 'prolonged_stay'
threshold = args.ps_thresh
logger.info(f"Running 100 seed test run for task {task}")
t1 = datetime.datetime.now()
run_100(prefix, task_df, args, threshold)
dt = datetime.datetime.now() - t1
logger.info(f"100 seed test run completed. Took {dt.days} days, {dt.seconds//3600} hours, and {(dt.seconds//60)%60} minutes")