Switch to unified view

a b/opengait/modeling/modules.py
1
import torch
2
import numpy as np
3
import torch.nn as nn
4
import torch.nn.functional as F
5
from utils import clones, is_list_or_tuple
6
from torchvision.ops import RoIAlign
7
8
9
class HorizontalPoolingPyramid():
10
    """
11
        Horizontal Pyramid Matching for Person Re-identification
12
        Arxiv: https://arxiv.org/abs/1804.05275
13
        Github: https://github.com/SHI-Labs/Horizontal-Pyramid-Matching
14
    """
15
16
    def __init__(self, bin_num=None):
17
        if bin_num is None:
18
            bin_num = [16, 8, 4, 2, 1]
19
        self.bin_num = bin_num
20
21
    def __call__(self, x):
22
        """
23
            x  : [n, c, h, w]
24
            ret: [n, c, p] 
25
        """
26
        n, c = x.size()[:2]
27
        features = []
28
        for b in self.bin_num:
29
            z = x.view(n, c, b, -1)
30
            z = z.mean(-1) + z.max(-1)[0]
31
            features.append(z)
32
        return torch.cat(features, -1)
33
34
35
class SetBlockWrapper(nn.Module):
36
    def __init__(self, forward_block):
37
        super(SetBlockWrapper, self).__init__()
38
        self.forward_block = forward_block
39
40
    def forward(self, x, *args, **kwargs):
41
        """
42
            In  x: [n, c_in, s, h_in, w_in]
43
            Out x: [n, c_out, s, h_out, w_out]
44
        """
45
        n, c, s, h, w = x.size()
46
        x = self.forward_block(x.transpose(
47
            1, 2).reshape(-1, c, h, w), *args, **kwargs)
48
        output_size = x.size()
49
        return x.reshape(n, s, *output_size[1:]).transpose(1, 2).contiguous()
50
51
52
class PackSequenceWrapper(nn.Module):
53
    def __init__(self, pooling_func):
54
        super(PackSequenceWrapper, self).__init__()
55
        self.pooling_func = pooling_func
56
57
    def forward(self, seqs, seqL, dim=2, options={}):
58
        """
59
            In  seqs: [n, c, s, ...]
60
            Out rets: [n, ...]
61
        """
62
        if seqL is None:
63
            return self.pooling_func(seqs, **options)
64
        seqL = seqL[0].data.cpu().numpy().tolist()
65
        start = [0] + np.cumsum(seqL).tolist()[:-1]
66
67
        rets = []
68
        for curr_start, curr_seqL in zip(start, seqL):
69
            narrowed_seq = seqs.narrow(dim, curr_start, curr_seqL)
70
            rets.append(self.pooling_func(narrowed_seq, **options))
71
        if len(rets) > 0 and is_list_or_tuple(rets[0]):
72
            return [torch.cat([ret[j] for ret in rets])
73
                    for j in range(len(rets[0]))]
74
        return torch.cat(rets)
75
76
77
class BasicConv2d(nn.Module):
78
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, **kwargs):
79
        super(BasicConv2d, self).__init__()
80
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
81
                              stride=stride, padding=padding, bias=False, **kwargs)
82
83
    def forward(self, x):
84
        x = self.conv(x)
85
        return x
86
87
88
class SeparateFCs(nn.Module):
89
    def __init__(self, parts_num, in_channels, out_channels, norm=False):
90
        super(SeparateFCs, self).__init__()
91
        self.p = parts_num
92
        self.fc_bin = nn.Parameter(
93
            nn.init.xavier_uniform_(
94
                torch.zeros(parts_num, in_channels, out_channels)))
95
        self.norm = norm
96
97
    def forward(self, x):
98
        """
99
            x: [n, c_in, p]
100
            out: [n, c_out, p]
101
        """
102
        x = x.permute(2, 0, 1).contiguous()
103
        if self.norm:
104
            out = x.matmul(F.normalize(self.fc_bin, dim=1))
105
        else:
106
            out = x.matmul(self.fc_bin)
107
        return out.permute(1, 2, 0).contiguous()
108
109
110
class SeparateBNNecks(nn.Module):
111
    """
112
        Bag of Tricks and a Strong Baseline for Deep Person Re-Identification
113
        CVPR Workshop:  https://openaccess.thecvf.com/content_CVPRW_2019/papers/TRMTMCT/Luo_Bag_of_Tricks_and_a_Strong_Baseline_for_Deep_Person_CVPRW_2019_paper.pdf
114
        Github: https://github.com/michuanhaohao/reid-strong-baseline
115
    """
116
117
    def __init__(self, parts_num, in_channels, class_num, norm=True, parallel_BN1d=True):
118
        super(SeparateBNNecks, self).__init__()
119
        self.p = parts_num
120
        self.class_num = class_num
121
        self.norm = norm
122
        self.fc_bin = nn.Parameter(
123
            nn.init.xavier_uniform_(
124
                torch.zeros(parts_num, in_channels, class_num)))
125
        if parallel_BN1d:
126
            self.bn1d = nn.BatchNorm1d(in_channels * parts_num)
127
        else:
128
            self.bn1d = clones(nn.BatchNorm1d(in_channels), parts_num)
129
        self.parallel_BN1d = parallel_BN1d
130
131
    def forward(self, x):
132
        """
133
            x: [n, c, p]
134
        """
135
        if self.parallel_BN1d:
136
            n, c, p = x.size()
137
            x = x.view(n, -1)  # [n, c*p]
138
            x = self.bn1d(x)
139
            x = x.view(n, c, p)
140
        else:
141
            x = torch.cat([bn(_x) for _x, bn in zip(
142
                x.split(1, 2), self.bn1d)], 2)  # [p, n, c]
143
        feature = x.permute(2, 0, 1).contiguous()
144
        if self.norm:
145
            feature = F.normalize(feature, dim=-1)  # [p, n, c]
146
            logits = feature.matmul(F.normalize(
147
                self.fc_bin, dim=1))  # [p, n, c]
148
        else:
149
            logits = feature.matmul(self.fc_bin)
150
        return feature.permute(1, 2, 0).contiguous(), logits.permute(1, 2, 0).contiguous()
151
152
153
class FocalConv2d(nn.Module):
154
    """
155
        GaitPart: Temporal Part-based Model for Gait Recognition
156
        CVPR2020: https://openaccess.thecvf.com/content_CVPR_2020/papers/Fan_GaitPart_Temporal_Part-Based_Model_for_Gait_Recognition_CVPR_2020_paper.pdf
157
        Github: https://github.com/ChaoFan96/GaitPart
158
    """
159
    def __init__(self, in_channels, out_channels, kernel_size, halving, **kwargs):
160
        super(FocalConv2d, self).__init__()
161
        self.halving = halving
162
        self.conv = nn.Conv2d(in_channels, out_channels,
163
                              kernel_size, bias=False, **kwargs)
164
165
    def forward(self, x):
166
        if self.halving == 0:
167
            z = self.conv(x)
168
        else:
169
            h = x.size(2)
170
            split_size = int(h // 2**self.halving)
171
            z = x.split(split_size, 2)
172
            z = torch.cat([self.conv(_) for _ in z], 2)
173
        return z
174
175
176
class BasicConv3d(nn.Module):
177
    def __init__(self, in_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False, **kwargs):
178
        super(BasicConv3d, self).__init__()
179
        self.conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size,
180
                                stride=stride, padding=padding, bias=bias, **kwargs)
181
182
    def forward(self, ipts):
183
        '''
184
            ipts: [n, c, s, h, w]
185
            outs: [n, c, s, h, w]
186
        '''
187
        outs = self.conv3d(ipts)
188
        return outs
189
190
191
class GaitAlign(nn.Module):
192
    """
193
        GaitEdge: Beyond Plain End-to-end Gait Recognition for Better Practicality
194
        ECCV2022: https://arxiv.org/pdf/2203.03972v2.pdf
195
        Github: https://github.com/ShiqiYu/OpenGait/tree/master/configs/gaitedge
196
    """
197
    def __init__(self, H=64, W=44, eps=1, **kwargs):
198
        super(GaitAlign, self).__init__()
199
        self.H, self.W, self.eps = H, W, eps
200
        self.Pad = nn.ZeroPad2d((int(self.W / 2), int(self.W / 2), 0, 0))
201
        self.RoiPool = RoIAlign((self.H, self.W), 1, sampling_ratio=-1)
202
203
    def forward(self, feature_map, binary_mask, w_h_ratio):
204
        """
205
           In  sils:         [n, c, h, w]
206
               w_h_ratio:    [n, 1]
207
           Out aligned_sils: [n, c, H, W]
208
        """
209
        n, c, h, w = feature_map.size()
210
        # w_h_ratio = w_h_ratio.repeat(1, 1) # [n, 1]
211
        w_h_ratio = w_h_ratio.view(-1, 1)  # [n, 1]
212
213
        h_sum = binary_mask.sum(-1)  # [n, c, h]
214
        _ = (h_sum >= self.eps).float().cumsum(axis=-1)  # [n, c, h]
215
        h_top = (_ == 0).float().sum(-1)  # [n, c]
216
        h_bot = (_ != torch.max(_, dim=-1, keepdim=True)
217
                 [0]).float().sum(-1) + 1.  # [n, c]
218
219
        w_sum = binary_mask.sum(-2)  # [n, c, w]
220
        w_cumsum = w_sum.cumsum(axis=-1)  # [n, c, w]
221
        w_h_sum = w_sum.sum(-1).unsqueeze(-1)  # [n, c, 1]
222
        w_center = (w_cumsum < w_h_sum / 2.).float().sum(-1)  # [n, c]
223
224
        p1 = self.W - self.H * w_h_ratio
225
        p1 = p1 / 2.
226
        p1 = torch.clamp(p1, min=0)  # [n, c]
227
        t_w = w_h_ratio * self.H / w
228
        p2 = p1 / t_w  # [n, c]
229
230
        height = h_bot - h_top  # [n, c]
231
        width = height * w / h  # [n, c]
232
        width_p = int(self.W / 2)
233
234
        feature_map = self.Pad(feature_map)
235
        w_center = w_center + width_p  # [n, c]
236
237
        w_left = w_center - width / 2 - p2  # [n, c]
238
        w_right = w_center + width / 2 + p2  # [n, c]
239
240
        w_left = torch.clamp(w_left, min=0., max=w+2*width_p)
241
        w_right = torch.clamp(w_right, min=0., max=w+2*width_p)
242
243
        boxes = torch.cat([w_left, h_top, w_right, h_bot], dim=-1)
244
        # index of bbox in batch
245
        box_index = torch.arange(n, device=feature_map.device)
246
        rois = torch.cat([box_index.view(-1, 1), boxes], -1)
247
        crops = self.RoiPool(feature_map, rois)  # [n, c, H, W]
248
        return crops
249
250
251
def RmBN2dAffine(model):
252
    for m in model.modules():
253
        if isinstance(m, nn.BatchNorm2d):
254
            m.weight.requires_grad = False
255
            m.bias.requires_grad = False
256
257
258
'''
259
Modifed from https://github.com/BNU-IVC/FastPoseGait/blob/main/fastposegait/modeling/components/units
260
'''
261
262
class Graph():
263
    """
264
    # Thanks to YAN Sijie for the released code on Github (https://github.com/yysijie/st-gcn)
265
    """
266
    def __init__(self, joint_format='coco', max_hop=2, dilation=1):
267
        self.joint_format = joint_format
268
        self.max_hop = max_hop
269
        self.dilation = dilation
270
271
        # get edges
272
        self.num_node, self.edge, self.connect_joint, self.parts = self._get_edge()
273
274
        # get adjacency matrix
275
        self.A = self._get_adjacency()
276
277
    def __str__(self):
278
        return self.A
279
280
    def _get_edge(self):
281
        if self.joint_format == 'coco':
282
            # keypoints = {
283
            #     0: "nose",
284
            #     1: "left_eye",
285
            #     2: "right_eye",
286
            #     3: "left_ear",
287
            #     4: "right_ear",
288
            #     5: "left_shoulder",
289
            #     6: "right_shoulder",
290
            #     7: "left_elbow",
291
            #     8: "right_elbow",
292
            #     9: "left_wrist",
293
            #     10: "right_wrist",
294
            #     11: "left_hip",
295
            #     12: "right_hip",
296
            #     13: "left_knee",
297
            #     14: "right_knee",
298
            #     15: "left_ankle",
299
            #     16: "right_ankle"
300
            # }
301
            num_node = 17
302
            self_link = [(i, i) for i in range(num_node)]
303
            neighbor_link = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 6),
304
                             (5, 7), (7, 9), (6, 8), (8, 10), (5, 11), (6, 12), (11, 12),
305
                             (11, 13), (13, 15), (12, 14), (14, 16)]
306
            self.edge = self_link + neighbor_link
307
            self.center = 0
308
            self.flip_idx = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
309
            connect_joint = np.array([5,0,0,1,2,0,0,5,6,7,8,5,6,11,12,13,14])
310
            parts = [
311
                np.array([5, 7, 9]),                      # left_arm
312
                np.array([6, 8, 10]),                     # right_arm
313
                np.array([11, 13, 15]),                   # left_leg
314
                np.array([12, 14, 16]),                   # right_leg
315
                np.array([0, 1, 2, 3, 4]),                # head
316
            ]
317
318
        elif self.joint_format == 'coco-no-head':
319
            num_node = 12
320
            self_link = [(i, i) for i in range(num_node)]
321
            neighbor_link = [(0, 1),
322
                             (0, 2), (2, 4), (1, 3), (3, 5), (0, 6), (1, 7), (6, 7),
323
                             (6, 8), (8, 10), (7, 9), (9, 11)]
324
            self.edge = self_link + neighbor_link
325
            self.center = 0
326
            connect_joint = np.array([3,1,0,2,4,0,6,8,10,7,9,11])
327
            parts =[
328
                np.array([0, 2, 4]),       # left_arm
329
                np.array([1, 3, 5]),       # right_arm
330
                np.array([6, 8, 10]),      # left_leg
331
                np.array([7, 9, 11])       # right_leg
332
            ]
333
334
        elif self.joint_format =='alphapose' or self.joint_format =='openpose':
335
            num_node = 18
336
            self_link = [(i, i) for i in range(num_node)]
337
            neighbor_link = [(0, 1), (0, 14), (0, 15), (14, 16), (15, 17),
338
                             (1, 2), (2, 3), (3, 4), (1, 5), (5, 6), (6, 7),
339
                             (1, 8), (8, 9), (9, 10), (1, 11), (11, 12), (12, 13)]
340
            self.edge = self_link + neighbor_link
341
            self.center = 1
342
            self.flip_idx = [0, 1, 5, 6, 7, 2, 3, 4, 11, 12, 13, 8, 9, 10, 15, 14, 17, 16]
343
            connect_joint = np.array([1,1,1,2,3,1,5,6,2,8,9,5,11,12,0,0,14,15])
344
            parts = [
345
                np.array([5, 6, 7]),               # left_arm
346
                np.array([2, 3, 4]),               # right_arm
347
                np.array([11, 12, 13]),            # left_leg
348
                np.array([8, 9, 10]),              # right_leg
349
                np.array([0, 1, 14, 15, 16, 17]),  # head
350
            ]
351
352
        else:
353
            num_node, neighbor_link, connect_joint, parts = 0, [], [], []
354
            raise ValueError('Error: Do NOT exist this dataset: {}!'.format(self.dataset))
355
        self_link = [(i, i) for i in range(num_node)]
356
        edge = self_link + neighbor_link
357
        return num_node, edge, connect_joint, parts
358
359
    def _get_hop_distance(self):
360
        A = np.zeros((self.num_node, self.num_node))
361
        for i, j in self.edge:
362
            A[j, i] = 1
363
            A[i, j] = 1
364
        hop_dis = np.zeros((self.num_node, self.num_node)) + np.inf
365
        transfer_mat = [np.linalg.matrix_power(A, d) for d in range(self.max_hop + 1)]
366
        arrive_mat = (np.stack(transfer_mat) > 0)
367
        for d in range(self.max_hop, -1, -1):
368
            hop_dis[arrive_mat[d]] = d
369
        return hop_dis
370
371
    def _get_adjacency(self):
372
        hop_dis = self._get_hop_distance()
373
        valid_hop = range(0, self.max_hop + 1, self.dilation)
374
        adjacency = np.zeros((self.num_node, self.num_node))
375
        for hop in valid_hop:
376
            adjacency[hop_dis == hop] = 1
377
        normalize_adjacency = self._normalize_digraph(adjacency)
378
        A = np.zeros((len(valid_hop), self.num_node, self.num_node))
379
        for i, hop in enumerate(valid_hop):
380
            A[i][hop_dis == hop] = normalize_adjacency[hop_dis == hop]
381
        return A
382
383
    def _normalize_digraph(self, A):
384
        Dl = np.sum(A, 0)
385
        num_node = A.shape[0]
386
        Dn = np.zeros((num_node, num_node))
387
        for i in range(num_node):
388
            if Dl[i] > 0:
389
                Dn[i, i] = Dl[i]**(-1)
390
        AD = np.dot(A, Dn)
391
        return AD
392
393
394
class TemporalBasicBlock(nn.Module):
395
    """
396
        TemporalConv_Res_Block
397
        Arxiv: https://arxiv.org/abs/2010.09978
398
        Github: https://github.com/Thomas-yx/ResGCNv1
399
    """
400
    def __init__(self, channels, temporal_window_size, stride=1, residual=False,reduction=0,get_res=False,tcn_stride=False):
401
        super(TemporalBasicBlock, self).__init__()
402
403
        padding = ((temporal_window_size - 1) // 2, 0)
404
405
        if not residual:
406
            self.residual = lambda x: 0
407
        elif stride == 1:
408
            self.residual = lambda x: x
409
        else:
410
            self.residual = nn.Sequential(
411
                nn.Conv2d(channels, channels, 1, (stride,1)),
412
                nn.BatchNorm2d(channels),
413
            )
414
415
        self.conv = nn.Conv2d(channels, channels, (temporal_window_size,1), (stride,1), padding)
416
        self.bn = nn.BatchNorm2d(channels)
417
        self.relu = nn.ReLU(inplace=True)
418
419
    def forward(self, x, res_module):
420
421
        res_block = self.residual(x)
422
423
        x = self.conv(x)
424
        x = self.bn(x)
425
        x = self.relu(x + res_block + res_module)
426
427
        return x
428
429
430
class TemporalBottleneckBlock(nn.Module):
431
    """
432
        TemporalConv_Res_Bottleneck
433
        Arxiv: https://arxiv.org/abs/2010.09978
434
        Github: https://github.com/Thomas-yx/ResGCNv1
435
    """
436
    def __init__(self, channels, temporal_window_size, stride=1, residual=False, reduction=4,get_res=False, tcn_stride=False):
437
        super(TemporalBottleneckBlock, self).__init__()
438
        tcn_stride =False
439
        padding = ((temporal_window_size - 1) // 2, 0)
440
        inter_channels = channels // reduction
441
        if get_res:
442
            if tcn_stride:
443
                stride =2
444
            self.residual = nn.Sequential(
445
                nn.Conv2d(channels, channels, 1, (2,1)),
446
                nn.BatchNorm2d(channels),
447
            )
448
            tcn_stride= True
449
        else:
450
            if not residual:
451
                self.residual = lambda x: 0
452
            elif stride == 1:
453
                self.residual = lambda x: x
454
            else:
455
                self.residual = nn.Sequential(
456
                    nn.Conv2d(channels, channels, 1, (2,1)),
457
                    nn.BatchNorm2d(channels),
458
                )
459
                tcn_stride= True
460
461
        self.conv_down = nn.Conv2d(channels, inter_channels, 1)
462
        self.bn_down = nn.BatchNorm2d(inter_channels)
463
        if tcn_stride:
464
            stride=2
465
        self.conv = nn.Conv2d(inter_channels, inter_channels, (temporal_window_size,1), (stride,1), padding)
466
        self.bn = nn.BatchNorm2d(inter_channels)
467
        self.conv_up = nn.Conv2d(inter_channels, channels, 1)
468
        self.bn_up = nn.BatchNorm2d(channels)
469
        self.relu = nn.ReLU(inplace=True)
470
471
    def forward(self, x, res_module):
472
473
        res_block = self.residual(x)
474
475
        x = self.conv_down(x)
476
        x = self.bn_down(x)
477
        x = self.relu(x)
478
479
        x = self.conv(x)
480
        x = self.bn(x)
481
        x = self.relu(x)
482
483
        x = self.conv_up(x)
484
        x = self.bn_up(x)
485
        x = self.relu(x + res_block + res_module)
486
        return x
487
488
489
490
class SpatialGraphConv(nn.Module):
491
    """
492
        SpatialGraphConv_Basic_Block
493
        Arxiv: https://arxiv.org/abs/1801.07455
494
        Github: https://github.com/yysijie/st-gcn
495
    """
496
    def __init__(self, in_channels, out_channels, max_graph_distance):
497
        super(SpatialGraphConv, self).__init__()
498
499
        # spatial class number (distance = 0 for class 0, distance = 1 for class 1, ...)
500
        self.s_kernel_size = max_graph_distance + 1
501
502
        # weights of different spatial classes
503
        self.gcn = nn.Conv2d(in_channels, out_channels*self.s_kernel_size, 1)
504
505
    def forward(self, x, A):
506
507
        # numbers in same class have same weight
508
        x = self.gcn(x)
509
510
        # divide nodes into different classes
511
        n, kc, t, v = x.size()
512
        x = x.view(n, self.s_kernel_size, kc//self.s_kernel_size, t, v).contiguous()
513
514
        # spatial graph convolution
515
        x = torch.einsum('nkctv,kvw->nctw', (x, A[:self.s_kernel_size])).contiguous()
516
517
        return x
518
519
class SpatialBasicBlock(nn.Module):
520
    """
521
        SpatialGraphConv_Res_Block
522
        Arxiv: https://arxiv.org/abs/2010.09978
523
        Github: https://github.com/Thomas-yx/ResGCNv1
524
    """
525
    def __init__(self, in_channels, out_channels, max_graph_distance, residual=False,reduction=0):
526
        super(SpatialBasicBlock, self).__init__()
527
528
        if not residual:
529
            self.residual = lambda x: 0
530
        elif in_channels == out_channels:
531
            self.residual = lambda x: x
532
        else:
533
            self.residual = nn.Sequential(
534
                nn.Conv2d(in_channels, out_channels, 1),
535
                nn.BatchNorm2d(out_channels),
536
            )
537
538
        self.conv = SpatialGraphConv(in_channels, out_channels, max_graph_distance)
539
        self.bn = nn.BatchNorm2d(out_channels)
540
        self.relu = nn.ReLU(inplace=True)
541
542
    def forward(self, x, A):
543
544
        res_block = self.residual(x)
545
546
        x = self.conv(x, A)
547
        x = self.bn(x)
548
        x = self.relu(x + res_block)
549
550
        return x
551
552
class SpatialBottleneckBlock(nn.Module):
553
    """
554
        SpatialGraphConv_Res_Bottleneck
555
        Arxiv: https://arxiv.org/abs/2010.09978
556
        Github: https://github.com/Thomas-yx/ResGCNv1
557
    """
558
559
    def __init__(self, in_channels, out_channels, max_graph_distance, residual=False, reduction=4):
560
        super(SpatialBottleneckBlock, self).__init__()
561
562
        inter_channels = out_channels // reduction
563
564
        if not residual:
565
            self.residual = lambda x: 0
566
        elif in_channels == out_channels:
567
            self.residual = lambda x: x
568
        else:
569
            self.residual = nn.Sequential(
570
                nn.Conv2d(in_channels, out_channels, 1),
571
                nn.BatchNorm2d(out_channels),
572
            )
573
574
        self.conv_down = nn.Conv2d(in_channels, inter_channels, 1)
575
        self.bn_down = nn.BatchNorm2d(inter_channels)
576
        self.conv = SpatialGraphConv(inter_channels, inter_channels, max_graph_distance)
577
        self.bn = nn.BatchNorm2d(inter_channels)
578
        self.conv_up = nn.Conv2d(inter_channels, out_channels, 1)
579
        self.bn_up = nn.BatchNorm2d(out_channels)
580
        self.relu = nn.ReLU(inplace=True)
581
582
    def forward(self, x, A):
583
584
        res_block = self.residual(x)
585
586
        x = self.conv_down(x)
587
        x = self.bn_down(x)
588
        x = self.relu(x)
589
590
        x = self.conv(x, A)
591
        x = self.bn(x)
592
        x = self.relu(x)
593
594
        x = self.conv_up(x)
595
        x = self.bn_up(x)
596
        x = self.relu(x + res_block)
597
598
        return x
599
600
class SpatialAttention(nn.Module):
601
    """
602
    This class implements Spatial Transformer. 
603
    Function adapted from: https://github.com/leaderj1001/Attention-Augmented-Conv2d
604
    """
605
    def __init__(self, in_channels, out_channel, A, num_point, dk_factor=0.25, kernel_size=1, Nh=8, num=4, stride=1):
606
        super(SpatialAttention, self).__init__()
607
        self.in_channels = in_channels
608
        self.kernel_size = kernel_size
609
        self.dk = int(dk_factor * out_channel)
610
        self.dv = int(out_channel)
611
        self.num = num
612
        self.Nh = Nh
613
        self.num_point=num_point
614
        self.A = A[0] + A[1] + A[2]
615
        self.stride = stride
616
        self.padding = (self.kernel_size - 1) // 2
617
618
        assert self.Nh != 0, "integer division or modulo by zero, Nh >= 1"
619
        assert self.dk % self.Nh == 0, "dk should be divided by Nh. (example: out_channels: 20, dk: 40, Nh: 4)"
620
        assert self.dv % self.Nh == 0, "dv should be divided by Nh. (example: out_channels: 20, dv: 4, Nh: 4)"
621
        assert stride in [1, 2], str(stride) + " Up to 2 strides are allowed."
622
623
        self.qkv_conv = nn.Conv2d(self.in_channels, 2 * self.dk + self.dv, kernel_size=self.kernel_size,
624
                                    stride=stride,
625
                                    padding=self.padding)
626
627
        self.attn_out = nn.Conv2d(self.dv, self.dv, kernel_size=1, stride=1)
628
629
    def forward(self, x):
630
        # Input x
631
        # (batch_size, channels, 1, joints)
632
        B, _, T, V = x.size()
633
634
        # flat_q, flat_k, flat_v
635
        # (batch_size, Nh, dvh or dkh, joints)
636
        # dvh = dv / Nh, dkh = dk / Nh
637
        # q, k, v obtained by doing 2D convolution on the input (q=XWq, k=XWk, v=XWv)
638
        flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh)
639
640
        # Calculate the scores, obtained by doing q*k
641
        # (batch_size, Nh, joints, dkh)*(batch_size, Nh, dkh, joints) =  (batch_size, Nh, joints,joints)
642
        # The multiplication can also be divided (multi_matmul) in case of space problems
643
644
        logits = torch.matmul(flat_q.transpose(2, 3), flat_k)
645
646
        weights = F.softmax(logits, dim=-1)
647
648
        # attn_out
649
        # (batch, Nh, joints, dvh)
650
        # weights*V
651
        # (batch, Nh, joints, joints)*(batch, Nh, joints, dvh)=(batch, Nh, joints, dvh)
652
        attn_out = torch.matmul(weights, flat_v.transpose(2, 3))
653
654
        attn_out = torch.reshape(attn_out, (B, self.Nh, T, V, self.dv // self.Nh))
655
656
        attn_out = attn_out.permute(0, 1, 4, 2, 3)
657
658
        # combine_heads_2d, combine heads only after having calculated each Z separately
659
        # (batch, Nh*dv, 1, joints)
660
        attn_out = self.combine_heads_2d(attn_out)
661
662
        # Multiply for W0 (batch, out_channels, 1, joints) with out_channels=dv
663
        attn_out = self.attn_out(attn_out)
664
        return attn_out
665
666
    def compute_flat_qkv(self, x, dk, dv, Nh):
667
        qkv = self.qkv_conv(x)
668
        # T=1 in this case, because we are considering each frame separately
669
        N, _, T, V = qkv.size()
670
671
        q, k, v = torch.split(qkv, [dk, dk, dv], dim=1)
672
        q = self.split_heads_2d(q, Nh)
673
        k = self.split_heads_2d(k, Nh)
674
        v = self.split_heads_2d(v, Nh)
675
676
        dkh = dk // Nh
677
        q = q*(dkh ** -0.5)
678
        flat_q = torch.reshape(q, (N, Nh, dkh, T * V))
679
        flat_k = torch.reshape(k, (N, Nh, dkh, T * V))
680
        flat_v = torch.reshape(v, (N, Nh, dv // self.Nh, T * V))
681
        return flat_q, flat_k, flat_v, q, k, v
682
683
    def split_heads_2d(self, x, Nh):
684
        B, channels, T, V = x.size()
685
        ret_shape = (B, Nh, channels // Nh, T, V)
686
        split = torch.reshape(x, ret_shape)
687
        return split
688
689
    def combine_heads_2d(self, x):
690
        batch, Nh, dv, T, V = x.size()
691
        ret_shape = (batch, Nh * dv, T, V)
692
        return torch.reshape(x, ret_shape)
693
694
from einops import rearrange
695
class ParallelBN1d(nn.Module):
696
    def __init__(self, parts_num, in_channels, **kwargs):
697
        super(ParallelBN1d, self).__init__()
698
        self.parts_num = parts_num
699
        self.bn1d = nn.BatchNorm1d(in_channels * parts_num, **kwargs)
700
701
    def forward(self, x):
702
        '''
703
            x: [n, c, p]
704
        '''
705
        x = rearrange(x, 'n c p -> n (c p)')
706
        x = self.bn1d(x)
707
        x = rearrange(x, 'n (c p) -> n c p', p=self.parts_num)
708
        return x
709
    
710
711
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
712
    """3x3 convolution with padding"""
713
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
714
                     padding=dilation, groups=groups, bias=False, dilation=dilation)
715
716
def conv1x1(in_planes, out_planes, stride=1):
717
    """1x1 convolution"""
718
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
719
720
class BasicBlock2D(nn.Module):
721
    expansion = 1
722
723
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
724
                 base_width=64, dilation=1, norm_layer=None):
725
        super(BasicBlock2D, self).__init__()
726
        if norm_layer is None:
727
            norm_layer = nn.BatchNorm2d
728
        if groups != 1 or base_width != 64:
729
            raise ValueError(
730
                'BasicBlock only supports groups=1 and base_width=64')
731
        if dilation > 1:
732
            raise NotImplementedError(
733
                "Dilation > 1 not supported in BasicBlock")
734
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
735
        self.conv1 = conv3x3(inplanes, planes, stride)
736
        self.bn1 = norm_layer(planes)
737
        self.relu = nn.ReLU(inplace=True)
738
        self.conv2 = conv3x3(planes, planes)
739
        self.bn2 = norm_layer(planes)
740
        self.downsample = downsample
741
        self.stride = stride
742
743
    def forward(self, x):
744
        identity = x
745
746
        out = self.conv1(x)
747
        out = self.bn1(out)
748
        out = self.relu(out)
749
750
        out = self.conv2(out)
751
        out = self.bn2(out)
752
753
        if self.downsample is not None:
754
            identity = self.downsample(x)
755
756
        out += identity
757
        out = self.relu(out)
758
759
        return out
760
761
class BasicBlockP3D(nn.Module):
762
    expansion = 1
763
764
    def __init__(self, inplanes, planes, stride=1,  downsample=None, groups=1,
765
                 base_width=64, dilation=1, norm_layer=None):
766
        super(BasicBlockP3D, self).__init__()
767
        if norm_layer is None:
768
            norm_layer2d = nn.BatchNorm2d
769
            norm_layer3d = nn.BatchNorm3d
770
        if groups != 1 or base_width != 64:
771
            raise ValueError(
772
                'BasicBlock only supports groups=1 and base_width=64')
773
        if dilation > 1:
774
            raise NotImplementedError(
775
                "Dilation > 1 not supported in BasicBlock")
776
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
777
        self.relu  = nn.ReLU(inplace=True)
778
        
779
        self.conv1 = SetBlockWrapper(
780
            nn.Sequential(
781
                conv3x3(inplanes, planes, stride), 
782
                norm_layer2d(planes), 
783
                nn.ReLU(inplace=True)
784
            )
785
        )
786
787
        self.conv2 = SetBlockWrapper(
788
            nn.Sequential(
789
                conv3x3(planes, planes), 
790
                norm_layer2d(planes), 
791
            )
792
        )
793
794
        self.shortcut3d = nn.Conv3d(planes, planes, (3, 1, 1), (1, 1, 1), (1, 0, 0), bias=False)
795
        self.sbn        = norm_layer3d(planes)
796
797
        self.downsample = downsample
798
799
    def forward(self, x):
800
        '''
801
            x: [n, c, s, h, w]
802
        '''
803
        identity = x
804
805
        out = self.conv1(x)
806
        out = self.relu(out + self.sbn(self.shortcut3d(out)))
807
        out = self.conv2(out)
808
809
        if self.downsample is not None:
810
            identity = self.downsample(x)
811
812
        out += identity
813
        out = self.relu(out)
814
815
        return out
816
    
817
class BasicBlock3D(nn.Module):
818
    expansion = 1
819
820
    def __init__(self, inplanes, planes, stride=[1, 1, 1],  downsample=None, groups=1,
821
                 base_width=64, dilation=1, norm_layer=None):
822
        super(BasicBlock3D, self).__init__()
823
        if norm_layer is None:
824
            norm_layer = nn.BatchNorm3d
825
        if groups != 1 or base_width != 64:
826
            raise ValueError(
827
                'BasicBlock only supports groups=1 and base_width=64')
828
        if dilation > 1:
829
            raise NotImplementedError(
830
                "Dilation > 1 not supported in BasicBlock")
831
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
832
        assert stride[0] in [1, 2, 3]
833
        if stride[0] in [1, 2]: 
834
            tp = 1
835
        else:
836
            tp = 0
837
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 3, 3), stride=stride, padding=[tp, 1, 1], bias=False)
838
        self.bn1   = norm_layer(planes)
839
        self.relu  = nn.ReLU(inplace=True)
840
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=(3, 3, 3), stride=[1, 1, 1], padding=[1, 1, 1], bias=False)
841
        self.bn2   = norm_layer(planes)
842
        self.downsample = downsample
843
844
    def forward(self, x):
845
        '''
846
            x: [n, c, s, h, w]
847
        '''
848
        identity = x
849
850
        out = self.conv1(x)
851
        out = self.bn1(out)
852
        out = self.relu(out)
853
854
        out = self.conv2(out)
855
        out = self.bn2(out)
856
857
        if self.downsample is not None:
858
            identity = self.downsample(x)
859
860
        out += identity
861
        out = self.relu(out)
862
863
        return out