|
a |
|
b/train.py |
|
|
1 |
import argparse |
|
|
2 |
import torch |
|
|
3 |
import numpy as np |
|
|
4 |
import torch.nn as nn |
|
|
5 |
from torch import optim |
|
|
6 |
from torch.utils.data import DataLoader |
|
|
7 |
from collections import Counter |
|
|
8 |
import pickle |
|
|
9 |
from tqdm import tqdm |
|
|
10 |
from datetime import datetime |
|
|
11 |
from model import VariationalGNN |
|
|
12 |
from utils import train, evaluate, EHRData, collate_fn |
|
|
13 |
import os |
|
|
14 |
import logging |
|
|
15 |
if torch.cuda.is_available(): |
|
|
16 |
device = 'cuda' |
|
|
17 |
else: |
|
|
18 |
device = 'cpu' |
|
|
19 |
print(device) |
|
|
20 |
|
|
|
21 |
def main(): |
|
|
22 |
parser = argparse.ArgumentParser(description='configuraitons') |
|
|
23 |
parser.add_argument('--result_path', type=str, default='.', help='output path of model checkpoints') |
|
|
24 |
parser.add_argument('--data_path', type=str, default='./mimc', help='input path of processed dataset') |
|
|
25 |
parser.add_argument('--embedding_size', type=int, default=256, help='embedding size') |
|
|
26 |
parser.add_argument('--num_of_layers', type=int, default=2, help='number of graph layers') |
|
|
27 |
parser.add_argument('--num_of_heads', type=int, default=1, help='number of attention heads') |
|
|
28 |
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') |
|
|
29 |
parser.add_argument('--batch_size', type=int, default=32, help='batch_size') |
|
|
30 |
parser.add_argument('--dropout', type=float, default=0.4, help='dropout') |
|
|
31 |
parser.add_argument('--reg', type=str, default="True", help='regularization') |
|
|
32 |
parser.add_argument('--lbd', type=int, default=1.0, help='regularization') |
|
|
33 |
|
|
|
34 |
args = parser.parse_args() |
|
|
35 |
result_path = args.result_path |
|
|
36 |
data_path = args.data_path |
|
|
37 |
in_feature = args.embedding_size |
|
|
38 |
out_feature =args.embedding_size |
|
|
39 |
n_layers = args.num_of_layers - 1 |
|
|
40 |
lr = args.lr |
|
|
41 |
args.reg = (args.reg == "True") |
|
|
42 |
n_heads = args.num_of_heads |
|
|
43 |
dropout = args.dropout |
|
|
44 |
alpha = 0.1 |
|
|
45 |
BATCH_SIZE = args.batch_size |
|
|
46 |
number_of_epochs = 50 |
|
|
47 |
eval_freq = 1000 |
|
|
48 |
|
|
|
49 |
# Load data |
|
|
50 |
train_x, train_y = pickle.load(open(data_path + 'train_csr.pkl', 'rb')) |
|
|
51 |
val_x, val_y = pickle.load(open(data_path + 'validation_csr.pkl', 'rb')) |
|
|
52 |
test_x, test_y = pickle.load(open(data_path + 'test_csr.pkl', 'rb')) |
|
|
53 |
train_upsampling = np.concatenate((np.arange(len(train_y)), np.repeat(np.where(train_y == 1)[0], 1))) |
|
|
54 |
train_x = train_x[train_upsampling] |
|
|
55 |
train_y = train_y[train_upsampling] |
|
|
56 |
|
|
|
57 |
# Create result root |
|
|
58 |
s = datetime.now().strftime('%Y%m%d%H%M%S') |
|
|
59 |
result_root = '%s/lr_%s-input_%s-output_%s-dropout_%s'%(result_path, lr, in_feature, out_feature, dropout) |
|
|
60 |
if not os.path.exists(result_root): |
|
|
61 |
os.mkdir(result_root) |
|
|
62 |
for handler in logging.root.handlers[:]: |
|
|
63 |
logging.root.removeHandler(handler) |
|
|
64 |
logging.basicConfig(filename='%s/train.log' % result_root, format='%(asctime)s %(message)s', level=logging.INFO) |
|
|
65 |
logging.info("Time:%s" %(s)) |
|
|
66 |
|
|
|
67 |
# initialize models |
|
|
68 |
num_of_nodes = train_x.shape[1] + 1 |
|
|
69 |
device_ids = range(torch.cuda.device_count()) |
|
|
70 |
# eICU has 1 feature on previous readmission that we didn't include in the graph |
|
|
71 |
model = VariationalGNN(in_feature, out_feature, num_of_nodes, n_heads, n_layers, |
|
|
72 |
dropout=dropout, alpha=alpha, variational=args.reg, none_graph_features=0).to(device) |
|
|
73 |
model = nn.DataParallel(model, device_ids=device_ids) |
|
|
74 |
val_loader = DataLoader(dataset=EHRData(val_x, val_y), batch_size=BATCH_SIZE, |
|
|
75 |
collate_fn=collate_fn, num_workers=torch.cuda.device_count(), shuffle=False) |
|
|
76 |
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=1e-8) |
|
|
77 |
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) |
|
|
78 |
|
|
|
79 |
# Train models |
|
|
80 |
for epoch in range(number_of_epochs): |
|
|
81 |
print("Learning rate:{}".format(optimizer.param_groups[0]['lr'])) |
|
|
82 |
ratio = Counter(train_y) |
|
|
83 |
train_loader = DataLoader(dataset=EHRData(train_x, train_y), batch_size=BATCH_SIZE, |
|
|
84 |
collate_fn=collate_fn, num_workers=torch.cuda.device_count(), shuffle=True) |
|
|
85 |
pos_weight = torch.ones(1).float().to(device) * (ratio[False] / ratio[True]) |
|
|
86 |
criterion = nn.BCEWithLogitsLoss(reduction="sum", pos_weight=pos_weight) |
|
|
87 |
t = tqdm(iter(train_loader), leave=False, total=len(train_loader)) |
|
|
88 |
model.train() |
|
|
89 |
total_loss = np.zeros(3) |
|
|
90 |
for idx, batch_data in enumerate(t): |
|
|
91 |
loss, kld, bce = train(batch_data, model, optimizer, criterion, args.lbd, 5) |
|
|
92 |
total_loss += np.array([loss, bce, kld]) |
|
|
93 |
if idx % eval_freq == 0 and idx > 0: |
|
|
94 |
torch.save(model.state_dict(), "{}/parameter{}_{}".format(result_root, epoch, idx)) |
|
|
95 |
val_auprc, _ = evaluate(model, val_loader, len(val_y)) |
|
|
96 |
logging.info('epoch:%d AUPRC:%f; loss: %.4f, bce: %.4f, kld: %.4f' % |
|
|
97 |
(epoch + 1, val_auprc, total_loss[0]/idx, total_loss[1]/idx, total_loss[2]/idx)) |
|
|
98 |
print('epoch:%d AUPRC:%f; loss: %.4f, bce: %.4f, kld: %.4f' % |
|
|
99 |
(epoch + 1, val_auprc, total_loss[0]/idx, total_loss[1]/idx, total_loss[2]/idx)) |
|
|
100 |
if idx % 50 == 0 and idx > 0: |
|
|
101 |
t.set_description('[epoch:%d] loss: %.4f, bce: %.4f, kld: %.4f' % |
|
|
102 |
(epoch + 1, total_loss[0]/idx, total_loss[1]/idx, total_loss[2]/idx)) |
|
|
103 |
t.refresh() |
|
|
104 |
scheduler.step() |
|
|
105 |
|
|
|
106 |
|
|
|
107 |
if __name__ == '__main__': |
|
|
108 |
main() |