Diff of /trian_resnet.py [000000] .. [5ba3a6]

Switch to unified view

a b/trian_resnet.py
1
# -*- coding: utf-8 -*-
2
import torch
3
import torch.nn as nn
4
import torch.nn.functional as F
5
from torch.utils.data import DataLoader
6
import torch.optim as optim
7
from tensorboardX import SummaryWriter
8
9
import numpy as np
10
import time
11
import datetime
12
import argparse
13
import os
14
import os.path as osp
15
16
from rs_dataset import RSDataset
17
from get_logger import get_logger
18
from res_network import Resnet18, Resnet34, Resnet101, Densenet121, SEResNext50
19
20
21
def parse_args():
22
    parse = argparse.ArgumentParser()
23
    parse.add_argument('--epoch', type=int, default=15)
24
    parse.add_argument('--schedule_step', type=int, default=4)
25
26
    parse.add_argument('--batch_size', type=int, default=8)
27
    parse.add_argument('--test_batch_size', type=int, default=128)
28
    parse.add_argument('--num_workers', type=int, default=32)
29
30
    parse.add_argument('--eval_fre', type=int, default=1)
31
    parse.add_argument('--msg_fre', type=int, default=10)
32
    parse.add_argument('--save_fre', type=int, default=1)
33
34
    parse.add_argument('--name', type=str, default='SEResNext50', help='log/model_out/tensorboard log')
35
    parse.add_argument('--data_dir', type=str, default='/media/tiger/zzr/rsna')
36
    parse.add_argument('--log_dir', type=str, default='./logs')
37
    parse.add_argument('--tensorboard_dir', type=str, default='./tensorboard')
38
    parse.add_argument('--model_out_dir', type=str, default='./model_out')
39
    parse.add_argument('--model_out_name', type=str, default='final_model.pth')
40
    parse.add_argument('--seed', type=int, default=5, help='random seed')
41
    parse.add_argument('--predefinedModel', type=str, default='/media/tiger/zzr/rsna_script/model_out/191004-003700_temp/out_1.pth')
42
    return parse.parse_args()
43
44
45
def evalute(net, val_loader, writer, epoch, logger):
46
    logger.info('------------after epo {}, eval...-----------'.format(epoch))
47
    loss = 0
48
    net.eval()
49
    with torch.no_grad():
50
        for img,lb in val_loader:
51
            img, lb = img.cuda(), lb.cuda()
52
            outputs = net(img)
53
            loss += nn.BCELoss()(outputs, lb)
54
55
    loss /= len(val_loader)
56
    logger.info('loss:{:.4f}/epoch{}'.format(loss, epoch))
57
    writer.add_scalar('loss', loss)
58
    net.train()
59
60
61
def main_worker(args, logger):
62
    try:
63
        writer = SummaryWriter(logdir=args.sub_tensorboard_dir)
64
        train_set = RSDataset(rootpth=args.data_dir, mode='train')
65
        train_loader = DataLoader(train_set,
66
                                  batch_size=args.batch_size,
67
                                  drop_last=True,
68
                                  shuffle=True,
69
                                  pin_memory=True,
70
                                  num_workers=args.num_workers)
71
72
        # val_set = RSDataset(rootpth=args.data_dir, mode='train')
73
        # val_loader = DataLoader(val_set,
74
        #                         batch_size=args.test_batch_size,
75
        #                         shuffle=False,
76
        #                         pin_memory=True,
77
        #                         num_workers=args.num_workers)
78
        net = SEResNext50()
79
        net = net.train()
80
        net = net.cuda()
81
        # net.load_state_dict(torch.load(args.predefinedModel))
82
        criterion = nn.BCELoss().cuda()
83
        # criterion = nn.CrossEntropyLoss().cuda()
84
        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
85
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.schedule_step, gamma=0.3)
86
        loss_record = []
87
88
        iter = 0
89
        running_loss = []
90
        st = glob_st = time.time()
91
        total_iter = len(train_loader)*args.epoch
92
        for epoch in range(args.epoch):
93
            # 评估
94
            # evalute(net, val_loader, writer, epoch, logger)
95
            # if epoch != 0 and epoch % args.eval_fre == 0:
96
            #     evalute(net, val_loader, writer, epoch, logger)
97
98
            if epoch != 0 and epoch % args.save_fre == 0:
99
                model_out_name = osp.join(args.sub_model_out_dir,'out_{}.pth'.format(epoch))
100
                # 防止分布式训练保存失败
101
                state_dict = net.modules.state_dict() if hasattr(net, 'module') else net.state_dict()
102
                torch.save(state_dict, model_out_name)
103
104
            for img, lb in train_loader:
105
                iter += 1
106
                img = img.cuda()
107
                lb = lb.cuda()
108
                optimizer.zero_grad()
109
                outputs = net(img)
110
                loss = criterion(outputs, lb)
111
                loss.backward()
112
                optimizer.step()
113
114
                running_loss.append(loss.item())
115
116
                if iter % args.msg_fre == 0:
117
                    ed = time.time()
118
                    spend = ed-st
119
                    global_spend = ed-glob_st
120
                    st = ed
121
122
                    eta = int((total_iter-iter)*(global_spend/iter))
123
                    eta = str(datetime.timedelta(seconds=eta))
124
                    global_spend = str(datetime.timedelta(seconds=(int(global_spend))))
125
126
                    avg_loss = np.mean(running_loss)
127
                    loss_record.append(avg_loss)
128
                    running_loss = []
129
130
                    lr = optimizer.param_groups[0]['lr']
131
132
                    msg = '. '.join([
133
                        'epoch:{epoch}',
134
                        'iter/total_iter:{iter}/{total_iter}',
135
                        'lr:{lr:.5f}',
136
                        'loss:{loss:.4f}',
137
                        'spend/global_spend:{spend:.4f}/{global_spend}',
138
                        'eta:{eta}'
139
                    ]).format(
140
                        epoch=epoch,
141
                        iter=iter,
142
                        total_iter=total_iter,
143
                        lr=lr,
144
                        loss=avg_loss,
145
                        spend=spend,
146
                        global_spend=global_spend,
147
                        eta=eta
148
                    )
149
                    logger.info(msg)
150
                    writer.add_scalar('loss',avg_loss,iter)
151
                    writer.add_scalar('lr',lr,iter)
152
153
            scheduler.step()
154
        # 训练完最后评估一次
155
        # evalute(net, val_loader, writer, args.epoch, logger)
156
157
        out_name = osp.join(args.sub_model_out_dir,args.model_out_name)
158
        torch.save(net.cpu().state_dict(),out_name)
159
160
        logger.info('-----------Done!!!----------')
161
162
    except:
163
        logger.exception('Exception logged')
164
    finally:
165
        writer.close()
166
167
168
if __name__ == '__main__':
169
    args = parse_args()
170
171
    torch.manual_seed(args.seed)
172
    torch.cuda.manual_seed(args.seed)
173
174
    # 唯一标识
175
    unique_name = time.strftime('%y%m%d-%H%M%S_') + args.name
176
    args.unique_name = unique_name
177
178
    # 每次创建作业使用不同的tensorboard目录
179
    args.sub_tensorboard_dir = osp.join(args.tensorboard_dir, args.unique_name)
180
    # 保存模型的目录
181
    args.sub_model_out_dir = osp.join(args.model_out_dir, args.unique_name)
182
183
    # 创建所有用到的目录
184
    for sub_dir in [args.sub_tensorboard_dir,args.sub_model_out_dir,  args.log_dir]:
185
        if not osp.exists(sub_dir):
186
            os.makedirs(sub_dir)
187
188
    log_file_name = osp.join(args.log_dir,args.unique_name + '.log')
189
    logger = get_logger(log_file_name)
190
191
    for k, v in args.__dict__.items():
192
        logger.info(k)
193
        logger.info(v)
194
195
    main_worker(args, logger=logger)