Diff of /train.py [000000] .. [0d4320]

Switch to unified view

a b/train.py
1
import argparse
2
import torch
3
import numpy as np
4
import torch.nn as nn
5
from torch import optim
6
from torch.utils.data import DataLoader
7
from collections import Counter
8
import pickle
9
from tqdm import tqdm
10
from datetime import datetime
11
from model import VariationalGNN
12
from utils import train, evaluate, EHRData, collate_fn
13
import os
14
import logging
15
if torch.cuda.is_available():
16
    device = 'cuda'
17
else:
18
    device = 'cpu'
19
print(device)
20
21
def main():
22
    parser = argparse.ArgumentParser(description='configuraitons')
23
    parser.add_argument('--result_path', type=str, default='.', help='output path of model checkpoints')
24
    parser.add_argument('--data_path', type=str, default='./mimc', help='input path of processed dataset')
25
    parser.add_argument('--embedding_size', type=int, default=256, help='embedding size')
26
    parser.add_argument('--num_of_layers', type=int, default=2, help='number of graph layers')
27
    parser.add_argument('--num_of_heads', type=int, default=1, help='number of attention heads')
28
    parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
29
    parser.add_argument('--batch_size', type=int, default=32, help='batch_size')
30
    parser.add_argument('--dropout', type=float, default=0.4, help='dropout')
31
    parser.add_argument('--reg', type=str, default="True", help='regularization')
32
    parser.add_argument('--lbd', type=int, default=1.0, help='regularization')
33
34
    args = parser.parse_args()
35
    result_path = args.result_path
36
    data_path = args.data_path
37
    in_feature = args.embedding_size
38
    out_feature =args.embedding_size
39
    n_layers = args.num_of_layers - 1
40
    lr = args.lr
41
    args.reg = (args.reg == "True")
42
    n_heads = args.num_of_heads
43
    dropout = args.dropout
44
    alpha = 0.1
45
    BATCH_SIZE = args.batch_size
46
    number_of_epochs = 50
47
    eval_freq = 1000
48
49
    # Load data
50
    train_x, train_y = pickle.load(open(data_path + 'train_csr.pkl', 'rb'))
51
    val_x, val_y = pickle.load(open(data_path + 'validation_csr.pkl', 'rb'))
52
    test_x, test_y = pickle.load(open(data_path + 'test_csr.pkl', 'rb'))
53
    train_upsampling = np.concatenate((np.arange(len(train_y)), np.repeat(np.where(train_y == 1)[0], 1)))
54
    train_x = train_x[train_upsampling]
55
    train_y = train_y[train_upsampling]
56
57
    # Create result root
58
    s = datetime.now().strftime('%Y%m%d%H%M%S')
59
    result_root = '%s/lr_%s-input_%s-output_%s-dropout_%s'%(result_path, lr, in_feature, out_feature, dropout)
60
    if not os.path.exists(result_root):
61
        os.mkdir(result_root)
62
    for handler in logging.root.handlers[:]:
63
        logging.root.removeHandler(handler)
64
    logging.basicConfig(filename='%s/train.log' % result_root, format='%(asctime)s %(message)s', level=logging.INFO)
65
    logging.info("Time:%s" %(s))
66
67
    # initialize models
68
    num_of_nodes = train_x.shape[1] + 1
69
    device_ids = range(torch.cuda.device_count())
70
    # eICU has 1 feature on previous readmission that we didn't include in the graph
71
    model = VariationalGNN(in_feature, out_feature, num_of_nodes, n_heads, n_layers,
72
                           dropout=dropout, alpha=alpha, variational=args.reg, none_graph_features=0).to(device)
73
    model = nn.DataParallel(model, device_ids=device_ids)
74
    val_loader = DataLoader(dataset=EHRData(val_x, val_y), batch_size=BATCH_SIZE,
75
                            collate_fn=collate_fn, num_workers=torch.cuda.device_count(), shuffle=False)
76
    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=1e-8)
77
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
78
79
    # Train models
80
    for epoch in range(number_of_epochs):
81
        print("Learning rate:{}".format(optimizer.param_groups[0]['lr']))
82
        ratio = Counter(train_y)
83
        train_loader = DataLoader(dataset=EHRData(train_x, train_y), batch_size=BATCH_SIZE,
84
                                  collate_fn=collate_fn, num_workers=torch.cuda.device_count(), shuffle=True)
85
        pos_weight = torch.ones(1).float().to(device) * (ratio[False] / ratio[True])
86
        criterion = nn.BCEWithLogitsLoss(reduction="sum", pos_weight=pos_weight)
87
        t = tqdm(iter(train_loader), leave=False, total=len(train_loader))
88
        model.train()
89
        total_loss = np.zeros(3)
90
        for idx, batch_data in enumerate(t):
91
            loss, kld, bce = train(batch_data, model, optimizer, criterion, args.lbd, 5)
92
            total_loss += np.array([loss, bce, kld])
93
            if idx % eval_freq == 0 and idx > 0:
94
                torch.save(model.state_dict(), "{}/parameter{}_{}".format(result_root, epoch, idx))
95
                val_auprc, _ = evaluate(model, val_loader, len(val_y))
96
                logging.info('epoch:%d AUPRC:%f; loss: %.4f, bce: %.4f, kld: %.4f' %
97
                             (epoch + 1, val_auprc, total_loss[0]/idx, total_loss[1]/idx, total_loss[2]/idx))
98
                print('epoch:%d AUPRC:%f; loss: %.4f, bce: %.4f, kld: %.4f' %
99
                      (epoch + 1, val_auprc, total_loss[0]/idx, total_loss[1]/idx, total_loss[2]/idx))
100
            if idx % 50 == 0 and idx > 0:
101
                t.set_description('[epoch:%d] loss: %.4f, bce: %.4f, kld: %.4f' %
102
                                  (epoch + 1, total_loss[0]/idx, total_loss[1]/idx, total_loss[2]/idx))
103
                t.refresh()
104
        scheduler.step()
105
106
107
if __name__ == '__main__':
108
    main()