--- a +++ b/main.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +''' +@time: 2019/7/23 19:42 + +@ author: javis +''' +import torch, time, os, shutil +import models, utils +import numpy as np +import pandas as pd +from tensorboard_logger import Logger +from torch import nn, optim +from torch.utils.data import DataLoader +from dataset import ECGDataset +from config import config + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(41) +torch.cuda.manual_seed(41) + + +# 保存当前模型的权重,并且更新最佳的模型权重 +def save_ckpt(state, is_best, model_save_dir): + current_w = os.path.join(model_save_dir, config.current_w) + best_w = os.path.join(model_save_dir, config.best_w) + torch.save(state, current_w) + if is_best: shutil.copyfile(current_w, best_w) + + +def train_epoch(model, optimizer, criterion, train_dataloader, show_interval=10): + model.train() + f1_meter, loss_meter, it_count = 0, 0, 0 + for inputs, target in train_dataloader: + inputs = inputs.to(device) + target = target.to(device) + # zero the parameter gradients + optimizer.zero_grad() + # forward + output = model(inputs) + loss = criterion(output, target) + loss.backward() + optimizer.step() + loss_meter += loss.item() + it_count += 1 + f1 = utils.calc_f1(target, torch.sigmoid(output)) + f1_meter += f1 + if it_count != 0 and it_count % show_interval == 0: + print("%d,loss:%.3e f1:%.3f" % (it_count, loss.item(), f1)) + return loss_meter / it_count, f1_meter / it_count + + +def val_epoch(model, criterion, val_dataloader, threshold=0.5): + model.eval() + f1_meter, loss_meter, it_count = 0, 0, 0 + with torch.no_grad(): + for inputs, target in val_dataloader: + inputs = inputs.to(device) + target = target.to(device) + output = model(inputs) + loss = criterion(output, target) + loss_meter += loss.item() + it_count += 1 + output = torch.sigmoid(output) + f1 = utils.calc_f1(target, output, threshold) + f1_meter += f1 + return loss_meter / it_count, f1_meter / it_count + + +def train(args): + # model + model = getattr(models, config.model_name)() + if args.ckpt and not args.resume: + state = torch.load(args.ckpt, map_location='cpu') + model.load_state_dict(state['state_dict']) + print('train with pretrained weight val_f1', state['f1']) + model = model.to(device) + # data + train_dataset = ECGDataset(data_path=config.train_data, train=True) + train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=6) + val_dataset = ECGDataset(data_path=config.train_data, train=False) + val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, num_workers=4) + print("train_datasize", len(train_dataset), "val_datasize", len(val_dataset)) + # optimizer and loss + optimizer = optim.Adam(model.parameters(), lr=config.lr) + w = torch.tensor(train_dataset.wc, dtype=torch.float).to(device) + criterion = utils.WeightedMultilabel(w) + # 模型保存文件夹 + model_save_dir = '%s/%s_%s' % (config.ckpt, config.model_name, time.strftime("%Y%m%d%H%M")) + if args.ex: model_save_dir += args.ex + best_f1 = -1 + lr = config.lr + start_epoch = 1 + stage = 1 + # 从上一个断点,继续训练 + if args.resume: + if os.path.exists(args.ckpt): # 这里是存放权重的目录 + model_save_dir = args.ckpt + current_w = torch.load(os.path.join(args.ckpt, config.current_w)) + best_w = torch.load(os.path.join(model_save_dir, config.best_w)) + best_f1 = best_w['loss'] + start_epoch = current_w['epoch'] + 1 + lr = current_w['lr'] + stage = current_w['stage'] + model.load_state_dict(current_w['state_dict']) + # 如果中断点恰好为转换stage的点 + if start_epoch - 1 in config.stage_epoch: + stage += 1 + lr /= config.lr_decay + utils.adjust_learning_rate(optimizer, lr) + model.load_state_dict(best_w['state_dict']) + print("=> loaded checkpoint (epoch {})".format(start_epoch - 1)) + logger = Logger(logdir=model_save_dir, flush_secs=2) + # =========>开始训练<========= + for epoch in range(start_epoch, config.max_epoch + 1): + since = time.time() + train_loss, train_f1 = train_epoch(model, optimizer, criterion, train_dataloader, show_interval=100) + val_loss, val_f1 = val_epoch(model, criterion, val_dataloader) + print('#epoch:%02d stage:%d train_loss:%.3e train_f1:%.3f val_loss:%0.3e val_f1:%.3f time:%s\n' + % (epoch, stage, train_loss, train_f1, val_loss, val_f1, utils.print_time_cost(since))) + logger.log_value('train_loss', train_loss, step=epoch) + logger.log_value('train_f1', train_f1, step=epoch) + logger.log_value('val_loss', val_loss, step=epoch) + logger.log_value('val_f1', val_f1, step=epoch) + state = {"state_dict": model.state_dict(), "epoch": epoch, "loss": val_loss, 'f1': val_f1, 'lr': lr, + 'stage': stage} + save_ckpt(state, best_f1 < val_f1, model_save_dir) + best_f1 = max(best_f1, val_f1) + if epoch in config.stage_epoch: + stage += 1 + lr /= config.lr_decay + best_w = os.path.join(model_save_dir, config.best_w) + model.load_state_dict(torch.load(best_w)['state_dict']) + print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr)) + utils.adjust_learning_rate(optimizer, lr) + +#用于测试加载模型 +def val(args): + list_threhold = [0.5] + model = getattr(models, config.model_name)() + if args.ckpt: model.load_state_dict(torch.load(args.ckpt, map_location='cpu')['state_dict']) + model = model.to(device) + criterion = nn.BCEWithLogitsLoss() + val_dataset = ECGDataset(data_path=config.train_data, train=False) + val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, num_workers=4) + for threshold in list_threhold: + val_loss, val_f1 = val_epoch(model, criterion, val_dataloader, threshold) + print('threshold %.2f val_loss:%0.3e val_f1:%.3f\n' % (threshold, val_loss, val_f1)) + +#提交结果使用 +def test(args): + from dataset import transform + from data_process import name2index + name2idx = name2index(config.arrythmia) + idx2name = {idx: name for name, idx in name2idx.items()} + utils.mkdirs(config.sub_dir) + # model + model = getattr(models, config.model_name)() + model.load_state_dict(torch.load(args.ckpt, map_location='cpu')['state_dict']) + model = model.to(device) + model.eval() + sub_file = '%s/subA_%s.txt' % (config.sub_dir, time.strftime("%Y%m%d%H%M")) + fout = open(sub_file, 'w', encoding='utf-8') + with torch.no_grad(): + for line in open(config.test_label, encoding='utf-8'): + fout.write(line.strip('\n')) + id = line.split('\t')[0] + file_path = os.path.join(config.test_dir, id) + df = pd.read_csv(file_path, sep=' ').values + x = transform(df).unsqueeze(0).to(device) + output = torch.sigmoid(model(x)).squeeze().cpu().numpy() + ixs = [i for i, out in enumerate(output) if out > 0.5] + for i in ixs: + fout.write("\t" + idx2name[i]) + fout.write('\n') + fout.close() + + + +if __name__ == '__main__': + + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("command", metavar="<command>", help="train or infer") + parser.add_argument("--ckpt", type=str, help="the path of model weight file") + parser.add_argument("--ex", type=str, help="experience name") + parser.add_argument("--resume", action='store_true', default=False) + args = parser.parse_args() + if (args.command == "train"): + train(args) + if (args.command == "test"): + test(args) + if (args.command == "val"): + val(args) \ No newline at end of file