Switch to unified view

a b/landmark_extraction/models/yolo.py
1
import argparse
2
import logging
3
import sys
4
from copy import deepcopy
5
6
sys.path.append('./')  # to run '$ python *.py' files in subdirectories
7
logger = logging.getLogger(__name__)
8
import torch
9
from models.common import *
10
from models.experimental import *
11
from utils.autoanchor import check_anchor_order
12
from utils.general import make_divisible, check_file, set_logging
13
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
14
    select_device, copy_attr
15
from utils.loss import SigmoidBin
16
17
try:
18
    import thop  # for FLOPS computation
19
except ImportError:
20
    thop = None
21
22
23
class Detect(nn.Module):
24
    stride = None  # strides computed during build
25
    export = False  # onnx export
26
    end2end = False
27
    include_nms = False 
28
29
    def __init__(self, nc=80, anchors=(), ch=()):  # detection layer
30
        super(Detect, self).__init__()
31
        self.nc = nc  # number of classes
32
        self.no = nc + 5  # number of outputs per anchor
33
        self.nl = len(anchors)  # number of detection layers
34
        self.na = len(anchors[0]) // 2  # number of anchors
35
        self.grid = [torch.zeros(1)] * self.nl  # init grid
36
        a = torch.tensor(anchors).float().view(self.nl, -1, 2)
37
        self.register_buffer('anchors', a)  # shape(nl,na,2)
38
        self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2))  # shape(nl,1,na,1,1,2)
39
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
40
41
    def forward(self, x):
42
        # x = x.copy()  # for profiling
43
        z = []  # inference output
44
        self.training |= self.export
45
        for i in range(self.nl):
46
            x[i] = self.m[i](x[i])  # conv
47
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
48
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
49
50
            if not self.training:  # inference
51
                if self.grid[i].shape[2:4] != x[i].shape[2:4]:
52
                    self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
53
                y = x[i].sigmoid()
54
                if not torch.onnx.is_in_onnx_export():
55
                    y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
56
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
57
                else:
58
                    xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
59
                    wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
60
                    y = torch.cat((xy, wh, y[..., 4:]), -1)
61
                z.append(y.view(bs, -1, self.no))
62
63
        if self.training:
64
            out = x
65
        elif self.end2end:
66
            out = torch.cat(z, 1)
67
        elif self.include_nms:
68
            z = self.convert(z)
69
            out = (z, )
70
        else:
71
            out = (torch.cat(z, 1), x)
72
73
        return out
74
75
    @staticmethod
76
    def _make_grid(nx=20, ny=20):
77
        yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
78
        return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
79
80
    def convert(self, z):
81
        z = torch.cat(z, 1)
82
        box = z[:, :, :4]
83
        conf = z[:, :, 4:5]
84
        score = z[:, :, 5:]
85
        score *= conf
86
        convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
87
                                           dtype=torch.float32,
88
                                           device=z.device)
89
        box @= convert_matrix                          
90
        return (box, score)
91
92
93
class IDetect(nn.Module):
94
    stride = None  # strides computed during build
95
    export = False  # onnx export
96
    end2end = False
97
    include_nms = False 
98
99
    def __init__(self, nc=80, anchors=(), ch=()):  # detection layer
100
        super(IDetect, self).__init__()
101
        self.nc = nc  # number of classes
102
        self.no = nc + 5  # number of outputs per anchor
103
        self.nl = len(anchors)  # number of detection layers
104
        self.na = len(anchors[0]) // 2  # number of anchors
105
        self.grid = [torch.zeros(1)] * self.nl  # init grid
106
        a = torch.tensor(anchors).float().view(self.nl, -1, 2)
107
        self.register_buffer('anchors', a)  # shape(nl,na,2)
108
        self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2))  # shape(nl,1,na,1,1,2)
109
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
110
        
111
        self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
112
        self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch)
113
114
    def forward(self, x):
115
        # x = x.copy()  # for profiling
116
        z = []  # inference output
117
        self.training |= self.export
118
        for i in range(self.nl):
119
            x[i] = self.m[i](self.ia[i](x[i]))  # conv
120
            x[i] = self.im[i](x[i])
121
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
122
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
123
124
            if not self.training:  # inference
125
                if self.grid[i].shape[2:4] != x[i].shape[2:4]:
126
                    self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
127
128
                y = x[i].sigmoid()
129
                y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
130
                y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
131
                z.append(y.view(bs, -1, self.no))
132
133
        return x if self.training else (torch.cat(z, 1), x)
134
    
135
    def fuseforward(self, x):
136
        # x = x.copy()  # for profiling
137
        z = []  # inference output
138
        self.training |= self.export
139
        for i in range(self.nl):
140
            x[i] = self.m[i](x[i])  # conv
141
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
142
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
143
144
            if not self.training:  # inference
145
                if self.grid[i].shape[2:4] != x[i].shape[2:4]:
146
                    self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
147
148
                y = x[i].sigmoid()
149
                y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
150
                y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
151
                z.append(y.view(bs, -1, self.no))
152
153
        if self.training:
154
            out = x
155
        elif self.end2end:
156
            out = torch.cat(z, 1)
157
        elif self.include_nms:
158
            z = self.convert(z)
159
            out = (z, )
160
        else:
161
            out = (torch.cat(z, 1), x)
162
163
        return out
164
    
165
    def fuse(self):
166
        print("IDetect.fuse")
167
        # fuse ImplicitA and Convolution
168
        for i in range(len(self.m)):
169
            c1,c2,_,_ = self.m[i].weight.shape
170
            c1_,c2_, _,_ = self.ia[i].implicit.shape
171
            self.m[i].bias += torch.matmul(self.m[i].weight.reshape(c1,c2),self.ia[i].implicit.reshape(c2_,c1_)).squeeze(1)
172
173
        # fuse ImplicitM and Convolution
174
        for i in range(len(self.m)):
175
            c1,c2, _,_ = self.im[i].implicit.shape
176
            self.m[i].bias *= self.im[i].implicit.reshape(c2)
177
            self.m[i].weight *= self.im[i].implicit.transpose(0,1)
178
            
179
    @staticmethod
180
    def _make_grid(nx=20, ny=20):
181
        yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
182
        return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
183
184
    def convert(self, z):
185
        z = torch.cat(z, 1)
186
        box = z[:, :, :4]
187
        conf = z[:, :, 4:5]
188
        score = z[:, :, 5:]
189
        score *= conf
190
        convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
191
                                           dtype=torch.float32,
192
                                           device=z.device)
193
        box @= convert_matrix                          
194
        return (box, score)
195
196
197
class IKeypoint(nn.Module):
198
    stride = None  # strides computed during build
199
    export = False  # onnx export
200
201
    def __init__(self, nc=80, anchors=(), nkpt=17, ch=(), inplace=True, dw_conv_kpt=False):  # detection layer
202
        super(IKeypoint, self).__init__()
203
        self.nc = nc  # number of classes
204
        self.nkpt = nkpt
205
        self.dw_conv_kpt = dw_conv_kpt
206
        self.no_det=(nc + 5)  # number of outputs per anchor for box and class
207
        self.no_kpt = 3*self.nkpt ## number of outputs per anchor for keypoints
208
        self.no = self.no_det+self.no_kpt
209
        self.nl = len(anchors)  # number of detection layers
210
        self.na = len(anchors[0]) // 2  # number of anchors
211
        self.grid = [torch.zeros(1)] * self.nl  # init grid
212
        self.flip_test = False
213
        a = torch.tensor(anchors).float().view(self.nl, -1, 2)
214
        self.register_buffer('anchors', a)  # shape(nl,na,2)
215
        self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2))  # shape(nl,1,na,1,1,2)
216
        self.m = nn.ModuleList(nn.Conv2d(x, self.no_det * self.na, 1) for x in ch)  # output conv
217
        
218
        self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
219
        self.im = nn.ModuleList(ImplicitM(self.no_det * self.na) for _ in ch)
220
        
221
        if self.nkpt is not None:
222
            if self.dw_conv_kpt: #keypoint head is slightly more complex
223
                self.m_kpt = nn.ModuleList(
224
                            nn.Sequential(DWConv(x, x, k=3), Conv(x,x),
225
                                          DWConv(x, x, k=3), Conv(x, x),
226
                                          DWConv(x, x, k=3), Conv(x,x),
227
                                          DWConv(x, x, k=3), Conv(x, x),
228
                                          DWConv(x, x, k=3), Conv(x, x),
229
                                          DWConv(x, x, k=3), nn.Conv2d(x, self.no_kpt * self.na, 1)) for x in ch)
230
            else: #keypoint head is a single convolution
231
                self.m_kpt = nn.ModuleList(nn.Conv2d(x, self.no_kpt * self.na, 1) for x in ch)
232
233
        self.inplace = inplace  # use in-place ops (e.g. slice assignment)
234
235
    def forward(self, x):
236
        # x = x.copy()  # for profiling
237
        z = []  # inference output
238
        self.training |= self.export
239
        for i in range(self.nl):
240
            if self.nkpt is None or self.nkpt==0:
241
                x[i] = self.im[i](self.m[i](self.ia[i](x[i])))  # conv
242
            else :
243
                x[i] = torch.cat((self.im[i](self.m[i](self.ia[i](x[i]))), self.m_kpt[i](x[i])), axis=1)
244
245
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
246
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
247
            x_det = x[i][..., :6]
248
            x_kpt = x[i][..., 6:]
249
250
            if not self.training:  # inference
251
                if self.grid[i].shape[2:4] != x[i].shape[2:4]:
252
                    self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
253
                kpt_grid_x = self.grid[i][..., 0:1]
254
                kpt_grid_y = self.grid[i][..., 1:2]
255
256
                if self.nkpt == 0:
257
                    y = x[i].sigmoid()
258
                else:
259
                    y = x_det.sigmoid()
260
261
                if self.inplace:
262
                    xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
263
                    wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
264
                    if self.nkpt != 0:
265
                        x_kpt[..., 0::3] = (x_kpt[..., ::3] * 2. - 0.5 + kpt_grid_x.repeat(1,1,1,1,17)) * self.stride[i]  # xy
266
                        x_kpt[..., 1::3] = (x_kpt[..., 1::3] * 2. - 0.5 + kpt_grid_y.repeat(1,1,1,1,17)) * self.stride[i]  # xy
267
                        #x_kpt[..., 0::3] = (x_kpt[..., ::3] + kpt_grid_x.repeat(1,1,1,1,17)) * self.stride[i]  # xy
268
                        #x_kpt[..., 1::3] = (x_kpt[..., 1::3] + kpt_grid_y.repeat(1,1,1,1,17)) * self.stride[i]  # xy
269
                        #print('=============')
270
                        #print(self.anchor_grid[i].shape)
271
                        #print(self.anchor_grid[i][...,0].unsqueeze(4).shape)
272
                        #print(x_kpt[..., 0::3].shape)
273
                        #x_kpt[..., 0::3] = ((x_kpt[..., 0::3].tanh() * 2.) ** 3 * self.anchor_grid[i][...,0].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_x.repeat(1,1,1,1,17) * self.stride[i]  # xy
274
                        #x_kpt[..., 1::3] = ((x_kpt[..., 1::3].tanh() * 2.) ** 3 * self.anchor_grid[i][...,1].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_y.repeat(1,1,1,1,17) * self.stride[i]  # xy
275
                        #x_kpt[..., 0::3] = (((x_kpt[..., 0::3].sigmoid() * 4.) ** 2 - 8.) * self.anchor_grid[i][...,0].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_x.repeat(1,1,1,1,17) * self.stride[i]  # xy
276
                        #x_kpt[..., 1::3] = (((x_kpt[..., 1::3].sigmoid() * 4.) ** 2 - 8.) * self.anchor_grid[i][...,1].unsqueeze(4).repeat(1,1,1,1,self.nkpt)) + kpt_grid_y.repeat(1,1,1,1,17) * self.stride[i]  # xy
277
                        x_kpt[..., 2::3] = x_kpt[..., 2::3].sigmoid()
278
279
                    y = torch.cat((xy, wh, y[..., 4:], x_kpt), dim = -1)
280
281
                else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
282
                    xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
283
                    wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
284
                    if self.nkpt != 0:
285
                        y[..., 6:] = (y[..., 6:] * 2. - 0.5 + self.grid[i].repeat((1,1,1,1,self.nkpt))) * self.stride[i]  # xy
286
                    y = torch.cat((xy, wh, y[..., 4:]), -1)
287
288
                z.append(y.view(bs, -1, self.no))
289
290
        return x if self.training else (torch.cat(z, 1), x)
291
292
    @staticmethod
293
    def _make_grid(nx=20, ny=20):
294
        yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
295
        return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
296
297
298
class IAuxDetect(nn.Module):
299
    stride = None  # strides computed during build
300
    export = False  # onnx export
301
302
    def __init__(self, nc=80, anchors=(), ch=()):  # detection layer
303
        super(IAuxDetect, self).__init__()
304
        self.nc = nc  # number of classes
305
        self.no = nc + 5  # number of outputs per anchor
306
        self.nl = len(anchors)  # number of detection layers
307
        self.na = len(anchors[0]) // 2  # number of anchors
308
        self.grid = [torch.zeros(1)] * self.nl  # init grid
309
        a = torch.tensor(anchors).float().view(self.nl, -1, 2)
310
        self.register_buffer('anchors', a)  # shape(nl,na,2)
311
        self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2))  # shape(nl,1,na,1,1,2)
312
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[:self.nl])  # output conv
313
        self.m2 = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch[self.nl:])  # output conv
314
        
315
        self.ia = nn.ModuleList(ImplicitA(x) for x in ch[:self.nl])
316
        self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch[:self.nl])
317
318
    def forward(self, x):
319
        # x = x.copy()  # for profiling
320
        z = []  # inference output
321
        self.training |= self.export
322
        for i in range(self.nl):
323
            x[i] = self.m[i](self.ia[i](x[i]))  # conv
324
            x[i] = self.im[i](x[i])
325
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
326
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
327
            
328
            x[i+self.nl] = self.m2[i](x[i+self.nl])
329
            x[i+self.nl] = x[i+self.nl].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
330
331
            if not self.training:  # inference
332
                if self.grid[i].shape[2:4] != x[i].shape[2:4]:
333
                    self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
334
335
                y = x[i].sigmoid()
336
                y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
337
                y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
338
                z.append(y.view(bs, -1, self.no))
339
340
        return x if self.training else (torch.cat(z, 1), x[:self.nl])
341
342
    @staticmethod
343
    def _make_grid(nx=20, ny=20):
344
        yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
345
        return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
346
347
348
class IBin(nn.Module):
349
    stride = None  # strides computed during build
350
    export = False  # onnx export
351
352
    def __init__(self, nc=80, anchors=(), ch=(), bin_count=21):  # detection layer
353
        super(IBin, self).__init__()
354
        self.nc = nc  # number of classes
355
        self.bin_count = bin_count
356
357
        self.w_bin_sigmoid = SigmoidBin(bin_count=self.bin_count, min=0.0, max=4.0)
358
        self.h_bin_sigmoid = SigmoidBin(bin_count=self.bin_count, min=0.0, max=4.0)
359
        # classes, x,y,obj
360
        self.no = nc + 3 + \
361
            self.w_bin_sigmoid.get_length() + self.h_bin_sigmoid.get_length()   # w-bce, h-bce
362
            # + self.x_bin_sigmoid.get_length() + self.y_bin_sigmoid.get_length()
363
        
364
        self.nl = len(anchors)  # number of detection layers
365
        self.na = len(anchors[0]) // 2  # number of anchors
366
        self.grid = [torch.zeros(1)] * self.nl  # init grid
367
        a = torch.tensor(anchors).float().view(self.nl, -1, 2)
368
        self.register_buffer('anchors', a)  # shape(nl,na,2)
369
        self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2))  # shape(nl,1,na,1,1,2)
370
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
371
        
372
        self.ia = nn.ModuleList(ImplicitA(x) for x in ch)
373
        self.im = nn.ModuleList(ImplicitM(self.no * self.na) for _ in ch)
374
375
    def forward(self, x):
376
377
        #self.x_bin_sigmoid.use_fw_regression = True
378
        #self.y_bin_sigmoid.use_fw_regression = True
379
        self.w_bin_sigmoid.use_fw_regression = True
380
        self.h_bin_sigmoid.use_fw_regression = True
381
        
382
        # x = x.copy()  # for profiling
383
        z = []  # inference output
384
        self.training |= self.export
385
        for i in range(self.nl):
386
            x[i] = self.m[i](self.ia[i](x[i]))  # conv
387
            x[i] = self.im[i](x[i])
388
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
389
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
390
391
            if not self.training:  # inference
392
                if self.grid[i].shape[2:4] != x[i].shape[2:4]:
393
                    self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
394
395
                y = x[i].sigmoid()
396
                y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
397
                #y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
398
                
399
400
                #px = (self.x_bin_sigmoid.forward(y[..., 0:12]) + self.grid[i][..., 0]) * self.stride[i]
401
                #py = (self.y_bin_sigmoid.forward(y[..., 12:24]) + self.grid[i][..., 1]) * self.stride[i]
402
403
                pw = self.w_bin_sigmoid.forward(y[..., 2:24]) * self.anchor_grid[i][..., 0]
404
                ph = self.h_bin_sigmoid.forward(y[..., 24:46]) * self.anchor_grid[i][..., 1]
405
406
                #y[..., 0] = px
407
                #y[..., 1] = py
408
                y[..., 2] = pw
409
                y[..., 3] = ph
410
                
411
                y = torch.cat((y[..., 0:4], y[..., 46:]), dim=-1)
412
                
413
                z.append(y.view(bs, -1, y.shape[-1]))
414
415
        return x if self.training else (torch.cat(z, 1), x)
416
417
    @staticmethod
418
    def _make_grid(nx=20, ny=20):
419
        yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
420
        return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
421
422
423
class Model(nn.Module):
424
    def __init__(self, cfg='yolor-csp-c.yaml', ch=3, nc=None, anchors=None):  # model, input channels, number of classes
425
        super(Model, self).__init__()
426
        self.traced = False
427
        if isinstance(cfg, dict):
428
            self.yaml = cfg  # model dict
429
        else:  # is *.yaml
430
            import yaml  # for torch hub
431
            self.yaml_file = Path(cfg).name
432
            with open(cfg) as f:
433
                self.yaml = yaml.load(f, Loader=yaml.SafeLoader)  # model dict
434
435
        # Define model
436
        ch = self.yaml['ch'] = self.yaml.get('ch', ch)  # input channels
437
        if nc and nc != self.yaml['nc']:
438
            logger.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
439
            self.yaml['nc'] = nc  # override yaml value
440
        if anchors:
441
            logger.info(f'Overriding model.yaml anchors with anchors={anchors}')
442
            self.yaml['anchors'] = round(anchors)  # override yaml value
443
        self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch])  # model, savelist
444
        self.names = [str(i) for i in range(self.yaml['nc'])]  # default names
445
        # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))])
446
447
        # Build strides, anchors
448
        m = self.model[-1]  # Detect()
449
        if isinstance(m, Detect):
450
            s = 256  # 2x min stride
451
            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
452
            m.anchors /= m.stride.view(-1, 1, 1)
453
            check_anchor_order(m)
454
            self.stride = m.stride
455
            self._initialize_biases()  # only run once
456
            # print('Strides: %s' % m.stride.tolist())
457
        if isinstance(m, IDetect):
458
            s = 256  # 2x min stride
459
            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
460
            m.anchors /= m.stride.view(-1, 1, 1)
461
            check_anchor_order(m)
462
            self.stride = m.stride
463
            self._initialize_biases()  # only run once
464
            # print('Strides: %s' % m.stride.tolist())
465
        if isinstance(m, IAuxDetect):
466
            s = 256  # 2x min stride
467
            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))[:4]])  # forward
468
            #print(m.stride)
469
            m.anchors /= m.stride.view(-1, 1, 1)
470
            check_anchor_order(m)
471
            self.stride = m.stride
472
            self._initialize_aux_biases()  # only run once
473
            # print('Strides: %s' % m.stride.tolist())
474
        if isinstance(m, IBin):
475
            s = 256  # 2x min stride
476
            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
477
            m.anchors /= m.stride.view(-1, 1, 1)
478
            check_anchor_order(m)
479
            self.stride = m.stride
480
            self._initialize_biases_bin()  # only run once
481
            # print('Strides: %s' % m.stride.tolist())
482
        if isinstance(m, IKeypoint):
483
            s = 256  # 2x min stride
484
            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
485
            m.anchors /= m.stride.view(-1, 1, 1)
486
            check_anchor_order(m)
487
            self.stride = m.stride
488
            self._initialize_biases_kpt()  # only run once
489
            # print('Strides: %s' % m.stride.tolist())
490
491
        # Init weights, biases
492
        initialize_weights(self)
493
        self.info()
494
        logger.info('')
495
496
    def forward(self, x, augment=False, profile=False):
497
        if augment:
498
            img_size = x.shape[-2:]  # height, width
499
            s = [1, 0.83, 0.67]  # scales
500
            f = [None, 3, None]  # flips (2-ud, 3-lr)
501
            y = []  # outputs
502
            for si, fi in zip(s, f):
503
                xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
504
                yi = self.forward_once(xi)[0]  # forward
505
                # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1])  # save
506
                yi[..., :4] /= si  # de-scale
507
                if fi == 2:
508
                    yi[..., 1] = img_size[0] - yi[..., 1]  # de-flip ud
509
                elif fi == 3:
510
                    yi[..., 0] = img_size[1] - yi[..., 0]  # de-flip lr
511
                y.append(yi)
512
            return torch.cat(y, 1), None  # augmented inference, train
513
        else:
514
            return self.forward_once(x, profile)  # single-scale inference, train
515
516
    def forward_once(self, x, profile=False):
517
        y, dt = [], []  # outputs
518
        for m in self.model:
519
            if m.f != -1:  # if not from previous layer
520
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
521
522
            if not hasattr(self, 'traced'):
523
                self.traced=False
524
525
            if self.traced:
526
                if isinstance(m, Detect) or isinstance(m, IDetect) or isinstance(m, IAuxDetect) or isinstance(m, IKeypoint):
527
                    break
528
529
            if profile:
530
                c = isinstance(m, (Detect, IDetect, IAuxDetect, IBin))
531
                o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0  # FLOPS
532
                for _ in range(10):
533
                    m(x.copy() if c else x)
534
                t = time_synchronized()
535
                for _ in range(10):
536
                    m(x.copy() if c else x)
537
                dt.append((time_synchronized() - t) * 100)
538
                print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
539
540
            x = m(x)  # run
541
            
542
            y.append(x if m.i in self.save else None)  # save output
543
544
        if profile:
545
            print('%.1fms total' % sum(dt))
546
        return x
547
548
    def _initialize_biases(self, cf=None):  # initialize biases into Detect(), cf is class frequency
549
        # https://arxiv.org/abs/1708.02002 section 3.3
550
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
551
        m = self.model[-1]  # Detect() module
552
        for mi, s in zip(m.m, m.stride):  # from
553
            b = mi.bias.view(m.na, -1)  # conv.bias(255) to (3,85)
554
            b.data[:, 4] += math.log(8 / (640 / s) ** 2)  # obj (8 objects per 640 image)
555
            b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum())  # cls
556
            mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
557
558
    def _initialize_aux_biases(self, cf=None):  # initialize biases into Detect(), cf is class frequency
559
        # https://arxiv.org/abs/1708.02002 section 3.3
560
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
561
        m = self.model[-1]  # Detect() module
562
        for mi, mi2, s in zip(m.m, m.m2, m.stride):  # from
563
            b = mi.bias.view(m.na, -1)  # conv.bias(255) to (3,85)
564
            b.data[:, 4] += math.log(8 / (640 / s) ** 2)  # obj (8 objects per 640 image)
565
            b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum())  # cls
566
            mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
567
            b2 = mi2.bias.view(m.na, -1)  # conv.bias(255) to (3,85)
568
            b2.data[:, 4] += math.log(8 / (640 / s) ** 2)  # obj (8 objects per 640 image)
569
            b2.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum())  # cls
570
            mi2.bias = torch.nn.Parameter(b2.view(-1), requires_grad=True)
571
572
    def _initialize_biases_bin(self, cf=None):  # initialize biases into Detect(), cf is class frequency
573
        # https://arxiv.org/abs/1708.02002 section 3.3
574
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
575
        m = self.model[-1]  # Bin() module
576
        bc = m.bin_count
577
        for mi, s in zip(m.m, m.stride):  # from
578
            b = mi.bias.view(m.na, -1)  # conv.bias(255) to (3,85)
579
            old = b[:, (0,1,2,bc+3)].data
580
            obj_idx = 2*bc+4
581
            b[:, :obj_idx].data += math.log(0.6 / (bc + 1 - 0.99))
582
            b[:, obj_idx].data += math.log(8 / (640 / s) ** 2)  # obj (8 objects per 640 image)
583
            b[:, (obj_idx+1):].data += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum())  # cls
584
            b[:, (0,1,2,bc+3)].data = old
585
            mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
586
587
    def _initialize_biases_kpt(self, cf=None):  # initialize biases into Detect(), cf is class frequency
588
        # https://arxiv.org/abs/1708.02002 section 3.3
589
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
590
        m = self.model[-1]  # Detect() module
591
        for mi, s in zip(m.m, m.stride):  # from
592
            b = mi.bias.view(m.na, -1)  # conv.bias(255) to (3,85)
593
            b.data[:, 4] += math.log(8 / (640 / s) ** 2)  # obj (8 objects per 640 image)
594
            b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum())  # cls
595
            mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
596
597
    def _print_biases(self):
598
        m = self.model[-1]  # Detect() module
599
        for mi in m.m:  # from
600
            b = mi.bias.detach().view(m.na, -1).T  # conv.bias(255) to (3,85)
601
            print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
602
603
    # def _print_weights(self):
604
    #     for m in self.model.modules():
605
    #         if type(m) is Bottleneck:
606
    #             print('%10.3g' % (m.w.detach().sigmoid() * 2))  # shortcut weights
607
608
    def fuse(self):  # fuse model Conv2d() + BatchNorm2d() layers
609
        print('Fusing layers... ')
610
        for m in self.model.modules():
611
            if isinstance(m, RepConv):
612
                #print(f" fuse_repvgg_block")
613
                m.fuse_repvgg_block()
614
            elif isinstance(m, RepConv_OREPA):
615
                #print(f" switch_to_deploy")
616
                m.switch_to_deploy()
617
            elif type(m) is Conv and hasattr(m, 'bn'):
618
                m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
619
                delattr(m, 'bn')  # remove batchnorm
620
                m.forward = m.fuseforward  # update forward
621
            elif isinstance(m, IDetect):
622
                m.fuse()
623
                m.forward = m.fuseforward
624
        self.info()
625
        return self
626
627
    def nms(self, mode=True):  # add or remove NMS module
628
        present = type(self.model[-1]) is NMS  # last layer is NMS
629
        if mode and not present:
630
            print('Adding NMS... ')
631
            m = NMS()  # module
632
            m.f = -1  # from
633
            m.i = self.model[-1].i + 1  # index
634
            self.model.add_module(name='%s' % m.i, module=m)  # add
635
            self.eval()
636
        elif not mode and present:
637
            print('Removing NMS... ')
638
            self.model = self.model[:-1]  # remove
639
        return self
640
641
    def autoshape(self):  # add autoShape module
642
        print('Adding autoShape... ')
643
        m = autoShape(self)  # wrap model
644
        copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=())  # copy attributes
645
        return m
646
647
    def info(self, verbose=False, img_size=640):  # print model information
648
        model_info(self, verbose, img_size)
649
650
651
def parse_model(d, ch):  # model_dict, input_channels(3)
652
    logger.info('\n%3s%18s%3s%10s  %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
653
    anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
654
    na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchors
655
    no = na * (nc + 5)  # number of outputs = anchors * (classes + 5)
656
657
    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
658
    for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
659
        m = eval(m) if isinstance(m, str) else m  # eval strings
660
        for j, a in enumerate(args):
661
            try:
662
                args[j] = eval(a) if isinstance(a, str) else a  # eval strings
663
            except:
664
                pass
665
666
        n = max(round(n * gd), 1) if n > 1 else n  # depth gain
667
        if m in [nn.Conv2d, Conv, RobustConv, RobustConv2, DWConv, GhostConv, RepConv, RepConv_OREPA, DownC, 
668
                 SPP, SPPF, SPPCSPC, GhostSPPCSPC, MixConv2d, Focus, Stem, GhostStem, CrossConv, 
669
                 Bottleneck, BottleneckCSPA, BottleneckCSPB, BottleneckCSPC, 
670
                 RepBottleneck, RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC,  
671
                 Res, ResCSPA, ResCSPB, ResCSPC, 
672
                 RepRes, RepResCSPA, RepResCSPB, RepResCSPC, 
673
                 ResX, ResXCSPA, ResXCSPB, ResXCSPC, 
674
                 RepResX, RepResXCSPA, RepResXCSPB, RepResXCSPC, 
675
                 Ghost, GhostCSPA, GhostCSPB, GhostCSPC,
676
                 SwinTransformerBlock, STCSPA, STCSPB, STCSPC,
677
                 SwinTransformer2Block, ST2CSPA, ST2CSPB, ST2CSPC]:
678
            c1, c2 = ch[f], args[0]
679
            if c2 != no:  # if not output
680
                c2 = make_divisible(c2 * gw, 8)
681
682
            args = [c1, c2, *args[1:]]
683
            if m in [DownC, SPPCSPC, GhostSPPCSPC, 
684
                     BottleneckCSPA, BottleneckCSPB, BottleneckCSPC, 
685
                     RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC, 
686
                     ResCSPA, ResCSPB, ResCSPC, 
687
                     RepResCSPA, RepResCSPB, RepResCSPC, 
688
                     ResXCSPA, ResXCSPB, ResXCSPC, 
689
                     RepResXCSPA, RepResXCSPB, RepResXCSPC,
690
                     GhostCSPA, GhostCSPB, GhostCSPC,
691
                     STCSPA, STCSPB, STCSPC,
692
                     ST2CSPA, ST2CSPB, ST2CSPC]:
693
                args.insert(2, n)  # number of repeats
694
                n = 1
695
        elif m is nn.BatchNorm2d:
696
            args = [ch[f]]
697
        elif m is Concat:
698
            c2 = sum([ch[x] for x in f])
699
        elif m is Chuncat:
700
            c2 = sum([ch[x] for x in f])
701
        elif m is Shortcut:
702
            c2 = ch[f[0]]
703
        elif m is Foldcut:
704
            c2 = ch[f] // 2
705
        elif m in [Detect, IDetect, IAuxDetect, IBin, IKeypoint]:
706
            args.append([ch[x] for x in f])
707
            if isinstance(args[1], int):  # number of anchors
708
                args[1] = [list(range(args[1] * 2))] * len(f)
709
        elif m is ReOrg:
710
            c2 = ch[f] * 4
711
        elif m is Contract:
712
            c2 = ch[f] * args[0] ** 2
713
        elif m is Expand:
714
            c2 = ch[f] // args[0] ** 2
715
        else:
716
            c2 = ch[f]
717
718
        m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args)  # module
719
        t = str(m)[8:-2].replace('__main__.', '')  # module type
720
        np = sum([x.numel() for x in m_.parameters()])  # number params
721
        m_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number params
722
        logger.info('%3s%18s%3s%10.0f  %-40s%-30s' % (i, f, n, np, t, args))  # print
723
        save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
724
        layers.append(m_)
725
        if i == 0:
726
            ch = []
727
        ch.append(c2)
728
    return nn.Sequential(*layers), sorted(save)
729
730
731
if __name__ == '__main__':
732
    parser = argparse.ArgumentParser()
733
    parser.add_argument('--cfg', type=str, default='yolor-csp-c.yaml', help='model.yaml')
734
    parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
735
    parser.add_argument('--profile', action='store_true', help='profile model speed')
736
    opt = parser.parse_args()
737
    opt.cfg = check_file(opt.cfg)  # check file
738
    set_logging()
739
    device = select_device(opt.device)
740
741
    # Create model
742
    model = Model(opt.cfg).to(device)
743
    model.train()
744
    
745
    if opt.profile:
746
        img = torch.rand(1, 3, 640, 640).to(device)
747
        y = model(img, profile=True)
748
749
    # Profile
750
    # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
751
    # y = model(img, profile=True)
752
753
    # Tensorboard
754
    # from torch.utils.tensorboard import SummaryWriter
755
    # tb_writer = SummaryWriter()
756
    # print("Run 'tensorboard --logdir=models/runs' to view tensorboard at http://localhost:6006/")
757
    # tb_writer.add_graph(model.model, img)  # add model to tensorboard
758
    # tb_writer.add_image('test', img[0], dataformats='CWH')  # add model to tensorboard