Switch to unified view

a b/Serialized/helper/mymodels.py
1
import torch
2
import torch.nn as nn
3
import numpy as np
4
import torchvision
5
import torch.nn.functional as F
6
import math
7
import copy
8
import collections
9
from pytorchcv.model_provider import get_model as ptcv_get_model
10
from pytorchcv.models.common import conv3x3_block
11
import pretrainedmodels
12
class Flatten(nn.Module):
13
    def forward(self, input):
14
        return input.view(input.size(0), -1)
15
16
def l2_norm(input,axis=1):
17
    norm = torch.norm(input,2,axis,True)
18
    output = torch.div(input, norm)
19
    return output 
20
21
class Window(nn.Module):
22
    def forward(self, x):
23
        return torch.clamp(x,0,1)
24
25
class ArcMarginProduct(nn.Module):
26
    r"""Implement of large margin arc distance: :
27
        Args:
28
            in_features: size of each input sample
29
            out_features: size of each output sample
30
            s: norm of input feature
31
            m: margin
32
            cos(theta + m)
33
        """
34
    def __init__(self, in_features, out_features,weights=None):
35
        super(ArcMarginProduct, self).__init__()
36
        if weights is None:
37
            self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
38
            self.reset_parameters()
39
        else:
40
            self.weight = nn.Parameter(weights)
41
42
    def reset_parameters(self):
43
        stdv = 1. / math.sqrt(self.weight.size(1))
44
        self.weight.data.uniform_(-stdv, stdv)
45
#        self.k.data=torch.ones(1,dtype=torch.float)
46
47
    def forward(self, features):
48
        cosine = F.linear(l2_norm(features), l2_norm(self.weight))
49
        return cosine
50
51
class ArcClassifier(nn.Module):
52
    def __init__(self,in_features, out_features,weights=None):
53
        super(ArcClassifier, self).__init__()
54
        self.classifier = ArcMarginProduct(in_features, out_features,weights=weights)
55
        self.dropout1=nn.Dropout(p=0.5, inplace=True)
56
        
57
    def forward(self, x,eq):
58
        out = self.dropout1(x-eq)
59
        out = self.classifier(out)
60
        return out
61
62
    def no_grad(self):
63
        for param in self.parameters():
64
            param.requires_grad=False
65
66
    def do_grad(self):
67
        for param in self.parameters():
68
            param.requires_grad=True
69
70
71
class MyDenseNet(nn.Module):
72
    def __init__(self,model,
73
                 num_classes,
74
                 num_channels=1,
75
                 strategy='copy',
76
                 add_noise=0.,
77
                 dropout=0.5,
78
                 arcface=False,
79
                 return_features=False,
80
                 norm=False,
81
                 intermediate=0,
82
                 extra_pool=1,
83
                 pool_type='max',
84
                 wso=None,
85
                 dont_do_grad=['wso'],
86
                 do_bn=False):
87
        super(MyDenseNet, self).__init__()
88
        self.features= torch.nn.Sequential()
89
        self.num_channels=num_channels
90
        self.dont_do_grad=dont_do_grad
91
        self.pool_type=pool_type
92
        self.norm=norm
93
        self.return_features=return_features
94
        self.num_classes=num_classes
95
        self.extra_pool=extra_pool
96
        if wso is not None:
97
            conv_ = nn.Conv2d(1,self.num_channels, kernel_size=(1, 1))
98
            if hasattr(wso, '__iter__'):
99
                conv_.weight.data.copy_(torch.tensor([[[[1./wso[0][1]]]],[[[1./wso[1][1]]]],[[[1./wso[2][1]]]]]))
100
                conv_.bias.data.copy_(torch.tensor([0.5 - wso[0][0]/wso[0][1],
101
                                                    0.5 - wso[1][0]/wso[1][1],
102
                                                    0.5 -wso[2][0]/wso[2][1]]))
103
104
            self.features.add_module('wso_conv',conv_)
105
            self.features.add_module('wso_window',nn.Sigmoid())
106
            if do_bn:
107
                self.features.add_module('wso_norm',nn.BatchNorm2d(self.num_channels))
108
            else:
109
                self.features.add_module('wso_norm',nn.InstanceNorm2d(self.num_channels))
110
        if (strategy == 'copy') or (num_channels!=3):
111
            base = list(list(model.children())[0].named_children())[1:]
112
            conv0 = model.state_dict()['features.conv0.weight']
113
            new_conv=nn.Conv2d(self.num_channels, conv0.shape[0], kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
114
            a=(np.arange(3*(self.num_channels//3+1),dtype=np.int)%3)
115
            np.random.shuffle(a)
116
            for i in range(self.num_channels):
117
                new_conv.state_dict()['weight'][:,i,:,:]=conv0.clone()[:,a[i],:,:]*(1.0+torch.randn_like(conv0[:,a[i],:,:])*add_noise)
118
            self.features.add_module('conv0',new_conv)
119
        else:
120
            base = list(list(model.children())[0].named_children())
121
        for (n,l) in base:
122
            self.features.add_module(n,l)
123
        if intermediate==0:
124
            self.num_features=list(model.children())[-1].in_features
125
            self.intermediate=None
126
        else:
127
            self.num_features=intermediate
128
            self.intermediate=nn.Linear(list(model.children())[-1].in_features, self.num_features)
129
        self.dropout1=nn.Dropout(p=dropout, inplace=True)
130
        if arcface:
131
            self.classifier=ArcMarginProduct(self.num_features, num_classes)
132
        else:
133
            self.classifier = nn.Linear(self.num_features//self.extra_pool, self.num_classes)
134
        
135
    def forward(self, x):
136
        x = self.features(x)
137
        x = F.relu(x, inplace=True)
138
        if self.pool_type=='avg':
139
            x = F.avg_pool3d(x.unsqueeze(1), kernel_size=(self.extra_pool,)+x.size()[2:]).view(x.size(0), -1)
140
        else:
141
            x = F.max_pool3d(x.unsqueeze(1), kernel_size=(self.extra_pool,)+x.size()[2:]).view(x.size(0), -1)
142
#        x = F.max_pool1d(x.view(x.unsqueeze(1),self.extra_pool).squeeze()
143
        x = self.dropout1(x)
144
        if self.intermediate is not None:
145
            x = self.intermediate(x)
146
            x = F.relu(x)
147
        features = x
148
        if self.norm:
149
            features = l2_norm(features,axis=1)
150
        out = self.classifier(features)
151
        return out if not self.return_features else (out,features)
152
    
153
    def parameter_scheduler(self,epoch):
154
        do_first=['classifier','wso']
155
        if epoch>0:
156
            for n,p in self.named_parameters():
157
                p.requires_grad=True
158
        else:
159
            for n,p in self.named_parameters():
160
                p.requires_grad= any(nd in n for nd in do_first)
161
                
162
    def no_grad(self):
163
        for param in self.parameters():
164
            param.requires_grad=False
165
166
    def do_grad(self):
167
        for n,p in self.named_parameters():
168
            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)
169
            
170
    def get_optimizer_parameters(self,klr):
171
        zero_layer=['conv0','norm0','ws_norm']
172
        param_optimizer = list(self.named_parameters())
173
        num_blocks=4
174
        no_decay=['bias']
175
        optimizer_grouped_parameters=[
176
            {'params': 
177
             [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and any(nd in n for nd in zero_layer))], 
178
                                                                                                     'lr':klr*2e-5,'weight_decay': 0.01},
179
            {'params': 
180
             [p for n, p in param_optimizer if any(nd in n for nd in no_decay)  and any(nd in n for nd in zero_layer)], 
181
                                                                                                     'lr':klr*2e-5, 'weight_decay': 0.0}]
182
        optimizer_grouped_parameters.extend([
183
            {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and ('wso' in n))], 
184
                                                                                                     'lr':klr*1e-5,'weight_decay': 0.01},
185
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)  and ('wso' in n)],
186
                                                                                                     'lr':klr*1e-5, 'weight_decay': 0.0}])
187
        optimizer_grouped_parameters.extend([
188
            {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and ('classifier' in n))], 
189
                                                                                                       'lr':klr*1e-3,'weight_decay': 0.01},
190
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)  and ('classifier' in n)], 
191
                                                                                                     'lr':klr*1e-3, 'weight_decay': 0.0}])
192
        for i in range(num_blocks):
193
            optimizer_grouped_parameters.extend([
194
                {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and ('denseblock{}'.format(i+1) in n))], 
195
                                                                                             'lr':klr*(2.0**i)*2e-5,'weight_decay': 0.01},
196
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)  and ('denseblock{}'.format(i+1) in n)], 
197
                                                                                             'lr':klr*(2.0**i)*2e-5, 'weight_decay': 0.0}])
198
        optimizer_grouped_parameters.extend([
199
            {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and ('norm5' in n))],
200
                                                                                             'lr':klr*1e-4,'weight_decay': 0.01},
201
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)  and ('norm5' in n)], 
202
                                                                                              'lr':klr*1e-4, 'weight_decay': 0.0}])
203
        return(optimizer_grouped_parameters)
204
205
206
207
class MySENet(nn.Module):
208
    def __init__(self,model,
209
                 num_classes,
210
                 num_channels=3,
211
                 dropout=0.2,
212
                 return_features=False,
213
                 wso=None,
214
                 full_copy=False,
215
                 dont_do_grad=['wso'],
216
                 extra_pool=1,
217
                 do_bn=False):
218
        super(MySENet, self).__init__()
219
        self.num_classes=num_classes
220
        self.return_features=return_features
221
        self.num_channels = num_channels
222
        self.features= torch.nn.Sequential()
223
        self.extra_pool=extra_pool
224
        self.dont_do_grad=dont_do_grad
225
        if full_copy:
226
            for (n,l) in list(list(model.children())[0].named_children()):
227
                self.features.add_module(n,l)
228
            if wso is not None:
229
                self.dont_do_grad=model.dont_do_grad
230
        else:
231
            if wso is not None:
232
                conv_ = nn.Conv2d(1,self.num_channels, kernel_size=(1, 1))
233
                if hasattr(wso, '__iter__'):
234
                    conv_.weight.data.copy_(torch.tensor([[[[1./wso[0][1]]]],[[[1./wso[1][1]]]],[[[1./wso[2][1]]]]]))
235
                    conv_.bias.data.copy_(torch.tensor([0.5 - wso[0][0]/wso[0][1],
236
                                                        0.5 - wso[1][0]/wso[1][1],
237
                                                        0.5 -wso[2][0]/wso[2][1]]))
238
239
                self.features.add_module('wso_conv',conv_)
240
                self.features.add_module('wso_relu',nn.Sigmoid())
241
                if do_bn:
242
                    self.features.add_module('wso_norm',nn.BatchNorm2d(self.num_channels))
243
                else:
244
                    self.features.add_module('wso_norm',nn.InstanceNorm2d(self.num_channels))
245
246
#            layer0= torch.nn.Sequential()
247
#            layer0.add_module('conv1',model.conv1)
248
#            layer0.add_module('bn1',model.bn1)                        
249
            se_layers={'layer0':model.layer0,
250
                       'layer1':model.layer1,
251
                       'layer2':model.layer2,
252
                       'layer3':model.layer3,
253
                       'layer4':model.layer4}
254
            for key in se_layers:
255
                self.features.add_module(key,se_layers[key])
256
        self.dropout = dropout if dropout is None else nn.Dropout(p=dropout, inplace=True)
257
        self.classifier=nn.Linear(model.last_linear.in_features//self.extra_pool, self.num_classes)
258
        
259
        
260
    def forward(self, x):
261
        x = self.features(x)
262
        x = F.max_pool3d(x.unsqueeze(1), kernel_size=(self.extra_pool,)+x.size()[2:]).view(x.size(0), -1)
263
        if self.dropout is not None:
264
            x = self.dropout(x) 
265
        features = x
266
        out = self.classifier(features)
267
        return out if not self.return_features else (out,features) 
268
    
269
    def parameter_scheduler(self,epoch):
270
        do_first=['classifier']
271
        if epoch>0:
272
            for n,p in self.named_parameters():
273
                p.requires_grad=True
274
        else:
275
            for n,p in self.named_parameters():
276
                p.requires_grad= any(nd in n for nd in do_first)
277
                
278
    def no_grad(self):
279
        for param in self.parameters():
280
            param.requires_grad=False
281
282
283
    def do_grad(self):
284
        for n,p in self.named_parameters():
285
            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)
286
            
287
    def get_optimizer_parameters(self,klr):
288
        param_optimizer = list(self.named_parameters())
289
        num_blocks=5
290
        no_decay=['bias']
291
        optimizer_grouped_parameters=[
292
            {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and ('classifier' in n))],
293
                                                                                                     'lr':klr*2e-4,'weight_decay': 0.01},
294
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)  and ('classifier' in n)], 
295
                                                                                                     'lr':klr*2e-4, 'weight_decay': 0.0}]
296
        optimizer_grouped_parameters.extend([
297
            {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and ('wso' in n))], 
298
                                                                                                     'lr':klr*5e-6,'weight_decay': 0.01},
299
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)  and ('wso' in n)],
300
                                                                                                     'lr':klr*5e-6, 'weight_decay': 0.0}])
301
        for i in range(num_blocks):
302
            optimizer_grouped_parameters.extend([
303
            {'params': [p for n, p in param_optimizer if (not any(nd in n for nd in no_decay) and ('layer{}'.format(i) in n))],
304
                                                                                             'lr':klr*(2.0**i)*1e-5,'weight_decay': 0.01},
305
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)  and ('layer{}'.format(i) in n)],
306
                                                                                             'lr':klr*(2.0**i)*1e-5, 'weight_decay': 0.0}])
307
        return(optimizer_grouped_parameters)
308
309
310
311
class MyEfficientNet(nn.Module):
312
    def __init__(self,model,num_classes,num_channels=3,dropout=0.5,return_features=False,wso=True,
313
                 full_copy=False,
314
                 dont_do_grad=['wso'],
315
                 extra_pool=1,
316
                 num_features=None):
317
        super(MyEfficientNet, self).__init__()
318
        self.num_classes=num_classes
319
        self.return_features=return_features
320
        self.num_channels = num_channels
321
        self.features= torch.nn.Sequential()
322
        self.extra_pool=extra_pool
323
        self.dont_do_grad=dont_do_grad
324
        if full_copy:
325
            for (n,l) in list(list(model.children())[0].named_children()):
326
                self.features.add_module(n,l)
327
            if wso is not None:
328
                self.dont_do_grad=model.dont_do_grad
329
330
        else:
331
            if wso is not None:
332
                conv_ = nn.Conv2d(1,self.num_channels, kernel_size=(1, 1))
333
                if hasattr(wso, '__iter__'):
334
                    conv_.weight.data.copy_(torch.tensor([[[[1./wso[0][1]]]],[[[1./wso[1][1]]]],[[[1./wso[2][1]]]]]))
335
                    conv_.bias.data.copy_(torch.tensor([0.5 - wso[0][0]/wso[0][1],
336
                                                        0.5 - wso[1][0]/wso[1][1],
337
                                                        0.5 -wso[2][0]/wso[2][1]]))
338
339
                self.features.add_module('wso_conv',conv_)
340
                self.features.add_module('wso_relu',nn.Sigmoid())
341
                self.features.add_module('wso_norm',nn.InstanceNorm2d(self.num_channels))
342
            for (n,l) in list(list(model.children())[0].named_children()):
343
                self.features.add_module(n,l)
344
        self.dropout = dropout if dropout is None else nn.Dropout(p=dropout, inplace=True)
345
        if num_features is None:
346
            self.classifier=nn.Linear(model.output.fc.in_features//self.extra_pool, self.num_classes)
347
        else:
348
            self.classifier=nn.Linear(num_features, self.num_classes)
349
        
350
        
351
    def forward(self, x):
352
        x = self.features(x)
353
        x = F.avg_pool2d(x, kernel_size=x.size(-1)).view(x.size(0), -1)
354
        if self.extra_pool>1:
355
            x = x.view(x.shape[0],x.shape[1]//self.extra_pool,self.extra_pool).mean(-1)
356
        if self.dropout is not None:
357
            x = self.dropout(x)
358
        features = x
359
        out = self.classifier(features)
360
        return out if not self.return_features else (out,features) 
361
    
362
    def parameter_scheduler(self,epoch):
363
        do_first=['classifier']
364
        if epoch>0:
365
            for n,p in self.named_parameters():
366
                p.requires_grad=True
367
        else:
368
            for n,p in self.named_parameters():
369
                p.requires_grad= any(nd in n for nd in do_first)
370
                
371
    def no_grad(self):
372
        for param in self.parameters():
373
            param.requires_grad=False
374
375
    def do_grad(self):
376
        for n,p in self.named_parameters():
377
            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)
378
379
380
class NeighborsNet(nn.Module):
381
    def __init__(self,num_classes,num_features=1024,num_neighbors=1,classifier_layer=None,intermidiate=None,dropout=0.2):
382
        super(NeighborsNet, self).__init__()
383
        self.num_classes=num_classes
384
        if classifier_layer is not None:
385
            self.num_features = classifier_layer.in_features
386
        else:
387
            self.num_features=num_features
388
        self.num_neighbors=num_neighbors
389
        layers=collections.OrderedDict()
390
        if dropout>0:
391
            layers['dropout']=nn.Dropout(p=dropout)
392
393
        if intermidiate is not None:
394
            layers['intermidiate']=nn.Linear(self.num_features*(2*self.num_neighbors+1), intermidiate)
395
            layers['relu']=nn.ReLU()
396
            layers['classifier']=nn.Linear(intermidiate, self.num_classes)
397
        else:
398
            layers['classifier']=nn.Linear(self.num_features*(2*self.num_neighbors+1), self.num_classes)
399
        if (classifier_layer is not None) and (intermidiate is None):
400
            _=layers['classifier'].bias.data.copy_((1.0+0.2*self.num_neighbors)*classifier_layer.bias.data)
401
            d = torch.cat([0.1*classifier_layer.weight.data for i in range(self.num_neighbors)]+\
402
                             [classifier_layer.weight.data]+\
403
                             [0.1*classifier_layer.weight.data for i in range(self.num_neighbors)],dim=1)
404
            _=layers['classifier'].weight.data.copy_(d)
405
        self.network= torch.nn.Sequential(layers)
406
407
        
408
    def forward(self, x):
409
        x = x.view((x.shape[0],-1))
410
        return self.network(x) 
411
    
412
    def parameter_scheduler(self,epoch):
413
        do_first=['classifier']
414
        if epoch>0:
415
            for n,p in self.named_parameters():
416
                p.requires_grad=True
417
        else:
418
            for n,p in self.named_parameters():
419
                p.requires_grad= any(nd in n for nd in do_first)
420
                
421
    def no_grad(self):
422
        for param in self.parameters():
423
            param.requires_grad=False
424
425
    def do_grad(self):
426
        for param in self.parameters():
427
            param.requires_grad=True
428
429
class ResModelPool(nn.Module):
430
    def __init__(self,in_size):
431
        super(ResModelPool, self).__init__()
432
        self.dont_do_grad=[]
433
        self.conv2d1=torch.nn.Conv2d(1, 64, (9,in_size),stride=(1,in_size), padding=(4,0))
434
        self.bn0=torch.nn.BatchNorm1d(64)
435
#        self.relu0=torch.nn.ReLU()
436
        self.conv1d1=torch.nn.Conv1d(64, 64, 7, padding=3)
437
        self.bn1=torch.nn.BatchNorm1d(64)
438
        self.relu1=torch.nn.ReLU()
439
        self.conv1d2=torch.nn.Conv1d(128, 64, 5, padding=2)
440
        self.bn2=torch.nn.BatchNorm1d(64)
441
        self.relu2=torch.nn.ReLU()
442
        self.conv1d3=torch.nn.Conv1d(192, 6, 3, padding=1)
443
        
444
        
445
    def forward(self, x):
446
        x=x.unsqueeze(1)
447
        x = self.conv2d1(x)
448
        x=F.max_pool2d(x,kernel_size=(1,x.shape[-1])).squeeze(-1)        
449
        x0 = self.bn0(x)
450
#        x0 = self.relu0(x)
451
        x = self.conv1d1(x0)
452
        x = self.bn1(x)
453
        x1 = self.relu1(x)
454
        x = torch.cat([x0,x1],1)
455
        x = self.conv1d2(x)
456
        x = self.bn2(x)
457
        x2 = self.relu2(x)
458
        x = torch.cat([x0,x1,x2],1)
459
        out = self.conv1d3(x).transpose(-1,-2)
460
        return out 
461
                    
462
    def no_grad(self):
463
        for param in self.parameters():
464
            param.requires_grad=False
465
466
    def do_grad(self):
467
        for n,p in self.named_parameters():
468
            p.requires_grad=  not any(nd in n for nd in self.dont_do_grad)
469
470
def mean_model(models):
471
    model = copy.deepcopy(models[0])
472
    params=[]
473
    for model_ in models:
474
        params.append(dict(model_.named_parameters()))
475
476
    param_dict=dict(model.named_parameters())
477
478
    for name in param_dict.keys():
479
        _=param_dict[name].data.copy_(torch.cat([param[name].data[...,None] for param in params],-1).mean(-1))
480
    return model