Switch to unified view

a b/Serialized/helper/mytraining.py
1
import torch
2
import torch.nn as nn
3
import numpy as np
4
import torchvision
5
import torch.nn.functional as F
6
from torch.utils.data import Dataset, DataLoader
7
import math
8
from torch.optim import Optimizer
9
from torch.optim.optimizer import required
10
from torch.nn.utils import clip_grad_norm_
11
import logging
12
import abc
13
import sys
14
from tqdm import tqdm_notebook
15
import torch.utils.data as D
16
import torch.nn.functional as F
17
from apex import amp 
18
#from .mymodels import out_to_predict,out_to_predict_in,out_to_predict_test,out_to_predict_test_simple
19
import copy
20
from scipy.spatial import distance_matrix
21
import matplotlib.pyplot as plt
22
from .mymodels import mean_model
23
def get_model_device(model):
24
    p = next(model.parameters())
25
    if p.is_cuda:
26
        device = torch.device("cuda:{}".format(p.get_device()))
27
    else:
28
        device = torch.device('cpu')
29
    return device
30
    
31
def model_train(model,optimizer,train_dataset,batch_size,num_epochs,loss_func,
32
                weights=None,accumulation_steps=1,
33
                weights_func=None,do_apex=True,validate_dataset=None,
34
                validate_loss=None,metric=None,param_schedualer=None,
35
                weights_data=False,history=None,return_model=False,model_apexed=False,
36
                num_workers=7,sampler=None,graph=None,k_lossf=0.01,pre_process=None,
37
                call_progress=None,use_batchs=True,best_average=1):
38
    
39
    if history is None:
40
        history = []
41
    num_average_models=min(num_epochs,best_average)
42
    best_models=np.empty(num_average_models+1,dtype=object)
43
    best_val_loss=1e6*np.ones(num_average_models+1,dtype=np.float)
44
    device = get_model_device(model)
45
    if do_apex and not model_apexed and (device.type=='cuda'):
46
        model_apexed=True
47
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1",verbosity=0)
48
    model.zero_grad()
49
    tq_epoch=tqdm_notebook(range(num_epochs))
50
    lossf=None
51
    for epoch in tq_epoch:
52
        best_models[1:]=best_models[:num_average_models]
53
        best_val_loss[1:]=best_val_loss[:num_average_models]
54
        torch.cuda.empty_cache()
55
        model.do_grad()
56
        if param_schedualer:
57
            param_schedualer(epoch)
58
        _=model.train()
59
        batch_size_= batch_size if use_batchs else None
60
        if sampler:
61
            data_loader=D.DataLoader(D.Subset(train_dataset,sampler()),num_workers=num_workers,
62
                                     batch_size=batch_size if use_batchs else None,
63
                                     shuffle=use_batchs)
64
        else:
65
            data_loader=D.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)
66
        sum_loss = 0.
67
        if metric:
68
            metric.zero()
69
        tq_batch = tqdm_notebook(data_loader,leave=True)
70
        model.zero_grad()
71
        for i,(batchs) in enumerate(tq_batch):
72
            x_batch=batchs[0].to(device) if len(batchs)==2 else [x.to(device) for x in batchs[:-1]]  
73
            y_batch=batchs[-1].to(device)
74
            if pre_process is not None:
75
                x_batch,y_batch = pre_process(x_batch,y_batch)
76
            if weights_data:
77
                weights=x_batch[-1]
78
                x_batch=x_batch[:-1]
79
            if weights_func:
80
                weights=weights_func(weights,epoch,i)
81
            y_preds = model(x_batch) if not isinstance(x_batch,list) else model(*x_batch)  
82
83
            if weights is not None:
84
                loss = loss_func(y_preds,y_batch,weights=weights)/accumulation_steps
85
            else:
86
                loss = loss_func(y_preds,y_batch)/accumulation_steps
87
88
            if model_apexed:
89
                with amp.scale_loss(loss, optimizer) as scaled_loss:
90
                    scaled_loss.backward()
91
            else:
92
                loss.backward()
93
            if (i+1) % accumulation_steps == 0:             # Wait for several backward steps
94
                optimizer.step()                            # Now we can do an optimizer step
95
                model.zero_grad()
96
            
97
            if lossf:
98
                lossf = (1-k_lossf)*lossf+k_lossf*loss.detach().item()*accumulation_steps
99
            else:
100
                lossf = loss.detach().item()*accumulation_steps
101
            if graph is not None:
102
                graph(lossf)
103
            batch_postfix={'loss':lossf}
104
            if metric:
105
                if isinstance(y_preds,tuple):
106
                    yp = tuple(y_preds[0].detach().cpu())
107
                else:
108
                    yp =y_preds.detach().cpu()
109
                batch_postfix.update(metric.calc(yp,y_batch.cpu().detach()))
110
            tq_batch.set_postfix(**batch_postfix)
111
112
            sum_loss=sum_loss+loss.detach().item()*accumulation_steps
113
        
114
        epoch_postfix={'loss':sum_loss/len(data_loader)}
115
        if metric:
116
            epoch_postfix.update(metric.calc_sums())
117
        tq_epoch.set_postfix(**batch_postfix)
118
        history.append(batch_postfix)
119
120
        if validate_dataset:
121
            if validate_loss is None:
122
                vloss = loss_func
123
                val_weights = weights
124
            else:
125
                vloss = validate_loss
126
                val_weights =None
127
            res=model_evaluate(model,
128
                               validate_dataset,
129
                               batch_size = batch_size if use_batchs else None ,
130
                               loss_func=vloss,
131
                               weights=val_weights,
132
                               metric=metric,
133
                               do_apex=False,
134
                               num_workers=num_workers)
135
                                     
136
            history[-1].update(res[1])
137
            best_val_loss[0] = res[0]
138
            best_models[0] = copy.deepcopy(model).to('cpu')
139
            best_models[0].no_grad()
140
            best_models=best_models[np.argsort(best_val_loss)]
141
            best_val_loss=best_val_loss[np.argsort(best_val_loss)]
142
#            if res[0]<best_val_loss:
143
#                best_val_loss=res[0]
144
#                best_model=copy.deepcopy(model)
145
#                best_model.no_grad()
146
            tq_epoch.set_postfix(res[1])
147
        print(history[-1])
148
        if call_progress is not None:
149
            call_progress(history)
150
    if num_average_models>1:
151
        best_model=mean_model(best_models[:num_average_models])
152
        model=model.to('cpu')
153
        best_model=best_model.to(device)
154
        res=model_evaluate(best_model,
155
                               validate_dataset,
156
                               batch_size = batch_size if use_batchs else None ,
157
                               loss_func=vloss,
158
                               weights=val_weights,
159
                               metric=metric,
160
                               do_apex=False,
161
                               num_workers=num_workers)
162
        best_model=best_model.to('cpu')
163
        model=model.to(device)
164
        if res[0]>best_val_loss[0]:
165
            best_model=best_models[0]
166
            print (best_val_loss[0])
167
        else:
168
            print (res)
169
    else:
170
        best_model=best_models[0]
171
        print (best_val_loss)
172
    return (history,best_model) if return_model else history
173
174
175
176
177
def model_run(model,dataset,do_apex=True,batch_size=32,num_workers=6):
178
    _=model.eval()
179
    model.no_grad()
180
    device = get_model_device(model)
181
    if do_apex and (device.type=='cuda'):
182
        model = amp.initialize(model, opt_level="O1",verbosity=0)
183
    res_list=[]
184
    data_loader=D.DataLoader(dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers)
185
    for batchs in tqdm_notebook(data_loader):
186
        y_preds=model(*[x.to(device) for x in batchs]) if isinstance(batchs,tuple) else model(batchs.to(device))
187
        res_list.append(tuple([y.cpu() for y in y_preds]) if isinstance(y_preds,tuple) else y_preds.cpu())
188
    return tuple([torch.cat(tens) for tens in map(list, zip(*res_list))]) if isinstance(res_list[0],tuple) else torch.cat(res_list)
189
190
def models_run(models,dataset,do_apex=True,batch_size=32,num_workers=6):
191
    islist = isinstance(models,list)    
192
    if islist:
193
        models_=models
194
    else:
195
        models_=[models]
196
    for model in models_:
197
        _=model.eval()
198
        model.no_grad()
199
    device = get_model_device(models_[0])
200
     
201
    if do_apex and (device.type=='cuda'):
202
        for model in models_:
203
            model = amp.initialize(model, opt_level="O1",verbosity=0)
204
    res_list=[]
205
    for model in models_:        
206
        res_list.append([])
207
    data_loader=D.DataLoader(dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers)
208
    for batchs in tqdm_notebook(data_loader):
209
        for i,model in enumerate(models_):
210
            y_preds=model(*[x.to(device) for x in batchs]) if isinstance(batchs,tuple) else model(batchs.to(device))
211
            res_list[i].append(tuple([y.cpu().detach() for y in y_preds]) if isinstance(y_preds,tuple) else y_preds.cpu().detach())
212
    res=[]
213
    for i in range(len(models_)):
214
        res.append(tuple([torch.cat(tens) for tens in map(list, 
215
                                                   zip(*res_list[i]))]) if isinstance(res_list[i][0],tuple) else torch.cat(res_list[i]))
216
217
    return tuple(zip(*res)) 
218
219
220
221
def model_evaluate(model,
222
                   validate_dataset,
223
                   batch_size,
224
                   loss_func,
225
                   weights=None,
226
                   metric=None,
227
                   do_apex=False,
228
                   num_workers=6):
229
    _=model.eval()
230
    model.no_grad()
231
    device = get_model_device(model)
232
    if do_apex and (device.type=='cuda'):
233
        model = amp.initialize(model, opt_level="O1",verbosity=0)
234
    data_loader=D.DataLoader(validate_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers)
235
    sum_loss = 0.
236
    lossf=None
237
    if metric:
238
        metric.zero()
239
    tq_batch = tqdm_notebook(data_loader,leave=True)
240
    for i,(batchs) in enumerate(tq_batch):
241
        x_batch=batchs[0].to(device) if len(batchs)==2 else [x.to(device) for x in batchs[:-1]]  
242
        y_batch=batchs[-1]
243
        y_preds = model(x_batch) if not isinstance(x_batch,list) else model(*x_batch)
244
        if weights is None:
245
            loss = loss_func(y_preds,y_batch.to(device))
246
        else:
247
            loss = loss_func(y_preds,y_batch.to(device),weights)
248
        sum_loss=sum_loss+loss.detach().item()
249
        if lossf:
250
            lossf = 0.98*lossf+0.02*loss.detach().item()
251
        else:
252
            lossf = loss.item()
253
        batch_postfix={'val_loss':lossf}
254
        if metric:
255
            if isinstance(y_preds,tuple):
256
                yp = tuple(y_preds[0].detach().cpu())
257
            else:
258
                yp =y_preds.detach().cpu()
259
            batch_postfix.update(metric.calc(yp,y_batch.cpu().detach(),prefix='val_'))
260
261
        tq_batch.set_postfix(**batch_postfix)
262
    epoch_postfix={'val_loss':sum_loss/len(data_loader)}
263
    if metric:
264
        epoch_postfix.update(metric.calc_sums('val_'))
265
                                     
266
    return sum_loss/len(data_loader), epoch_postfix
267
268
class loss_graph():
269
    def __init__(self,fig,ax,num_epoch=1,batch2epoch=100,limits=None):
270
        self.num_epoch=num_epoch
271
        self.batch2epoch=batch2epoch
272
        self.loss_arr=np.zeros(num_epoch*batch2epoch,dtype=np.float)
273
        self.arr_size=num_epoch*batch2epoch
274
        self.num_points=0
275
        self.fig=fig
276
        self.ax = ax
277
        self.limits=limits if limits is not None else (-1000,1000)
278
        self.ticks = (np.arange(0, num_epoch*batch2epoch+1, step=batch2epoch),np.arange(0, num_epoch+1, step=1))
279
    def __call__(self,loss):
280
        if self.num_points==self.arr_size:
281
            new_arr=np.zeros(self.arr_size+self.batch2epoch,dtype=np.float)
282
            new_arr[:self.arr_size]=self.loss_arr
283
            self.loss_arr=new_arr
284
        self.loss_arr[self.num_points]=max(self.limits[0],min(self.limits[1],loss))
285
        self.num_points=self.num_points+1
286
        _=self.ax.clear()
287
        _=self.ax.plot(self.loss_arr[0:self.num_points])
288
        _=self.ax.set_xlabel('batch')
289
        _=self.ax.set_ylabel('loss')
290
        _=plt.xticks(self.ticks[0],self.ticks[1])
291
        _=self.fig.canvas.draw()
292
293