Diff of /trian_resnet.py [000000] .. [5ba3a6]

Switch to side-by-side view

--- a
+++ b/trian_resnet.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import torch.optim as optim
+from tensorboardX import SummaryWriter
+
+import numpy as np
+import time
+import datetime
+import argparse
+import os
+import os.path as osp
+
+from rs_dataset import RSDataset
+from get_logger import get_logger
+from res_network import Resnet18, Resnet34, Resnet101, Densenet121, SEResNext50
+
+
+def parse_args():
+    parse = argparse.ArgumentParser()
+    parse.add_argument('--epoch', type=int, default=15)
+    parse.add_argument('--schedule_step', type=int, default=4)
+
+    parse.add_argument('--batch_size', type=int, default=8)
+    parse.add_argument('--test_batch_size', type=int, default=128)
+    parse.add_argument('--num_workers', type=int, default=32)
+
+    parse.add_argument('--eval_fre', type=int, default=1)
+    parse.add_argument('--msg_fre', type=int, default=10)
+    parse.add_argument('--save_fre', type=int, default=1)
+
+    parse.add_argument('--name', type=str, default='SEResNext50', help='log/model_out/tensorboard log')
+    parse.add_argument('--data_dir', type=str, default='/media/tiger/zzr/rsna')
+    parse.add_argument('--log_dir', type=str, default='./logs')
+    parse.add_argument('--tensorboard_dir', type=str, default='./tensorboard')
+    parse.add_argument('--model_out_dir', type=str, default='./model_out')
+    parse.add_argument('--model_out_name', type=str, default='final_model.pth')
+    parse.add_argument('--seed', type=int, default=5, help='random seed')
+    parse.add_argument('--predefinedModel', type=str, default='/media/tiger/zzr/rsna_script/model_out/191004-003700_temp/out_1.pth')
+    return parse.parse_args()
+
+
+def evalute(net, val_loader, writer, epoch, logger):
+    logger.info('------------after epo {}, eval...-----------'.format(epoch))
+    loss = 0
+    net.eval()
+    with torch.no_grad():
+        for img,lb in val_loader:
+            img, lb = img.cuda(), lb.cuda()
+            outputs = net(img)
+            loss += nn.BCELoss()(outputs, lb)
+
+    loss /= len(val_loader)
+    logger.info('loss:{:.4f}/epoch{}'.format(loss, epoch))
+    writer.add_scalar('loss', loss)
+    net.train()
+
+
+def main_worker(args, logger):
+    try:
+        writer = SummaryWriter(logdir=args.sub_tensorboard_dir)
+        train_set = RSDataset(rootpth=args.data_dir, mode='train')
+        train_loader = DataLoader(train_set,
+                                  batch_size=args.batch_size,
+                                  drop_last=True,
+                                  shuffle=True,
+                                  pin_memory=True,
+                                  num_workers=args.num_workers)
+
+        # val_set = RSDataset(rootpth=args.data_dir, mode='train')
+        # val_loader = DataLoader(val_set,
+        #                         batch_size=args.test_batch_size,
+        #                         shuffle=False,
+        #                         pin_memory=True,
+        #                         num_workers=args.num_workers)
+        net = SEResNext50()
+        net = net.train()
+        net = net.cuda()
+        # net.load_state_dict(torch.load(args.predefinedModel))
+        criterion = nn.BCELoss().cuda()
+        # criterion = nn.CrossEntropyLoss().cuda()
+        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
+        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.schedule_step, gamma=0.3)
+        loss_record = []
+
+        iter = 0
+        running_loss = []
+        st = glob_st = time.time()
+        total_iter = len(train_loader)*args.epoch
+        for epoch in range(args.epoch):
+            # 评估
+            # evalute(net, val_loader, writer, epoch, logger)
+            # if epoch != 0 and epoch % args.eval_fre == 0:
+            #     evalute(net, val_loader, writer, epoch, logger)
+
+            if epoch != 0 and epoch % args.save_fre == 0:
+                model_out_name = osp.join(args.sub_model_out_dir,'out_{}.pth'.format(epoch))
+                # 防止分布式训练保存失败
+                state_dict = net.modules.state_dict() if hasattr(net, 'module') else net.state_dict()
+                torch.save(state_dict, model_out_name)
+
+            for img, lb in train_loader:
+                iter += 1
+                img = img.cuda()
+                lb = lb.cuda()
+                optimizer.zero_grad()
+                outputs = net(img)
+                loss = criterion(outputs, lb)
+                loss.backward()
+                optimizer.step()
+
+                running_loss.append(loss.item())
+
+                if iter % args.msg_fre == 0:
+                    ed = time.time()
+                    spend = ed-st
+                    global_spend = ed-glob_st
+                    st = ed
+
+                    eta = int((total_iter-iter)*(global_spend/iter))
+                    eta = str(datetime.timedelta(seconds=eta))
+                    global_spend = str(datetime.timedelta(seconds=(int(global_spend))))
+
+                    avg_loss = np.mean(running_loss)
+                    loss_record.append(avg_loss)
+                    running_loss = []
+
+                    lr = optimizer.param_groups[0]['lr']
+
+                    msg = '. '.join([
+                        'epoch:{epoch}',
+                        'iter/total_iter:{iter}/{total_iter}',
+                        'lr:{lr:.5f}',
+                        'loss:{loss:.4f}',
+                        'spend/global_spend:{spend:.4f}/{global_spend}',
+                        'eta:{eta}'
+                    ]).format(
+                        epoch=epoch,
+                        iter=iter,
+                        total_iter=total_iter,
+                        lr=lr,
+                        loss=avg_loss,
+                        spend=spend,
+                        global_spend=global_spend,
+                        eta=eta
+                    )
+                    logger.info(msg)
+                    writer.add_scalar('loss',avg_loss,iter)
+                    writer.add_scalar('lr',lr,iter)
+
+            scheduler.step()
+        # 训练完最后评估一次
+        # evalute(net, val_loader, writer, args.epoch, logger)
+
+        out_name = osp.join(args.sub_model_out_dir,args.model_out_name)
+        torch.save(net.cpu().state_dict(),out_name)
+
+        logger.info('-----------Done!!!----------')
+
+    except:
+        logger.exception('Exception logged')
+    finally:
+        writer.close()
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    torch.manual_seed(args.seed)
+    torch.cuda.manual_seed(args.seed)
+
+    # 唯一标识
+    unique_name = time.strftime('%y%m%d-%H%M%S_') + args.name
+    args.unique_name = unique_name
+
+    # 每次创建作业使用不同的tensorboard目录
+    args.sub_tensorboard_dir = osp.join(args.tensorboard_dir, args.unique_name)
+    # 保存模型的目录
+    args.sub_model_out_dir = osp.join(args.model_out_dir, args.unique_name)
+
+    # 创建所有用到的目录
+    for sub_dir in [args.sub_tensorboard_dir,args.sub_model_out_dir,  args.log_dir]:
+        if not osp.exists(sub_dir):
+            os.makedirs(sub_dir)
+
+    log_file_name = osp.join(args.log_dir,args.unique_name + '.log')
+    logger = get_logger(log_file_name)
+
+    for k, v in args.__dict__.items():
+        logger.info(k)
+        logger.info(v)
+
+    main_worker(args, logger=logger)
\ No newline at end of file