Diff of /main.py [000000] .. [dcdaea]

Switch to unified view

a b/main.py
1
# -*- coding: utf-8 -*-
2
'''
3
@time: 2019/7/23 19:42
4
5
@ author: javis
6
'''
7
import torch, time, os, shutil
8
import models, utils
9
import numpy as np
10
import pandas as pd
11
from tensorboard_logger import Logger
12
from torch import nn, optim
13
from torch.utils.data import DataLoader
14
from dataset import ECGDataset
15
from config import config
16
17
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
19
torch.manual_seed(41)
20
torch.cuda.manual_seed(41)
21
22
23
# 保存当前模型的权重,并且更新最佳的模型权重
24
def save_ckpt(state, is_best, model_save_dir):
25
    current_w = os.path.join(model_save_dir, config.current_w)
26
    best_w = os.path.join(model_save_dir, config.best_w)
27
    torch.save(state, current_w)
28
    if is_best: shutil.copyfile(current_w, best_w)
29
30
31
def train_epoch(model, optimizer, criterion, train_dataloader, show_interval=10):
32
    model.train()
33
    f1_meter, loss_meter, it_count = 0, 0, 0
34
    for inputs, target in train_dataloader:
35
        inputs = inputs.to(device)
36
        target = target.to(device)
37
        # zero the parameter gradients
38
        optimizer.zero_grad()
39
        # forward
40
        output = model(inputs)
41
        loss = criterion(output, target)
42
        loss.backward()
43
        optimizer.step()
44
        loss_meter += loss.item()
45
        it_count += 1
46
        f1 = utils.calc_f1(target, torch.sigmoid(output))
47
        f1_meter += f1
48
        if it_count != 0 and it_count % show_interval == 0:
49
            print("%d,loss:%.3e f1:%.3f" % (it_count, loss.item(), f1))
50
    return loss_meter / it_count, f1_meter / it_count
51
52
53
def val_epoch(model, criterion, val_dataloader, threshold=0.5):
54
    model.eval()
55
    f1_meter, loss_meter, it_count = 0, 0, 0
56
    with torch.no_grad():
57
        for inputs, target in val_dataloader:
58
            inputs = inputs.to(device)
59
            target = target.to(device)
60
            output = model(inputs)
61
            loss = criterion(output, target)
62
            loss_meter += loss.item()
63
            it_count += 1
64
            output = torch.sigmoid(output)
65
            f1 = utils.calc_f1(target, output, threshold)
66
            f1_meter += f1
67
    return loss_meter / it_count, f1_meter / it_count
68
69
70
def train(args):
71
    # model
72
    model = getattr(models, config.model_name)()
73
    if args.ckpt and not args.resume:
74
        state = torch.load(args.ckpt, map_location='cpu')
75
        model.load_state_dict(state['state_dict'])
76
        print('train with pretrained weight val_f1', state['f1'])
77
    model = model.to(device)
78
    # data
79
    train_dataset = ECGDataset(data_path=config.train_data, train=True)
80
    train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=6)
81
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
82
    val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, num_workers=4)
83
    print("train_datasize", len(train_dataset), "val_datasize", len(val_dataset))
84
    # optimizer and loss
85
    optimizer = optim.Adam(model.parameters(), lr=config.lr)
86
    w = torch.tensor(train_dataset.wc, dtype=torch.float).to(device)
87
    criterion = utils.WeightedMultilabel(w)
88
    # 模型保存文件夹
89
    model_save_dir = '%s/%s_%s' % (config.ckpt, config.model_name, time.strftime("%Y%m%d%H%M"))
90
    if args.ex: model_save_dir += args.ex
91
    best_f1 = -1
92
    lr = config.lr
93
    start_epoch = 1
94
    stage = 1
95
    # 从上一个断点,继续训练
96
    if args.resume:
97
        if os.path.exists(args.ckpt):  # 这里是存放权重的目录
98
            model_save_dir = args.ckpt
99
            current_w = torch.load(os.path.join(args.ckpt, config.current_w))
100
            best_w = torch.load(os.path.join(model_save_dir, config.best_w))
101
            best_f1 = best_w['loss']
102
            start_epoch = current_w['epoch'] + 1
103
            lr = current_w['lr']
104
            stage = current_w['stage']
105
            model.load_state_dict(current_w['state_dict'])
106
            # 如果中断点恰好为转换stage的点
107
            if start_epoch - 1 in config.stage_epoch:
108
                stage += 1
109
                lr /= config.lr_decay
110
                utils.adjust_learning_rate(optimizer, lr)
111
                model.load_state_dict(best_w['state_dict'])
112
            print("=> loaded checkpoint (epoch {})".format(start_epoch - 1))
113
    logger = Logger(logdir=model_save_dir, flush_secs=2)
114
    # =========>开始训练<=========
115
    for epoch in range(start_epoch, config.max_epoch + 1):
116
        since = time.time()
117
        train_loss, train_f1 = train_epoch(model, optimizer, criterion, train_dataloader, show_interval=100)
118
        val_loss, val_f1 = val_epoch(model, criterion, val_dataloader)
119
        print('#epoch:%02d stage:%d train_loss:%.3e train_f1:%.3f  val_loss:%0.3e val_f1:%.3f time:%s\n'
120
              % (epoch, stage, train_loss, train_f1, val_loss, val_f1, utils.print_time_cost(since)))
121
        logger.log_value('train_loss', train_loss, step=epoch)
122
        logger.log_value('train_f1', train_f1, step=epoch)
123
        logger.log_value('val_loss', val_loss, step=epoch)
124
        logger.log_value('val_f1', val_f1, step=epoch)
125
        state = {"state_dict": model.state_dict(), "epoch": epoch, "loss": val_loss, 'f1': val_f1, 'lr': lr,
126
                 'stage': stage}
127
        save_ckpt(state, best_f1 < val_f1, model_save_dir)
128
        best_f1 = max(best_f1, val_f1)
129
        if epoch in config.stage_epoch:
130
            stage += 1
131
            lr /= config.lr_decay
132
            best_w = os.path.join(model_save_dir, config.best_w)
133
            model.load_state_dict(torch.load(best_w)['state_dict'])
134
            print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr))
135
            utils.adjust_learning_rate(optimizer, lr)
136
137
#用于测试加载模型
138
def val(args):
139
    list_threhold = [0.5]
140
    model = getattr(models, config.model_name)()
141
    if args.ckpt: model.load_state_dict(torch.load(args.ckpt, map_location='cpu')['state_dict'])
142
    model = model.to(device)
143
    criterion = nn.BCEWithLogitsLoss()
144
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
145
    val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, num_workers=4)
146
    for threshold in list_threhold:
147
        val_loss, val_f1 = val_epoch(model, criterion, val_dataloader, threshold)
148
        print('threshold %.2f val_loss:%0.3e val_f1:%.3f\n' % (threshold, val_loss, val_f1))
149
150
#提交结果使用
151
def test(args):
152
    from dataset import transform
153
    from data_process import name2index
154
    name2idx = name2index(config.arrythmia)
155
    idx2name = {idx: name for name, idx in name2idx.items()}
156
    utils.mkdirs(config.sub_dir)
157
    # model
158
    model = getattr(models, config.model_name)()
159
    model.load_state_dict(torch.load(args.ckpt, map_location='cpu')['state_dict'])
160
    model = model.to(device)
161
    model.eval()
162
    sub_file = '%s/subA_%s.txt' % (config.sub_dir, time.strftime("%Y%m%d%H%M"))
163
    fout = open(sub_file, 'w', encoding='utf-8')
164
    with torch.no_grad():
165
        for line in open(config.test_label, encoding='utf-8'):
166
            fout.write(line.strip('\n'))
167
            id = line.split('\t')[0]
168
            file_path = os.path.join(config.test_dir, id)
169
            df = pd.read_csv(file_path, sep=' ').values
170
            x = transform(df).unsqueeze(0).to(device)
171
            output = torch.sigmoid(model(x)).squeeze().cpu().numpy()
172
            ixs = [i for i, out in enumerate(output) if out > 0.5]
173
            for i in ixs:
174
                fout.write("\t" + idx2name[i])
175
            fout.write('\n')
176
    fout.close()
177
178
179
180
if __name__ == '__main__':
181
182
    import argparse
183
184
    parser = argparse.ArgumentParser()
185
    parser.add_argument("command", metavar="<command>", help="train or infer")
186
    parser.add_argument("--ckpt", type=str, help="the path of model weight file")
187
    parser.add_argument("--ex", type=str, help="experience name")
188
    parser.add_argument("--resume", action='store_true', default=False)
189
    args = parser.parse_args()
190
    if (args.command == "train"):
191
        train(args)
192
    if (args.command == "test"):
193
        test(args)
194
    if (args.command == "val"):
195
        val(args)