|
a |
|
b/mlp/train_100.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
|
|
|
3 |
import datetime |
|
|
4 |
import logging |
|
|
5 |
import sys |
|
|
6 |
sys.path.append('../') |
|
|
7 |
|
|
|
8 |
import warnings |
|
|
9 |
warnings.filterwarnings("ignore") |
|
|
10 |
|
|
|
11 |
import pickle |
|
|
12 |
import lightgbm |
|
|
13 |
import pandas as pd |
|
|
14 |
import numpy as np |
|
|
15 |
from tqdm import tqdm |
|
|
16 |
from pathlib import Path |
|
|
17 |
from sklearn.feature_extraction.text import TfidfVectorizer |
|
|
18 |
|
|
|
19 |
from torch import optim |
|
|
20 |
|
|
|
21 |
from skorch import NeuralNetBinaryClassifier |
|
|
22 |
from skorch.toy import MLPModule |
|
|
23 |
from skorch.dataset import CVSplit |
|
|
24 |
from skorch.callbacks import * |
|
|
25 |
|
|
|
26 |
from utils.splits import set_group_splits |
|
|
27 |
from args import args |
|
|
28 |
|
|
|
29 |
logger = logging.getLogger(__name__) |
|
|
30 |
logger.setLevel(logging.INFO) |
|
|
31 |
sh = logging.StreamHandler() |
|
|
32 |
sh.setFormatter(logging.Formatter('%(levelname)s:%(name)s: %(message)s')) |
|
|
33 |
logger.addHandler(sh) |
|
|
34 |
|
|
|
35 |
def run_100(task, task_df, args, threshold): |
|
|
36 |
reduce_lr = LRScheduler( |
|
|
37 |
policy='ReduceLROnPlateau', |
|
|
38 |
mode='min', |
|
|
39 |
factor=0.5, |
|
|
40 |
patience=1, |
|
|
41 |
) |
|
|
42 |
|
|
|
43 |
seeds = list(range(args.start_seed, args.start_seed + 100)) |
|
|
44 |
for seed in tqdm(seeds, desc=f'{task} Runs'): |
|
|
45 |
logger.info(f"Spliting with seed {seed}") |
|
|
46 |
checkpoint = Checkpoint( |
|
|
47 |
dirname=args.modeldir/f'{task}_seed_{seed}', |
|
|
48 |
) |
|
|
49 |
df = set_group_splits(task_df.copy(), group_col='hadm_id', seed=seed) |
|
|
50 |
vectorizer = TfidfVectorizer(sublinear_tf=True, ngram_range=(1,2), binary=True, max_features=60_000) |
|
|
51 |
|
|
|
52 |
x_train = vectorizer.fit_transform(df.loc[(df['split'] == 'train')]['processed_note']).astype(np.float32) |
|
|
53 |
x_test = vectorizer.transform(df.loc[(df['split'] == 'test')]['processed_note']).astype(np.float32) |
|
|
54 |
|
|
|
55 |
x_train = np.asarray(x_train.todense()) |
|
|
56 |
x_test = np.asarray(x_test.todense()) |
|
|
57 |
vocab_sz = len(vectorizer.vocabulary_) |
|
|
58 |
|
|
|
59 |
y_train = df.loc[(df['split'] == 'train')][f'{task}_label'].to_numpy() |
|
|
60 |
y_test = df.loc[(df['split'] == 'test')][f'{task}_label'].to_numpy() |
|
|
61 |
|
|
|
62 |
clf = MLPModule(input_units=vocab_sz, output_units=1, hidden_units=args.hidden_dim, num_hidden=1, dropout=args.dropout_p, squeeze_output=True) |
|
|
63 |
|
|
|
64 |
net = NeuralNetBinaryClassifier( |
|
|
65 |
clf, |
|
|
66 |
max_epochs=args.max_epochs, |
|
|
67 |
lr=args.lr, |
|
|
68 |
device=args.device, |
|
|
69 |
optimizer=optim.Adam, |
|
|
70 |
optimizer__weight_decay=args.wd, |
|
|
71 |
batch_size=args.batch_size, |
|
|
72 |
verbose=1, |
|
|
73 |
callbacks=[EarlyStopping, ProgressBar, checkpoint, reduce_lr], |
|
|
74 |
train_split=CVSplit(cv=0.15, stratified=True), |
|
|
75 |
iterator_train__shuffle=True, |
|
|
76 |
threshold=threshold, |
|
|
77 |
) |
|
|
78 |
net.set_params(callbacks__valid_acc=None) |
|
|
79 |
net.fit(x_train, y_train.astype(np.float32)) |
|
|
80 |
|
|
|
81 |
if __name__=='__main__': |
|
|
82 |
if len(sys.argv) != 2: |
|
|
83 |
logger.error(f"Usage: {sys.argv[0]} task_name (ia|ps)") |
|
|
84 |
sys.exit(1) |
|
|
85 |
|
|
|
86 |
task = sys.argv[1] |
|
|
87 |
if task not in ['ia', 'ps']: |
|
|
88 |
logger.error("Task values are either ia (imminent admission) or ps (prolonged stay)") |
|
|
89 |
sys.exit(1) |
|
|
90 |
|
|
|
91 |
args.modeldir = args.workdir/'models' |
|
|
92 |
ori_df = pd.read_csv(args.dataset_csv, usecols=args.cols, parse_dates=args.dates) |
|
|
93 |
if task == 'ia': |
|
|
94 |
task_df = ori_df.loc[(ori_df['imminent_adm_label'] != -1)][args.imminent_adm_cols].reset_index(drop=True) |
|
|
95 |
prefix = 'imminent_adm' |
|
|
96 |
threshold = args.ia_thresh |
|
|
97 |
if task == 'ps': |
|
|
98 |
task_df = ps_df = ori_df.loc[(ori_df['chartinterval'] != 0)][args.prolonged_stay_cols].reset_index(drop=True) |
|
|
99 |
prefix = 'prolonged_stay' |
|
|
100 |
threshold = args.ps_thresh |
|
|
101 |
|
|
|
102 |
logger.info(f"Running 100 seed test run for task {task}") |
|
|
103 |
t1 = datetime.datetime.now() |
|
|
104 |
run_100(prefix, task_df, args, threshold) |
|
|
105 |
dt = datetime.datetime.now() - t1 |
|
|
106 |
logger.info(f"100 seed test run completed. Took {dt.days} days, {dt.seconds//3600} hours, and {(dt.seconds//60)%60} minutes") |