|
a |
|
b/Serialized/helper/mytraining.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import numpy as np |
|
|
4 |
import torchvision |
|
|
5 |
import torch.nn.functional as F |
|
|
6 |
from torch.utils.data import Dataset, DataLoader |
|
|
7 |
import math |
|
|
8 |
from torch.optim import Optimizer |
|
|
9 |
from torch.optim.optimizer import required |
|
|
10 |
from torch.nn.utils import clip_grad_norm_ |
|
|
11 |
import logging |
|
|
12 |
import abc |
|
|
13 |
import sys |
|
|
14 |
from tqdm import tqdm_notebook |
|
|
15 |
import torch.utils.data as D |
|
|
16 |
import torch.nn.functional as F |
|
|
17 |
from apex import amp |
|
|
18 |
#from .mymodels import out_to_predict,out_to_predict_in,out_to_predict_test,out_to_predict_test_simple |
|
|
19 |
import copy |
|
|
20 |
from scipy.spatial import distance_matrix |
|
|
21 |
import matplotlib.pyplot as plt |
|
|
22 |
from .mymodels import mean_model |
|
|
23 |
def get_model_device(model): |
|
|
24 |
p = next(model.parameters()) |
|
|
25 |
if p.is_cuda: |
|
|
26 |
device = torch.device("cuda:{}".format(p.get_device())) |
|
|
27 |
else: |
|
|
28 |
device = torch.device('cpu') |
|
|
29 |
return device |
|
|
30 |
|
|
|
31 |
def model_train(model,optimizer,train_dataset,batch_size,num_epochs,loss_func, |
|
|
32 |
weights=None,accumulation_steps=1, |
|
|
33 |
weights_func=None,do_apex=True,validate_dataset=None, |
|
|
34 |
validate_loss=None,metric=None,param_schedualer=None, |
|
|
35 |
weights_data=False,history=None,return_model=False,model_apexed=False, |
|
|
36 |
num_workers=7,sampler=None,graph=None,k_lossf=0.01,pre_process=None, |
|
|
37 |
call_progress=None,use_batchs=True,best_average=1): |
|
|
38 |
|
|
|
39 |
if history is None: |
|
|
40 |
history = [] |
|
|
41 |
num_average_models=min(num_epochs,best_average) |
|
|
42 |
best_models=np.empty(num_average_models+1,dtype=object) |
|
|
43 |
best_val_loss=1e6*np.ones(num_average_models+1,dtype=np.float) |
|
|
44 |
device = get_model_device(model) |
|
|
45 |
if do_apex and not model_apexed and (device.type=='cuda'): |
|
|
46 |
model_apexed=True |
|
|
47 |
model, optimizer = amp.initialize(model, optimizer, opt_level="O1",verbosity=0) |
|
|
48 |
model.zero_grad() |
|
|
49 |
tq_epoch=tqdm_notebook(range(num_epochs)) |
|
|
50 |
lossf=None |
|
|
51 |
for epoch in tq_epoch: |
|
|
52 |
best_models[1:]=best_models[:num_average_models] |
|
|
53 |
best_val_loss[1:]=best_val_loss[:num_average_models] |
|
|
54 |
torch.cuda.empty_cache() |
|
|
55 |
model.do_grad() |
|
|
56 |
if param_schedualer: |
|
|
57 |
param_schedualer(epoch) |
|
|
58 |
_=model.train() |
|
|
59 |
batch_size_= batch_size if use_batchs else None |
|
|
60 |
if sampler: |
|
|
61 |
data_loader=D.DataLoader(D.Subset(train_dataset,sampler()),num_workers=num_workers, |
|
|
62 |
batch_size=batch_size if use_batchs else None, |
|
|
63 |
shuffle=use_batchs) |
|
|
64 |
else: |
|
|
65 |
data_loader=D.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers) |
|
|
66 |
sum_loss = 0. |
|
|
67 |
if metric: |
|
|
68 |
metric.zero() |
|
|
69 |
tq_batch = tqdm_notebook(data_loader,leave=True) |
|
|
70 |
model.zero_grad() |
|
|
71 |
for i,(batchs) in enumerate(tq_batch): |
|
|
72 |
x_batch=batchs[0].to(device) if len(batchs)==2 else [x.to(device) for x in batchs[:-1]] |
|
|
73 |
y_batch=batchs[-1].to(device) |
|
|
74 |
if pre_process is not None: |
|
|
75 |
x_batch,y_batch = pre_process(x_batch,y_batch) |
|
|
76 |
if weights_data: |
|
|
77 |
weights=x_batch[-1] |
|
|
78 |
x_batch=x_batch[:-1] |
|
|
79 |
if weights_func: |
|
|
80 |
weights=weights_func(weights,epoch,i) |
|
|
81 |
y_preds = model(x_batch) if not isinstance(x_batch,list) else model(*x_batch) |
|
|
82 |
|
|
|
83 |
if weights is not None: |
|
|
84 |
loss = loss_func(y_preds,y_batch,weights=weights)/accumulation_steps |
|
|
85 |
else: |
|
|
86 |
loss = loss_func(y_preds,y_batch)/accumulation_steps |
|
|
87 |
|
|
|
88 |
if model_apexed: |
|
|
89 |
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
|
90 |
scaled_loss.backward() |
|
|
91 |
else: |
|
|
92 |
loss.backward() |
|
|
93 |
if (i+1) % accumulation_steps == 0: # Wait for several backward steps |
|
|
94 |
optimizer.step() # Now we can do an optimizer step |
|
|
95 |
model.zero_grad() |
|
|
96 |
|
|
|
97 |
if lossf: |
|
|
98 |
lossf = (1-k_lossf)*lossf+k_lossf*loss.detach().item()*accumulation_steps |
|
|
99 |
else: |
|
|
100 |
lossf = loss.detach().item()*accumulation_steps |
|
|
101 |
if graph is not None: |
|
|
102 |
graph(lossf) |
|
|
103 |
batch_postfix={'loss':lossf} |
|
|
104 |
if metric: |
|
|
105 |
if isinstance(y_preds,tuple): |
|
|
106 |
yp = tuple(y_preds[0].detach().cpu()) |
|
|
107 |
else: |
|
|
108 |
yp =y_preds.detach().cpu() |
|
|
109 |
batch_postfix.update(metric.calc(yp,y_batch.cpu().detach())) |
|
|
110 |
tq_batch.set_postfix(**batch_postfix) |
|
|
111 |
|
|
|
112 |
sum_loss=sum_loss+loss.detach().item()*accumulation_steps |
|
|
113 |
|
|
|
114 |
epoch_postfix={'loss':sum_loss/len(data_loader)} |
|
|
115 |
if metric: |
|
|
116 |
epoch_postfix.update(metric.calc_sums()) |
|
|
117 |
tq_epoch.set_postfix(**batch_postfix) |
|
|
118 |
history.append(batch_postfix) |
|
|
119 |
|
|
|
120 |
if validate_dataset: |
|
|
121 |
if validate_loss is None: |
|
|
122 |
vloss = loss_func |
|
|
123 |
val_weights = weights |
|
|
124 |
else: |
|
|
125 |
vloss = validate_loss |
|
|
126 |
val_weights =None |
|
|
127 |
res=model_evaluate(model, |
|
|
128 |
validate_dataset, |
|
|
129 |
batch_size = batch_size if use_batchs else None , |
|
|
130 |
loss_func=vloss, |
|
|
131 |
weights=val_weights, |
|
|
132 |
metric=metric, |
|
|
133 |
do_apex=False, |
|
|
134 |
num_workers=num_workers) |
|
|
135 |
|
|
|
136 |
history[-1].update(res[1]) |
|
|
137 |
best_val_loss[0] = res[0] |
|
|
138 |
best_models[0] = copy.deepcopy(model).to('cpu') |
|
|
139 |
best_models[0].no_grad() |
|
|
140 |
best_models=best_models[np.argsort(best_val_loss)] |
|
|
141 |
best_val_loss=best_val_loss[np.argsort(best_val_loss)] |
|
|
142 |
# if res[0]<best_val_loss: |
|
|
143 |
# best_val_loss=res[0] |
|
|
144 |
# best_model=copy.deepcopy(model) |
|
|
145 |
# best_model.no_grad() |
|
|
146 |
tq_epoch.set_postfix(res[1]) |
|
|
147 |
print(history[-1]) |
|
|
148 |
if call_progress is not None: |
|
|
149 |
call_progress(history) |
|
|
150 |
if num_average_models>1: |
|
|
151 |
best_model=mean_model(best_models[:num_average_models]) |
|
|
152 |
model=model.to('cpu') |
|
|
153 |
best_model=best_model.to(device) |
|
|
154 |
res=model_evaluate(best_model, |
|
|
155 |
validate_dataset, |
|
|
156 |
batch_size = batch_size if use_batchs else None , |
|
|
157 |
loss_func=vloss, |
|
|
158 |
weights=val_weights, |
|
|
159 |
metric=metric, |
|
|
160 |
do_apex=False, |
|
|
161 |
num_workers=num_workers) |
|
|
162 |
best_model=best_model.to('cpu') |
|
|
163 |
model=model.to(device) |
|
|
164 |
if res[0]>best_val_loss[0]: |
|
|
165 |
best_model=best_models[0] |
|
|
166 |
print (best_val_loss[0]) |
|
|
167 |
else: |
|
|
168 |
print (res) |
|
|
169 |
else: |
|
|
170 |
best_model=best_models[0] |
|
|
171 |
print (best_val_loss) |
|
|
172 |
return (history,best_model) if return_model else history |
|
|
173 |
|
|
|
174 |
|
|
|
175 |
|
|
|
176 |
|
|
|
177 |
def model_run(model,dataset,do_apex=True,batch_size=32,num_workers=6): |
|
|
178 |
_=model.eval() |
|
|
179 |
model.no_grad() |
|
|
180 |
device = get_model_device(model) |
|
|
181 |
if do_apex and (device.type=='cuda'): |
|
|
182 |
model = amp.initialize(model, opt_level="O1",verbosity=0) |
|
|
183 |
res_list=[] |
|
|
184 |
data_loader=D.DataLoader(dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers) |
|
|
185 |
for batchs in tqdm_notebook(data_loader): |
|
|
186 |
y_preds=model(*[x.to(device) for x in batchs]) if isinstance(batchs,tuple) else model(batchs.to(device)) |
|
|
187 |
res_list.append(tuple([y.cpu() for y in y_preds]) if isinstance(y_preds,tuple) else y_preds.cpu()) |
|
|
188 |
return tuple([torch.cat(tens) for tens in map(list, zip(*res_list))]) if isinstance(res_list[0],tuple) else torch.cat(res_list) |
|
|
189 |
|
|
|
190 |
def models_run(models,dataset,do_apex=True,batch_size=32,num_workers=6): |
|
|
191 |
islist = isinstance(models,list) |
|
|
192 |
if islist: |
|
|
193 |
models_=models |
|
|
194 |
else: |
|
|
195 |
models_=[models] |
|
|
196 |
for model in models_: |
|
|
197 |
_=model.eval() |
|
|
198 |
model.no_grad() |
|
|
199 |
device = get_model_device(models_[0]) |
|
|
200 |
|
|
|
201 |
if do_apex and (device.type=='cuda'): |
|
|
202 |
for model in models_: |
|
|
203 |
model = amp.initialize(model, opt_level="O1",verbosity=0) |
|
|
204 |
res_list=[] |
|
|
205 |
for model in models_: |
|
|
206 |
res_list.append([]) |
|
|
207 |
data_loader=D.DataLoader(dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers) |
|
|
208 |
for batchs in tqdm_notebook(data_loader): |
|
|
209 |
for i,model in enumerate(models_): |
|
|
210 |
y_preds=model(*[x.to(device) for x in batchs]) if isinstance(batchs,tuple) else model(batchs.to(device)) |
|
|
211 |
res_list[i].append(tuple([y.cpu().detach() for y in y_preds]) if isinstance(y_preds,tuple) else y_preds.cpu().detach()) |
|
|
212 |
res=[] |
|
|
213 |
for i in range(len(models_)): |
|
|
214 |
res.append(tuple([torch.cat(tens) for tens in map(list, |
|
|
215 |
zip(*res_list[i]))]) if isinstance(res_list[i][0],tuple) else torch.cat(res_list[i])) |
|
|
216 |
|
|
|
217 |
return tuple(zip(*res)) |
|
|
218 |
|
|
|
219 |
|
|
|
220 |
|
|
|
221 |
def model_evaluate(model, |
|
|
222 |
validate_dataset, |
|
|
223 |
batch_size, |
|
|
224 |
loss_func, |
|
|
225 |
weights=None, |
|
|
226 |
metric=None, |
|
|
227 |
do_apex=False, |
|
|
228 |
num_workers=6): |
|
|
229 |
_=model.eval() |
|
|
230 |
model.no_grad() |
|
|
231 |
device = get_model_device(model) |
|
|
232 |
if do_apex and (device.type=='cuda'): |
|
|
233 |
model = amp.initialize(model, opt_level="O1",verbosity=0) |
|
|
234 |
data_loader=D.DataLoader(validate_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers) |
|
|
235 |
sum_loss = 0. |
|
|
236 |
lossf=None |
|
|
237 |
if metric: |
|
|
238 |
metric.zero() |
|
|
239 |
tq_batch = tqdm_notebook(data_loader,leave=True) |
|
|
240 |
for i,(batchs) in enumerate(tq_batch): |
|
|
241 |
x_batch=batchs[0].to(device) if len(batchs)==2 else [x.to(device) for x in batchs[:-1]] |
|
|
242 |
y_batch=batchs[-1] |
|
|
243 |
y_preds = model(x_batch) if not isinstance(x_batch,list) else model(*x_batch) |
|
|
244 |
if weights is None: |
|
|
245 |
loss = loss_func(y_preds,y_batch.to(device)) |
|
|
246 |
else: |
|
|
247 |
loss = loss_func(y_preds,y_batch.to(device),weights) |
|
|
248 |
sum_loss=sum_loss+loss.detach().item() |
|
|
249 |
if lossf: |
|
|
250 |
lossf = 0.98*lossf+0.02*loss.detach().item() |
|
|
251 |
else: |
|
|
252 |
lossf = loss.item() |
|
|
253 |
batch_postfix={'val_loss':lossf} |
|
|
254 |
if metric: |
|
|
255 |
if isinstance(y_preds,tuple): |
|
|
256 |
yp = tuple(y_preds[0].detach().cpu()) |
|
|
257 |
else: |
|
|
258 |
yp =y_preds.detach().cpu() |
|
|
259 |
batch_postfix.update(metric.calc(yp,y_batch.cpu().detach(),prefix='val_')) |
|
|
260 |
|
|
|
261 |
tq_batch.set_postfix(**batch_postfix) |
|
|
262 |
epoch_postfix={'val_loss':sum_loss/len(data_loader)} |
|
|
263 |
if metric: |
|
|
264 |
epoch_postfix.update(metric.calc_sums('val_')) |
|
|
265 |
|
|
|
266 |
return sum_loss/len(data_loader), epoch_postfix |
|
|
267 |
|
|
|
268 |
class loss_graph(): |
|
|
269 |
def __init__(self,fig,ax,num_epoch=1,batch2epoch=100,limits=None): |
|
|
270 |
self.num_epoch=num_epoch |
|
|
271 |
self.batch2epoch=batch2epoch |
|
|
272 |
self.loss_arr=np.zeros(num_epoch*batch2epoch,dtype=np.float) |
|
|
273 |
self.arr_size=num_epoch*batch2epoch |
|
|
274 |
self.num_points=0 |
|
|
275 |
self.fig=fig |
|
|
276 |
self.ax = ax |
|
|
277 |
self.limits=limits if limits is not None else (-1000,1000) |
|
|
278 |
self.ticks = (np.arange(0, num_epoch*batch2epoch+1, step=batch2epoch),np.arange(0, num_epoch+1, step=1)) |
|
|
279 |
def __call__(self,loss): |
|
|
280 |
if self.num_points==self.arr_size: |
|
|
281 |
new_arr=np.zeros(self.arr_size+self.batch2epoch,dtype=np.float) |
|
|
282 |
new_arr[:self.arr_size]=self.loss_arr |
|
|
283 |
self.loss_arr=new_arr |
|
|
284 |
self.loss_arr[self.num_points]=max(self.limits[0],min(self.limits[1],loss)) |
|
|
285 |
self.num_points=self.num_points+1 |
|
|
286 |
_=self.ax.clear() |
|
|
287 |
_=self.ax.plot(self.loss_arr[0:self.num_points]) |
|
|
288 |
_=self.ax.set_xlabel('batch') |
|
|
289 |
_=self.ax.set_ylabel('loss') |
|
|
290 |
_=plt.xticks(self.ticks[0],self.ticks[1]) |
|
|
291 |
_=self.fig.canvas.draw() |
|
|
292 |
|
|
|
293 |
|