|
a |
|
b/trian_resnet.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
import torch |
|
|
3 |
import torch.nn as nn |
|
|
4 |
import torch.nn.functional as F |
|
|
5 |
from torch.utils.data import DataLoader |
|
|
6 |
import torch.optim as optim |
|
|
7 |
from tensorboardX import SummaryWriter |
|
|
8 |
|
|
|
9 |
import numpy as np |
|
|
10 |
import time |
|
|
11 |
import datetime |
|
|
12 |
import argparse |
|
|
13 |
import os |
|
|
14 |
import os.path as osp |
|
|
15 |
|
|
|
16 |
from rs_dataset import RSDataset |
|
|
17 |
from get_logger import get_logger |
|
|
18 |
from res_network import Resnet18, Resnet34, Resnet101, Densenet121, SEResNext50 |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
def parse_args(): |
|
|
22 |
parse = argparse.ArgumentParser() |
|
|
23 |
parse.add_argument('--epoch', type=int, default=15) |
|
|
24 |
parse.add_argument('--schedule_step', type=int, default=4) |
|
|
25 |
|
|
|
26 |
parse.add_argument('--batch_size', type=int, default=8) |
|
|
27 |
parse.add_argument('--test_batch_size', type=int, default=128) |
|
|
28 |
parse.add_argument('--num_workers', type=int, default=32) |
|
|
29 |
|
|
|
30 |
parse.add_argument('--eval_fre', type=int, default=1) |
|
|
31 |
parse.add_argument('--msg_fre', type=int, default=10) |
|
|
32 |
parse.add_argument('--save_fre', type=int, default=1) |
|
|
33 |
|
|
|
34 |
parse.add_argument('--name', type=str, default='SEResNext50', help='log/model_out/tensorboard log') |
|
|
35 |
parse.add_argument('--data_dir', type=str, default='/media/tiger/zzr/rsna') |
|
|
36 |
parse.add_argument('--log_dir', type=str, default='./logs') |
|
|
37 |
parse.add_argument('--tensorboard_dir', type=str, default='./tensorboard') |
|
|
38 |
parse.add_argument('--model_out_dir', type=str, default='./model_out') |
|
|
39 |
parse.add_argument('--model_out_name', type=str, default='final_model.pth') |
|
|
40 |
parse.add_argument('--seed', type=int, default=5, help='random seed') |
|
|
41 |
parse.add_argument('--predefinedModel', type=str, default='/media/tiger/zzr/rsna_script/model_out/191004-003700_temp/out_1.pth') |
|
|
42 |
return parse.parse_args() |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def evalute(net, val_loader, writer, epoch, logger): |
|
|
46 |
logger.info('------------after epo {}, eval...-----------'.format(epoch)) |
|
|
47 |
loss = 0 |
|
|
48 |
net.eval() |
|
|
49 |
with torch.no_grad(): |
|
|
50 |
for img,lb in val_loader: |
|
|
51 |
img, lb = img.cuda(), lb.cuda() |
|
|
52 |
outputs = net(img) |
|
|
53 |
loss += nn.BCELoss()(outputs, lb) |
|
|
54 |
|
|
|
55 |
loss /= len(val_loader) |
|
|
56 |
logger.info('loss:{:.4f}/epoch{}'.format(loss, epoch)) |
|
|
57 |
writer.add_scalar('loss', loss) |
|
|
58 |
net.train() |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
def main_worker(args, logger): |
|
|
62 |
try: |
|
|
63 |
writer = SummaryWriter(logdir=args.sub_tensorboard_dir) |
|
|
64 |
train_set = RSDataset(rootpth=args.data_dir, mode='train') |
|
|
65 |
train_loader = DataLoader(train_set, |
|
|
66 |
batch_size=args.batch_size, |
|
|
67 |
drop_last=True, |
|
|
68 |
shuffle=True, |
|
|
69 |
pin_memory=True, |
|
|
70 |
num_workers=args.num_workers) |
|
|
71 |
|
|
|
72 |
# val_set = RSDataset(rootpth=args.data_dir, mode='train') |
|
|
73 |
# val_loader = DataLoader(val_set, |
|
|
74 |
# batch_size=args.test_batch_size, |
|
|
75 |
# shuffle=False, |
|
|
76 |
# pin_memory=True, |
|
|
77 |
# num_workers=args.num_workers) |
|
|
78 |
net = SEResNext50() |
|
|
79 |
net = net.train() |
|
|
80 |
net = net.cuda() |
|
|
81 |
# net.load_state_dict(torch.load(args.predefinedModel)) |
|
|
82 |
criterion = nn.BCELoss().cuda() |
|
|
83 |
# criterion = nn.CrossEntropyLoss().cuda() |
|
|
84 |
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) |
|
|
85 |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.schedule_step, gamma=0.3) |
|
|
86 |
loss_record = [] |
|
|
87 |
|
|
|
88 |
iter = 0 |
|
|
89 |
running_loss = [] |
|
|
90 |
st = glob_st = time.time() |
|
|
91 |
total_iter = len(train_loader)*args.epoch |
|
|
92 |
for epoch in range(args.epoch): |
|
|
93 |
# 评估 |
|
|
94 |
# evalute(net, val_loader, writer, epoch, logger) |
|
|
95 |
# if epoch != 0 and epoch % args.eval_fre == 0: |
|
|
96 |
# evalute(net, val_loader, writer, epoch, logger) |
|
|
97 |
|
|
|
98 |
if epoch != 0 and epoch % args.save_fre == 0: |
|
|
99 |
model_out_name = osp.join(args.sub_model_out_dir,'out_{}.pth'.format(epoch)) |
|
|
100 |
# 防止分布式训练保存失败 |
|
|
101 |
state_dict = net.modules.state_dict() if hasattr(net, 'module') else net.state_dict() |
|
|
102 |
torch.save(state_dict, model_out_name) |
|
|
103 |
|
|
|
104 |
for img, lb in train_loader: |
|
|
105 |
iter += 1 |
|
|
106 |
img = img.cuda() |
|
|
107 |
lb = lb.cuda() |
|
|
108 |
optimizer.zero_grad() |
|
|
109 |
outputs = net(img) |
|
|
110 |
loss = criterion(outputs, lb) |
|
|
111 |
loss.backward() |
|
|
112 |
optimizer.step() |
|
|
113 |
|
|
|
114 |
running_loss.append(loss.item()) |
|
|
115 |
|
|
|
116 |
if iter % args.msg_fre == 0: |
|
|
117 |
ed = time.time() |
|
|
118 |
spend = ed-st |
|
|
119 |
global_spend = ed-glob_st |
|
|
120 |
st = ed |
|
|
121 |
|
|
|
122 |
eta = int((total_iter-iter)*(global_spend/iter)) |
|
|
123 |
eta = str(datetime.timedelta(seconds=eta)) |
|
|
124 |
global_spend = str(datetime.timedelta(seconds=(int(global_spend)))) |
|
|
125 |
|
|
|
126 |
avg_loss = np.mean(running_loss) |
|
|
127 |
loss_record.append(avg_loss) |
|
|
128 |
running_loss = [] |
|
|
129 |
|
|
|
130 |
lr = optimizer.param_groups[0]['lr'] |
|
|
131 |
|
|
|
132 |
msg = '. '.join([ |
|
|
133 |
'epoch:{epoch}', |
|
|
134 |
'iter/total_iter:{iter}/{total_iter}', |
|
|
135 |
'lr:{lr:.5f}', |
|
|
136 |
'loss:{loss:.4f}', |
|
|
137 |
'spend/global_spend:{spend:.4f}/{global_spend}', |
|
|
138 |
'eta:{eta}' |
|
|
139 |
]).format( |
|
|
140 |
epoch=epoch, |
|
|
141 |
iter=iter, |
|
|
142 |
total_iter=total_iter, |
|
|
143 |
lr=lr, |
|
|
144 |
loss=avg_loss, |
|
|
145 |
spend=spend, |
|
|
146 |
global_spend=global_spend, |
|
|
147 |
eta=eta |
|
|
148 |
) |
|
|
149 |
logger.info(msg) |
|
|
150 |
writer.add_scalar('loss',avg_loss,iter) |
|
|
151 |
writer.add_scalar('lr',lr,iter) |
|
|
152 |
|
|
|
153 |
scheduler.step() |
|
|
154 |
# 训练完最后评估一次 |
|
|
155 |
# evalute(net, val_loader, writer, args.epoch, logger) |
|
|
156 |
|
|
|
157 |
out_name = osp.join(args.sub_model_out_dir,args.model_out_name) |
|
|
158 |
torch.save(net.cpu().state_dict(),out_name) |
|
|
159 |
|
|
|
160 |
logger.info('-----------Done!!!----------') |
|
|
161 |
|
|
|
162 |
except: |
|
|
163 |
logger.exception('Exception logged') |
|
|
164 |
finally: |
|
|
165 |
writer.close() |
|
|
166 |
|
|
|
167 |
|
|
|
168 |
if __name__ == '__main__': |
|
|
169 |
args = parse_args() |
|
|
170 |
|
|
|
171 |
torch.manual_seed(args.seed) |
|
|
172 |
torch.cuda.manual_seed(args.seed) |
|
|
173 |
|
|
|
174 |
# 唯一标识 |
|
|
175 |
unique_name = time.strftime('%y%m%d-%H%M%S_') + args.name |
|
|
176 |
args.unique_name = unique_name |
|
|
177 |
|
|
|
178 |
# 每次创建作业使用不同的tensorboard目录 |
|
|
179 |
args.sub_tensorboard_dir = osp.join(args.tensorboard_dir, args.unique_name) |
|
|
180 |
# 保存模型的目录 |
|
|
181 |
args.sub_model_out_dir = osp.join(args.model_out_dir, args.unique_name) |
|
|
182 |
|
|
|
183 |
# 创建所有用到的目录 |
|
|
184 |
for sub_dir in [args.sub_tensorboard_dir,args.sub_model_out_dir, args.log_dir]: |
|
|
185 |
if not osp.exists(sub_dir): |
|
|
186 |
os.makedirs(sub_dir) |
|
|
187 |
|
|
|
188 |
log_file_name = osp.join(args.log_dir,args.unique_name + '.log') |
|
|
189 |
logger = get_logger(log_file_name) |
|
|
190 |
|
|
|
191 |
for k, v in args.__dict__.items(): |
|
|
192 |
logger.info(k) |
|
|
193 |
logger.info(v) |
|
|
194 |
|
|
|
195 |
main_worker(args, logger=logger) |